# Check model training

In [1]:
from albumentations import Compose, Resize, RandomCropNearBBox, ShiftScaleRotate, GaussNoise, ElasticTransform
from albumentations import RandomBrightnessContrast, Normalize

import cv2
import torch

import sys

sys.path.insert(0, "../code/deeplab")


from dataflow.dataloaders import get_train_val_loaders
from dataflow.datasets import get_train_dataset
from dataflow.transforms import ToTensor

In [2]:

train_transforms = Compose([
    ShiftScaleRotate(shift_limit=0.2, scale_limit=0.075, rotate_limit=45, interpolation=cv2.INTER_CUBIC, p=0.3),
    Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    GaussNoise(),
    RandomBrightnessContrast(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensor(),
])
train_transform_fn = lambda dp: train_transforms(**dp)



val_transforms = Compose([
    Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensor(),    
])
val_transform_fn = lambda dp: val_transforms(**dp)


In [3]:
train_loader, val_loader, train_eval_loader = get_train_val_loaders(root_path="/home/storage_ext4_1tb/", 
                                                                    train_transforms=train_transform_fn,
                                                                    val_transforms=val_transform_fn,
                                                                    batch_size=4,
                                                                    random_seed=12)

In [4]:
num_classes = 21

In [5]:
from models.deeplabv3 import DeepLabV3
from models.backbones import build_resnet18_backbone

model = DeepLabV3(build_resnet18_backbone, num_classes=num_classes)

In [6]:
import torch.nn as nn

from ignite.utils import convert_tensor, to_onehot

In [7]:
device = 'cuda'

model = model.to(device).half()
criterion = nn.BCEWithLogitsLoss().to(device)

In [13]:
for batch in train_loader:
    x, y = batch['image'], batch['mask']
    x = convert_tensor(x, device, non_blocking=True).half()
    y = convert_tensor(y, device, non_blocking=True).long()
    
    # Ignore boundaries
    y[y == 255] = 0
    # to OHE
    batch_size = y.shape[0]
    size = y.shape[1:]
    y = to_onehot(y.reshape(-1), num_classes=num_classes).reshape(batch_size, -1, *size)
    
    break

In [14]:
batch['mask'].shape

torch.Size([4, 224, 224])

In [15]:
x.shape, y.shape, torch.unique(y)

(torch.Size([4, 3, 224, 224]),
 torch.Size([4, 21, 224, 224]),
 tensor([0., 1.], device='cuda:0'))

In [10]:
y_pred = model(x)

In [11]:
y_pred.shape, y_pred.type()

(torch.Size([4, 21, 224, 224]), 'torch.cuda.HalfTensor')

In [12]:
loss = criterion(y_pred.float(), y)
loss

tensor(0.8636, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [13]:
loss.backward()