In [1]:
from src.core import *
from src.rois import *
from functools import partial

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms, models
from torchvision.ops import nms

from sklearn.model_selection import train_test_split
from fastai.vision.all import DataLoaders, OptimWrapper, Learner
from selectivesearch import selective_search

In [2]:
loader = JSONLoader('data')
df, id2label = loader.load_train()

In [3]:
def extract_rois_ss(img):
    img_area = img.shape[1]*img.shape[2]
    _, regions = selective_search(img.permute((1,2,0)), scale=200, min_size=100)
    rois = torch.tensor([r['rect'] for r in regions])
    sizes = torch.tensor([r['size'] for r in regions])
    mask = (sizes>0.05*img_area) & (sizes<img_area)
    return ROIs(rois[mask, :])

In [4]:
train_df, eval_df = train_test_split(df.loc[:1000], test_size=0.2)

train_ds = ObjectDataset(df=train_df, extract_rois=extract_rois_ss)
eval_ds = ObjectDataset(df=eval_df, extract_rois=extract_rois_ss)

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=train_ds.collate_fn, pin_memory=True)
eval_dl = DataLoader(eval_ds, batch_size=4, shuffle=False, collate_fn=eval_ds.collate_fn, pin_memory=True)

In [5]:
dls = DataLoaders(train_dl, eval_dl)
dls.n_inp = 1

In [6]:
vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

vgg16.classifier[0].in_features

25088

In [7]:
class RCNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.img_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        for param in self.img_encoder.parameters():
            param.requires_grad = False
        self.img_encoder.eval()
        encode_dim = self.img_encoder.classifier[0].in_features
        self.img_encoder.classifier = nn.Sequential()

        self.cls_head = nn.Linear(encode_dim, n_classes)
        self.cls_loss = nn.CrossEntropyLoss()
        self.reg_head = nn.Sequential(
             nn.Linear(encode_dim, 512), nn.ReLU(),
             nn.Linear(512, 4), nn.Sigmoid(),
        )
        self.reg_loss = nn.L1Loss()

    def forward(self, crops):
        features = self.img_encoder(crops)
        probs = self.cls_head(features)
        bbox = self.reg_head(features)
        return probs, bbox

    def calc_loss(self, preds, ids, offsets, beta=0.1):
        probs, bbox = preds
        cls_loss = self.cls_loss(probs, ids)
        mask = ids!=0
        bbox, offsets = bbox[mask], offsets[mask]
        reg_loss = self.reg_loss(bbox, offsets) if len(mask)>0 else torch.tensor(0.0, requires_grad=True)

        return beta*cls_loss + (1-beta)*reg_loss

In [8]:
model = RCNN(len(id2label))
opt_func = partial(OptimWrapper, opt=torch.optim.Adam)

In [9]:
learn = Learner(dls, model, loss_func=model.calc_loss, opt_func=opt_func)

In [None]:
learn.fit_one_cycle(n_epoch=3, lr_max=1e-3)

epoch,train_loss,valid_loss,time


