In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import sys
sys.path.append('../scripts/')
from unet_custom_implementation import Unet
import torchvision
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import wandb
import segmentation_models_pytorch as smp
from tqdm import tqdm

In [2]:
multiple_gpus = True
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        multiple_gpus = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
csv_file = pd.read_csv('../../ship_data/train_ship_segmentations_v2.csv')
csv_file = csv_file.groupby('ImageId')['EncodedPixels'].apply(list).reset_index()
image_ids, pixels = csv_file['ImageId'].values.tolist(), csv_file['EncodedPixels'].values.tolist()

In [4]:
csv_file['fixed_inputs'] = csv_file['ImageId'].apply(lambda x: '../../ship_data/train_v2/' + x)
csv_file['mask_paths'] = csv_file['ImageId'].apply(lambda x: '../../ship_data/masks_v1/train/' + x.split('.')[0] + '.' + 'png')

In [5]:
csv_file = csv_file[csv_file['fixed_inputs'] != '../../files/train_v2/6384c3e78.jpg']

In [6]:
def split_datasets(csv_file, test_size = 0.01):
    train, test = train_test_split(csv_file, test_size = test_size, random_state=42)
    train, val = train_test_split(train, test_size = test_size, random_state=42)
    return train, val, test

In [7]:
train, val, test = split_datasets(csv_file)

In [8]:
class GetData(Dataset):
    def __init__(self, csv_file: pd.DataFrame):
        self.img_paths = csv_file['fixed_inputs'].values.tolist()
        self.mask_paths = csv_file['mask_paths'].values.tolist()
    
    def __len__(self) -> int:
        return len(self.img_paths)
    
    def __getitem__(self, index):
        img = torchvision.io.read_file(self.img_paths[index])
        img = torchvision.io.decode_jpeg(img)
        mask = torchvision.io.read_file(self.mask_paths[index])
        mask = torchvision.io.decode_image(mask)
        img = torchvision.transforms.functional.resize(img, (256, 256))
        mask = torchvision.transforms.functional.resize(mask, (68, 68))
        img = img / 255
        mask = mask / 255
        return img, mask

In [9]:
# Loss function
def dice_bce_loss(inputs, targets, smooth = 1e-5):
    # remove if your model inherently handles sigmoid
    number_of_pixels = inputs.shape[0] * (512 * 512 * 3)
    sigmoid = nn.Sigmoid()
    inputs = sigmoid(inputs)
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum()
    dice_loss = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    dice_loss = 1 - dice_loss
    # Pixel wise log loss is calculated not number of images
    # I checked reduce by mean is correct measure.
    BCE = nn.functional.binary_cross_entropy(inputs, targets, reduce='mean')
    final = dice_loss + BCE
    return final, number_of_pixels

In [10]:
# IOU metric
# SMOOTH = 1e-5
def iou_score(inputs, targets, thres = None, smooth=1e-5):
    sigmoid = nn.Sigmoid()
    inputs = sigmoid(inputs)
    if thres != None:
        inputs = (inputs > thres).float()
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = torch.sum(inputs * targets)
    unioun = torch.sum(inputs + targets) - intersection
    # TP = torch.sum(torch.logical_and(inputs == 1, targets == 1))
    # FP = torch.sum(torch.logical_and(inputs == 1, targets == 0))
    # FN = torch.sum(torch.logical_and(inputs == 0, targets == 1))
    iou = (intersection + smooth) / (unioun + smooth)
    return iou

In [11]:
def train_model(model, train_set, val_set, epochs):
    model = nn.DataParallel(model)
    model = model.to(device)
    datadict = {
        'train': train_set,
        'val' : val_set
    }
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    for epoch in range(epochs):
        train_loss, train_iou = 0.0, 0.0
        val_loss, val_iou = 0.0, 0.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss, running_iou = 0.0, 0.0
            with tqdm(datadict[phase], unit='batch') as tepoch:
                for img, label in tepoch:
                    img = img.to(device)
                    label = label.to(device)
                    optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(img)
                        loss, _ = dice_bce_loss(outputs, label)
                        iou = iou_score(outputs, label)
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    running_loss += loss.item()
                    running_iou += iou.item()
                    tepoch.set_postfix(loss = loss.item(), iou = iou.item())
            if phase == 'train':
                train_loss = running_loss / len(datadict['train'])
                train_iou = running_iou / len(datadict['train'])
                # epoch_loss = running_loss / len(datadict['train'])
                # epoch_iou = running_iou / len(datadict['train'])
                # print(f'Train loss: {epoch_loss}')
            else:
                val_loss = running_loss / len(datadict['val'])
                val_iou = running_iou / len(datadict['val'])
                # epoch_loss = running_loss / len(datadict['val'])
                # epoch_iou = running_iou / len(datadict['val'])
                # print(f'Val loss: {epoch_loss}')
        wandb.log({
            'train_loss' : train_loss,
            'val_loss' : val_loss,
            'train_iou' : train_iou,
            'val_iou' : val_iou
        })

In [12]:
train_dataset = GetData(train)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=22)
val_dataset = GetData(val)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True, num_workers=22)

In [13]:
train_model(Unet(3), train_loader, val_loader, 10)

  3%|▎         | 47/1475 [00:20<10:34,  2.25batch/s, iou=0.00211, loss=1.96] 


KeyboardInterrupt: 

In [21]:
x.shape

torch.Size([64, 64, 68, 68])

In [14]:
# def get_model():
#     model = smp.Unet(
#         in_channels=3,             
#         classes=1,                
#     )
#     return model

In [27]:
model = Unet(3)

In [28]:
temp = torch.rand(1, 3, 256, 256)

In [29]:
pred_temp = model(temp)

In [30]:
# original - 68, new -  

In [31]:
pred_temp.shape

torch.Size([1, 1, 68, 68])