In [100]:
import os
import cv2
import torch
import numpy as np
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

from VOC_dataset import VOCDataset
from config import DefaultConfig
from fcos import FCOS, DetectHead 
from loss import coords_fmap2orig

In [101]:
root_dir = './test_images/'
image_list = [os.path.join(root_dir, fn) for fn in os.listdir(root_dir) if fn.endswith('.jpg')]

In [102]:
config = DefaultConfig

class FCOSDetector(nn.Module):
    def __init__(self,mode="training",config=None):
        super().__init__()
        self.mode=mode
        self.fcos_body=FCOS(config=config)
    
    def forward(self, inputs):
        batch_imgs = inputs
        out=self.fcos_body(batch_imgs)
        return out
    
ckpt = torch.load("./models/voc2012_512x800_epoch100_loss0.6055.pth",map_location=torch.device('cpu'))

model=FCOSDetector(mode="inference")
model.load_state_dict(ckpt)
model.eval()
print('')

INFO===>success frozen BN
INFO===>success frozen backbone stage1



In [110]:
resize = transforms.Resize((512))
to_tensor = transforms.ToTensor()
to_normalize = transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))

idx = 0
image = Image.open(image_list[idx]).convert('RGB')
tensor_img = to_normalize(to_tensor(resize(image)))
tensor_img = tensor_img.unsqueeze(0)

with torch.no_grad():
    out = model(tensor_img)

### DetectHead

In [105]:
def _reshape_cat_out(inputs,strides):
    batch_size=inputs[0].shape[0]
    c=inputs[0].shape[1]
    out=[]
    coords=[]
    for pred,stride in zip(inputs,strides):
        pred=pred.permute(0,2,3,1)
        coord=coords_fmap2orig(pred,stride).to(device=pred.device)
        pred=torch.reshape(pred,[batch_size,-1,c])
        out.append(pred)
        coords.append(coord)
    return torch.cat(out,dim=1),torch.cat(coords,dim=0)

In [65]:
def _coords2boxes(coords, offsets):
    x1y1=coords[None,:,:]-offsets[...,:2]
    x2y2=coords[None,:,:]+offsets[...,2:] #[batch_size,sum(_h*_w),2]
    boxes=torch.cat([x1y1,x2y2],dim=-1) #[batch_size,sum(_h*_w),4]
    return boxes

In [None]:
def _post_process(self,preds_topk):
    _cls_scores_post=[]
    _cls_classes_post=[]
    _boxes_post=[]
    cls_scores_topk,cls_classes_topk,boxes_topk=preds_topk
    
    for batch in range(cls_classes_topk.shape[0]):
        mask=cls_scores_topk[batch]>=score_threshold
        _cls_scores_b=cls_scores_topk[batch][mask]
        _cls_classes_b=cls_classes_topk[batch][mask]
        _boxes_b=boxes_topk[batch][mask]
        nms_ind=self.batched_nms(_boxes_b,_cls_scores_b,_cls_classes_b,self.nms_iou_threshold)
        _cls_scores_post.append(_cls_scores_b[nms_ind])
        _cls_classes_post.append(_cls_classes_b[nms_ind])
        _boxes_post.append(_boxes_b[nms_ind])
    scores,classes,boxes= torch.stack(_cls_scores_post,dim=0),torch.stack(_cls_classes_post,dim=0),torch.stack(_boxes_post,dim=0)

    return scores,classes,boxes

In [82]:
 def box_nms(boxes,scores,thr):
        if boxes.shape[0]==0:
            return torch.zeros(0,device=boxes.device).long()
        assert boxes.shape[-1]==4
        x1,y1,x2,y2=boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
        areas=(x2-x1+1)*(y2-y1+1)
        order=scores.sort(0,descending=True)[1]
        keep=[]
        while order.numel()>0:
            if order.numel()==1:
                i=order.item()
                keep.append(i)
                break
            else:
                i=order[0].item()
                keep.append(i)
            
            xmin=x1[order[1:]].clamp(min=float(x1[i]))
            ymin=y1[order[1:]].clamp(min=float(y1[i]))
            xmax=x2[order[1:]].clamp(max=float(x2[i]))
            ymax=y2[order[1:]].clamp(max=float(y2[i]))
            inter=(xmax-xmin).clamp(min=0)*(ymax-ymin).clamp(min=0)
            iou=inter/(areas[i]+areas[order[1:]]-inter)
            idx=(iou<=thr).nonzero().squeeze()
            if idx.numel()==0:
                break
            order=order[idx+1]
        return torch.LongTensor(keep)

