diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index afdbc46a64b..19534f193ed 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -318,16 +318,19 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, Example:: >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) - >>> images,boxes,labels = torch.rand(4,3,600,1200), torch.rand(4,11,4), torch.rand(4,11) # For Training + >>> # For training + >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) + >>> labels = torch.randint(1, 91, (4, 11)) >>> images = list(image for image in images) - >>> targets = [] + >>> targets = [] >>> for i in range(len(images)): >>> d = {} >>> d['boxes'] = boxes[i] - >>> d['labels'] = labels[i].type(torch.int64) + >>> d['labels'] = labels[i] >>> targets.append(d) - >>> output = model(images,targets) - >>> model.eval() # For inference + >>> output = model(images, targets) + >>> # For inference + >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x)