In [1]:
import os
import sys
path = os.path.join(os.getcwd(), '..')
sys.path.append(path)

import warnings
warnings.filterwarnings('ignore')

from fastai2.vision.all import *

from src.data.dls import build_dataloaders
from src.model.FasterRCNN import get_faster_rcnn
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN

device = torch.device('cuda')

%load_ext autoreload
%autoreload 2

In [2]:
from pdb import set_trace

#### 1a. Setup One Batch

In [3]:
data_path = '/userhome/34/h3509807/wheat-data'

dls = build_dataloaders(
    data_path, bs = 8, resize_sz = 256,
    norm = False, rand_seed = 144
    )

In [14]:
b = dls.one_batch()
b[0].shape, b[1].shape, b[2].shape

(torch.Size([8, 3, 256, 256]), torch.Size([8, 77, 4]), torch.Size([8, 77]))

In [15]:
b[0][0, 0, :3, :3]

tensor([[0.4196, 0.5098, 0.5412],
        [0.5020, 0.6118, 0.5725],
        [0.4784, 0.5059, 0.4471]], device='cuda:0')

#### 1b. Transform Batch to Model Feed Ready

In [16]:
def decode_bboxs(enc_bboxs):
    sz = 256
    return TensorPoint((enc_bboxs + 1)*tensor(256).float()/2, img_size = sz)

In [17]:
xb = [i for i in b[0]]

In [18]:
yb = []
for bboxs, cats in zip(b[1], b[2]):
    idxs = torch.where(cats != 0)
    tmp_dict = {'boxes': decode_bboxs(bboxs[idxs]), 
                'labels': cats[idxs]}
    yb.append(tmp_dict)

#### 2. Setup Model

In [19]:
model = get_faster_rcnn()
model = model.to(device)

#### 3. Test Forward Behavior
- train v.s. eval mode
- with and without target

In [21]:
model.train()
with torch.no_grad():
    pred = model(xb, yb)
    #pred = model(xb) << this lead to error
type(pred), pred.keys(), len(pred)

(dict,
 dict_keys(['loss_classifier', 'loss_box_reg', 'loss_objectness', 'loss_rpn_box_reg']),
 4)

In [27]:
model.eval()
with torch.no_grad():
    pred = model(xb, yb)
type(pred), len(pred), pred[0].keys()

(list, 8, dict_keys(['boxes', 'labels', 'scores']))

In [28]:
model.eval()
with torch.no_grad():
    pred = model(xb)
type(pred), len(pred), pred[0].keys()

(list, 8, dict_keys(['boxes', 'labels', 'scores']))

#### Conclusion
- Whenever in training mode, the output is a dict of loss breakdown.
- Whenever in eval mode, the output is list of prediction (boxes, labels, scores), no matter you have targets or ont

#### 4. Monkey Patch GeneralizedRCNN Behavior and Try Again

In [29]:
@patch
def eager_outputs(self: GeneralizedRCNN, losses, detections):
    return losses, detections

In [30]:
model = get_faster_rcnn()
model = model.to(device)

#### 4a. Training Mode

In [46]:
model.train()
with torch.no_grad():
    pred = model(xb, yb)
    #pred = model(xb) << this lead to error
type(pred), len(pred)

(tuple, 2)

In [47]:
pred[0]

{'loss_classifier': tensor(0.6281, device='cuda:0'),
 'loss_box_reg': tensor(0.0005, device='cuda:0'),
 'loss_objectness': tensor(15.5977, device='cuda:0'),
 'loss_rpn_box_reg': tensor(477296.1250, device='cuda:0')}

In [48]:
pred[1]

[]

#### 4b. Eval Mode

In [39]:
model.eval()
with torch.no_grad():
    pred = model(xb, yb)
    #pred = model(xb) << this lead to error
type(pred), len(pred)

(tuple, 2)

In [40]:
type(pred[0]), len(pred[0]), type(pred[1]), len(pred[1])

(dict, 0, list, 8)

In [43]:
model.eval()
with torch.no_grad():
    pred = model(xb)
    #pred = model(xb) << this lead to error
type(pred), len(pred)

(tuple, 2)

In [44]:
type(pred[0]), len(pred[0]), type(pred[1]), len(pred[1])

(dict, 0, list, 8)

In [45]:
model

FasterRCNN(
  (transform): GeneralizedRCNNTransform()
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d()
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d()
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d()
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d()
          )
  