In [1]:
from src.core import *
from src.rois import *
from functools import partial

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms, models
from torchvision.transforms.functional import resized_crop

from sklearn.model_selection import train_test_split
from fastai.vision.all import DataLoaders, OptimWrapper, Learner

In [2]:
loader = JSONLoader('data')
df, id2label = loader.load_train()

In [3]:
from tqdm import tqdm

tqdm.pandas()
res = df.loc[:500].progress_apply(get_annotated_rois, axis=1)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 501/501 [09:44<00:00,  1.17s/it]


In [5]:
img_paths = sum([row[0] for row in res], [])
rois = torch.cat([row[1] for row in res])
roi_ids = torch.cat([row[2] for row in res])
offsets = torch.cat([row[3] for row in res])

In [14]:
class RCNNDataset(Dataset):
    def __init__(self, img_paths, rois, roi_ids, offsets, crop_size=(224,224)):
        self.img_paths, self.rois, self.roi_ids, self.offsets = img_paths, rois, roi_ids, offsets
        self.crop_size = crop_size
        self.img_tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        img = self.img_tfms(img)
        _, H, W = img.shape
        offset = self.offsets[idx]/torch.tensor([W,H,W,H])
        x_min, y_min, w, h = self.rois[idx].int().tolist()
        crop = resized_crop(img, top=y_min, left=x_min, height=h, width=w, size=self.crop_size)
        return crop, self.roi_ids[idx], offset

In [15]:
train_idxs, eval_idxs = train_test_split(range(len(img_paths)), test_size=0.2)

train_ds = RCNNDataset([img_paths[i] for i in train_idxs], rois[train_idxs], roi_ids[train_idxs], offsets[train_idxs])
eval_ds = RCNNDataset([img_paths[i] for i in eval_idxs], rois[eval_idxs], roi_ids[eval_idxs], offsets[eval_idxs])

In [16]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, pin_memory=True)
eval_dl = DataLoader(eval_ds, batch_size=32, shuffle=False, pin_memory=True)

dls = DataLoaders(train_dl, eval_dl)
dls.n_inp = 1

In [9]:
vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

vgg16.classifier[0].in_features

25088

In [None]:
class RCNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.img_encoder = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        for param in self.img_encoder.parameters():
            param.requires_grad = False
        self.img_encoder.eval()
        encode_dim = self.img_encoder.classifier[0].in_features
        self.img_encoder.classifier = nn.Sequential()

        self.cls_head = nn.Linear(encode_dim, n_classes)
        self.cls_loss_func = nn.CrossEntropyLoss()
        self.reg_head = nn.Sequential(
            nn.Linear(encode_dim, 512), nn.ReLU(),
            nn.Linear(512, 4), nn.Tanh(),
        )
        self.reg_loss_func = nn.MSELoss()

    def forward(self, crops):
        features = self.img_encoder(crops)
        probs = self.cls_head(features)
        bbox = self.reg_head(features)
        return probs, bbox

    def calc_loss(self, preds, ids, offsets, beta=0.2):
        probs, bbox = preds
        cls_loss = self.cls_loss_func(probs, ids)
        mask = ids!=0
        bbox, offsets = bbox[mask], offsets[mask]
        reg_loss = self.reg_loss_func(bbox, offsets) if len(mask)>0 else torch.tensor(0.0, requires_grad=True)
    
        print(cls_loss, reg_loss)
        return beta*cls_loss + (1-beta)*reg_loss

In [11]:
model = RCNN(len(id2label))
opt_func = partial(OptimWrapper, opt=torch.optim.Adam)

In [17]:
learn = Learner(dls, model, loss_func=model.calc_loss, opt_func=opt_func)

In [18]:
learn.fit_one_cycle(n_epoch=2, lr_max=1e-4)

epoch,train_loss,valid_loss,time


tensor(1.8425, grad_fn=<NllLossBackward0>) tensor(0.0487, grad_fn=<MseLossBackward0>)
tensor(1.8672, grad_fn=<NllLossBackward0>) tensor(0.1609, grad_fn=<MseLossBackward0>)
tensor(1.9660, grad_fn=<NllLossBackward0>) tensor(0.0767, grad_fn=<MseLossBackward0>)
tensor(1.7868, grad_fn=<NllLossBackward0>) tensor(0.0816, grad_fn=<MseLossBackward0>)
tensor(1.7599, grad_fn=<NllLossBackward0>) tensor(0.0719, grad_fn=<MseLossBackward0>)
tensor(1.7597, grad_fn=<NllLossBackward0>) tensor(0.0960, grad_fn=<MseLossBackward0>)
tensor(1.8872, grad_fn=<NllLossBackward0>) tensor(0.0637, grad_fn=<MseLossBackward0>)
tensor(1.7564, grad_fn=<NllLossBackward0>) tensor(0.0849, grad_fn=<MseLossBackward0>)
tensor(1.8057, grad_fn=<NllLossBackward0>) tensor(0.0593, grad_fn=<MseLossBackward0>)
tensor(1.6273, grad_fn=<NllLossBackward0>) tensor(0.1016, grad_fn=<MseLossBackward0>)
tensor(1.6211, grad_fn=<NllLossBackward0>) tensor(0.0248, grad_fn=<MseLossBackward0>)
tensor(1.5030, grad_fn=<NllLossBackward0>) tensor(0.06

KeyboardInterrupt: 