In [1]:
from src.core import *
from src.rois import *
from pathlib import Path
from functools import partial
from tqdm import tqdm
tqdm.pandas()

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn.functional import one_hot
from torchvision.io import read_image, ImageReadMode
import torchvision.transforms.v2 as v2
from torchvision import models
from torchvision.transforms.v2.functional import resized_crop

from fastai.vision.all import DataLoaders, OptimWrapper, Learner, AvgMetric, Metric

In [2]:
class RCNNDataset(Dataset):
    def __init__(self, ann_path, imgs_path, sample_frac=1, crop_size=(224,224), tfms=None):
        ann_path = Path(ann_path)
        df, self.id2label, self.id2img = load_data(str(ann_path.parent), imgs_path, ann_path.stem)
        replace = True if sample_frac>1 else False
        df = df.sample(frac=sample_frac, replace=replace)
        self.crop_size = crop_size
        self.tfms = tfms
        if self.tfms is None:
            self.tfms = v2.Compose([
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        
        res = df.progress_apply(get_annotated_rois, axis=1, id2img=self.id2img)
        self.img_ids = torch.cat([row[0] for row in res])
        self.rois = torch.cat([row[1] for row in res])
        self.roi_ids = torch.cat([row[2] for row in res])
        self.offsets = torch.cat([row[3] for row in res])

    def __len__(self):
        return self.img_ids.shape[0]

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img = read_image(self.id2img[img_id.item()], mode=ImageReadMode.RGB)
        img = self.tfms(img)
        x_min, y_min, w, h = self.rois[idx].int().tolist()
        crop = resized_crop(img, top=y_min, left=x_min, height=h, width=w, size=self.crop_size)
        
        return crop, img_id, self.rois[idx], self.roi_ids[idx], self.offsets[idx]

In [3]:
train_ds = RCNNDataset('tmp/train.json', 'data/train', 0.01)
valid_ds = RCNNDataset('tmp/valid.json', 'data/train', 0.1)

100%|███████████████████████████████████████████████████████████████████████████████████| 51/51 [00:29<00:00,  1.72it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 57/57 [00:34<00:00,  1.65it/s]


In [4]:
num_neg = torch.sum(train_ds.roi_ids==0)
num_pos = train_ds.roi_ids.numel()-num_neg
print(num_pos/num_neg)

tensor(0.6564)


In [5]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=False, pin_memory=True)

dls = DataLoaders(train_dl, valid_dl)
dls.n_inp = 1

In [6]:
def get_dls(train_ds, valid_ds, bs=64, tfms=None):
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, pin_memory=True)
    valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False, pin_memory=True)
    
    dls = DataLoaders(train_dl, valid_dl)
    dls.n_inp = 1
    return dls

In [7]:
vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

vgg16.classifier[0].in_features

25088

In [8]:
class RCNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()
        encode_dim = self.model.classifier[0].in_features

        head = nn.Sequential(
            nn.Linear(encode_dim, 4096), nn.ReLU(),
            nn.BatchNorm1d(4096), nn.Dropout(0.5),
            nn.Linear(4096, 512), nn.ReLU(),
            nn.BatchNorm1d(512), nn.Dropout(0.5),
            nn.Linear(512, n_classes+5)
        )
        self.model.classifier = head

    def forward(self, crops): return self.model(crops)

In [9]:
def reg_loss(preds, *targs):
    _, _, roi_ids, offsets = targs
    loss = torch.tensor(0.0, requires_grad=True)
    mask = roi_ids!=0
    if torch.sum(mask)>0:
        loss = nn.L1Loss()(preds[mask, -4:], offsets[mask])
    return loss

def cls_loss(preds, *targs):
    _, _, roi_ids, _ = targs
    loss = torch.tensor(0.0, requires_grad=True)
    n_classes = preds.shape[1]-4
    cats = one_hot(roi_ids, num_classes=n_classes)
    preds[:,:-4] = nn.Sigmoid()(preds[:,:-4])
    loss = nn.BCELoss()(preds[:,1:-4], cats[:,1:].to(torch.float32))
    return loss

def detn_loss(preds, *targs):
    return cls_loss(preds, *targs) + reg_loss(preds, *targs)

In [10]:
model = RCNN(len(train_ds.id2label))

In [11]:
for i, batch in enumerate(train_dl):
    if i==3:  break
    preds = model(batch[0])
    print(cls_loss(preds, *batch[1:]))
    print(reg_loss(preds, *batch[1:]))
    print(detn_loss(preds, *batch[1:]))
    print('\n')

tensor(0.7639, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.7309, grad_fn=<MeanBackward0>)
tensor(1.6954, grad_fn=<AddBackward0>)


tensor(0.7586, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.6911, grad_fn=<MeanBackward0>)
tensor(1.6562, grad_fn=<AddBackward0>)


tensor(0.7649, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.7239, grad_fn=<MeanBackward0>)
tensor(1.6845, grad_fn=<AddBackward0>)




In [12]:
import io, sys
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

class mAP(Metric):
    def __init__(self, gt_path, pred_path):
        self.gt_path, self.pred_path = gt_path, pred_path
        self.reset()
        
    def reset(self):
        with open(self.pred_path, "w") as f:
            json.dump([], f, indent=4)

    def accumulate(self, learn):
        probs, pred_offsets = learn.pred[:,:-4], learn.pred[:,-4:]
        scores, pred_ids = probs.max(dim=1)
        img_ids, rois, roi_ids, _  = learn.y
        mask = roi_ids!=0
        pred_bbs = apply_offsets(rois, pred_offsets)
        
        self.write_to_file(img_ids[mask], pred_ids[mask], pred_bbs[mask], scores[mask])

    def write_to_file(self, img_ids, ids, pred_bbs, scores):
        iterables = [img_ids.tolist(), ids.tolist(), pred_bbs.tolist(), scores.tolist()]
        new_anns = [
            {'image_id':img_id, 'category_id':id, 'bbox':bbox, 'score':score}
            for img_id, id, bbox, score in zip(*iterables)
        ]
        with open(self.pred_path, 'r') as f:
            data = json.load(f)
        data += new_anns
        with open(self.pred_path, 'w') as f:
            json.dump(data, f, indent=4)
            
    @property
    def value(self):
        with io.StringIO() as buf:
            save_stdout = sys.stdout
            sys.stdout = buf  # Redirect standard output
            coco_gt = COCO(self.gt_path)
            coco_pred = coco_gt.loadRes(self.pred_path)
            cocoEval = COCOeval(coco_gt, coco_pred, 'bbox')
            cocoEval.evaluate()
            cocoEval.accumulate()
            cocoEval.summarize()
            sys.stdout = save_stdout  # Restore standard output
        return cocoEval.stats[0]

In [13]:
opt_func = partial(OptimWrapper, opt=torch.optim.Adam)
mAP_metric = mAP('tmp/valid.json', 'tmp/valid_preds.json')
reg_metric, cls_metric = AvgMetric(reg_loss), AvgMetric(cls_loss)
metrics = [mAP_metric, reg_metric, cls_metric]

learn = Learner(dls, model, loss_func=detn_loss, opt_func=opt_func, metrics=metrics)

In [14]:
learn.fit_one_cycle(n_epoch=2, lr_max=1e-4)

epoch,train_loss,valid_loss,m_ap,reg_loss,cls_loss,time
0,1.60849,1.219979,2.4e-05,0.521118,0.962353,06:12
1,1.64905,1.084734,0.000253,0.384446,0.964177,06:01
