In [None]:
from loss import *
from dataset import *
from utils import *
from model import *
from viz import *
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

C=20
B=2
S=7
prob_threshold=0.5
iou_threshold=0.5

net = Yolov1(split_size=7, num_boxes=2, num_classes=20).cuda()
optimizer = optim.Adam(net.parameters(), lr=2e-5)
loss_fn = YoloLoss(S=7, B=2, C=20)


transform = A.Compose([
    A.Resize(448, 448),
    # A.HorizontalFlip(),
    A.Normalize(),
    ToTensorV2()
], bbox_params=A.BboxParams(format='albumentations')) # format=albumentations is normalized pascal_voc.
test_dataset=VOCDataset(root='data', train=False, transform=transform)
dataloader=DataLoader(dataset=test_dataset, batch_size=4, shuffle=False)
name2id={'aeroplane': 0,'bicycle': 1,'bird': 2,'boat': 3,'bottle': 4,
         'bus': 5,'car': 6,'cat': 7,'chair': 8,'cow': 9,
         'diningtable': 10,'dog': 11,'horse': 12,'motorbike': 13,'person': 14,
         'pottedplant': 15,'sheep': 16,'sofa': 17,'train': 18,'tvmonitor': 19}
net.load_state_dict(torch.load('model.pt'))
net.eval()
_=1

In [None]:
indices=torch.arange(S*S).cuda()
results=dict((v,dict()) for v in name2id.values()) #{class:{image_no: [p,x,y,w,h]}}
loop=tqdm(dataloader)

for _, (img, (boxes, labels, Iobj), train_index) in enumerate(loop):
    N=img.shape[0] # batch size
    img, boxes, labels, Iobj = img.cuda(), boxes.cuda(), labels.cuda(), Iobj.cuda()
    with torch.no_grad():
        out = net(img)

    predictions = out.reshape(-1, S*S, C + 5 * B)  # (N, S*S, C+5B)
    boxes = predictions[..., 20:].reshape(-1, S*S, B, 5)  # (N, S*S, B, 5)
    
    # detect boxes
    N_box_det=(boxes[...,0]>prob_threshold) & ((boxes>0).all(-1)) & ((boxes[...,1:3]<1).all(-1)) #N,S*S,B
    N_Iobj=N_box_det.any(-1) #N,S*S
    
    

    for n in range(N): # for each image
        class_probs=predictions[n,:, :20] #S*S,20
        box=boxes[n] #S*S,B,5
        
        box_det=N_box_det[n] #S*S,B
        Iobj=N_Iobj[n] #S*S
        
        box_det=box_det[Iobj] #I,B
        I=box_det.shape[0]

        box=box[Iobj] #I,B,5
        box_indices=indices[Iobj]
        class_probs=class_probs[Iobj] #I,20
        label=class_probs.argmax(-1) #I
        label_probs=class_probs[torch.arange(I),label]#I
        classes=torch.unique(label) # unique classess

        for c in classes:
            Iobj_filtered=label==c
            bbnd=box[Iobj_filtered] # I_filtered, B, 5
            indx=box_indices[Iobj_filtered]
            ii,jj=indx//S, indx%S
            bbnd[:,:,1]+=jj[:,None]
            bbnd[:,:,2]+=ii[:,None]
            bbnd[...,0]*=label_probs[Iobj_filtered][:,None]
            bbnd=bbnd[box_det[Iobj_filtered]]
            ind_after_nms=non_max_suppression(bboxes=bbnd, iou_threshold=iou_threshold)
            results[c.item()][train_index[n].item()]=bbnd[ind_after_nms]          