In [83]:
def batched_nms(boxes, scores, idxs, iou_threshold):
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes.max()
    offsets = idxs.to(boxes) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    keep = box_nms(boxes_for_nms, scores, iou_threshold)
    return keep

In [87]:
def _post_process(preds_topk):
    _cls_scores_post=[]
    _cls_classes_post=[]
    _boxes_post=[]
    cls_scores_topk,cls_classes_topk,boxes_topk=preds_topk
    
    for batch in range(cls_classes_topk.shape[0]):
        mask=cls_scores_topk[batch]>=score_threshold
        _cls_scores_b=cls_scores_topk[batch][mask]#[?]
        _cls_classes_b=cls_classes_topk[batch][mask]#[?]
        _boxes_b=boxes_topk[batch][mask]#[?,4]
        nms_ind=batched_nms(_boxes_b,_cls_scores_b,_cls_classes_b,nms_iou_threshold)
        
        _cls_scores_post.append(_cls_scores_b[nms_ind])
        _cls_classes_post.append(_cls_classes_b[nms_ind])
        _boxes_post.append(_boxes_b[nms_ind])
        
    scores,classes,boxes= torch.stack(_cls_scores_post,dim=0),torch.stack(_cls_classes_post,dim=0),torch.stack(_boxes_post,dim=0)
    return scores,classes,boxes

In [112]:
# detection_head=DetectHead(config.score_threshold,config.nms_iou_threshold,config.max_detection_boxes_num,config.strides,config)
score_threshold=config.score_threshold
nms_iou_threshold=config.nms_iou_threshold
max_detection_boxes_num=config.max_detection_boxes_num
strides=config.strides

inputs = out

cls_logits,coords=_reshape_cat_out(inputs[0], strides) #[batch_size,sum(_h*_w),class_num]
cnt_logits,_=_reshape_cat_out(inputs[1], strides) #[batch_size,sum(_h*_w),1]
reg_preds,_=_reshape_cat_out(inputs[2], strides) #[batch_size,sum(_h*_w),4]

cls_preds=cls_logits.sigmoid_()
cnt_preds=cnt_logits.sigmoid_()

cls_scores,cls_classes=torch.max(cls_preds,dim=-1) #[batch_size,sum(_h*_w)]
cls_scores=cls_scores*(cnt_preds.squeeze(dim=-1)) #[batch_size,sum(_h*_w)]
cls_classes=cls_classes+1 #[batch_size,sum(_h*_w)]

boxes=_coords2boxes(coords, reg_preds)

# select top k
max_num = min(max_detection_boxes_num, cls_scores.shape[-1])
topk_ind = torch.topk(cls_scores, max_num, dim=-1, largest=True, sorted=True)[1]

_cls_scores=[]
_cls_classes=[]
_boxes=[]
for batch in range(cls_scores.shape[0]):
    _cls_scores.append(cls_scores[batch][topk_ind[batch]])#[max_num]
    _cls_classes.append(cls_classes[batch][topk_ind[batch]])#[max_num]
    _boxes.append(boxes[batch][topk_ind[batch]])#[max_num,4]
    
cls_scores_topk = torch.stack(_cls_scores,dim=0)#[batch_size,max_num]
cls_classes_topk = torch.stack(_cls_classes,dim=0)#[batch_size,max_num]
boxes_topk = torch.stack(_boxes,dim=0)#[batch_size,max_num,4]

scores,classes,boxes = _post_process([cls_scores_topk,cls_classes_topk,boxes_topk])

batch_boxes=boxes.clamp_(min=0)
h,w=tensor_img.shape[2:]
batch_boxes[...,[0,2]]=batch_boxes[...,[0,2]].clamp_(max=w-1)
batch_boxes[...,[1,3]]=batch_boxes[...,[1,3]].clamp_(max=h-1)

In [113]:
image = image.resize((tensor_img.shape[3], tensor_img.shape[2]))
cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

boxes=boxes[0].cpu().numpy().tolist()
classes=classes[0].cpu().numpy().tolist()
scores=scores[0].cpu().numpy().tolist()

for i, box in enumerate(boxes):
    pt1 = (int(box[0]), int(box[1]))
    pt2 = (int(box[2]), int(box[3]))
    cv2.rectangle(cv_img,pt1,pt2,(0,255,0))
    cv2.putText(cv_img,"%s %.3f"%(VOCDataset.CLASSES_NAME[int(classes[i])],scores[i]),(int(box[0]+5),int(box[1])+15),cv2.FONT_HERSHEY_SIMPLEX,0.5,[0,200,20],2)

cv2.imshow('', cv_img)
cv2.waitKey(0)
cv2.destroyAllWindows()