Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
import unittest


import torch
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

import pytest

class ResnetFPNBackboneTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float32

def test_resnet18_fpn_backbone(self):
device = torch.device('cpu')
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
y = resnet18_fpn(x)
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])

def test_resnet50_fpn_backbone(self):
device = torch.device('cpu')
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
y = resnet50_fpn(x)
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ['0', '1', '2', '3', 'pool']