In [22]:
import torch
import numpy as np
import pandas as pd
import segmentation_models_pytorch as sm
import cv2
import os
import glob
from tqdm import tqdm
import multiprocessing as mp
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import io
import torchvision.transforms as T
import PIL
import torch.nn as nn
import torch.nn.functional as F
import time
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
from torch.utils.tensorboard import SummaryWriter

In [2]:
csv_file = pd.read_csv('../../files/train_ship_segmentations_v2.csv')
csv_file = csv_file.groupby('ImageId')['EncodedPixels'].apply(list).reset_index()
image_ids, pixels = csv_file['ImageId'].values.tolist(), csv_file['EncodedPixels'].values.tolist()

In [3]:
csv_file['fixed_inputs'] = csv_file['ImageId'].apply(lambda x: '../../files/train_v2/' + x)
csv_file['mask_paths'] = csv_file['ImageId'].apply(lambda x: '../../files/masks_v1/train/' + x.split('.')[0] + '.' + 'png')

In [4]:
for x in tqdm(csv_file['fixed_inputs'].values.tolist()):
    if os.path.exists(x) == False:
        print(x)

100%|██████████| 192556/192556 [00:00<00:00, 627781.51it/s]


In [5]:
for x in tqdm(csv_file['mask_paths'].values.tolist()):
    if os.path.exists(x) == False:
        print(x)

100%|██████████| 192556/192556 [00:00<00:00, 606901.49it/s]


In [6]:
csv_file['fixed_inputs'].values.tolist()[0]

'../../files/train_v2/00003e153.jpg'

In [7]:
csv_file = csv_file[csv_file['fixed_inputs'] != '../../files/train_v2/6384c3e78.jpg']

In [8]:
def split_datasets(csv_file, test_size = 0.01):
    train, test = train_test_split(csv_file, test_size = test_size, random_state=42)
    train, val = train_test_split(train, test_size = test_size, random_state=42)
    return train, val, test

In [9]:
# csv_file = pd.read_csv
train, val, test = split_datasets(csv_file)

In [10]:
type_of_label = []
for x in train['EncodedPixels'].values.tolist():
    if type(x[0]) == str:
        type_of_label.append(1)
    else:
        type_of_label.append(0)

In [11]:
unet_incep = sm.Unet(
    encoder_name='inceptionv4',
    encoder_weights=None,
    in_channels=3,
    classes=1
)

In [12]:
class ShipSegmentationData(Dataset):
    def __init__(self, csv_file, output_shape, is_val = False):
        self.csv_file = csv_file
        self.imgs = self.csv_file['fixed_inputs'].values.tolist()
        self.masks = self.csv_file['mask_paths'].values.tolist()
        self.flip_probs = np.random.randint(0, 2, size = (len(self.imgs)))
        self.brightness_factor = np.random.uniform(1, 2, len(self.imgs))
        self.contrast_factor = np.random.uniform(1, 2, len(self.imgs))
        self.output_shape = output_shape
        self.is_val = is_val

    def __getitem__(self, idx):
        img = io.read_image(self.imgs[idx])
        mask = io.read_image(self.masks[idx])
        if self.is_val:
            if self.flip_probs[idx] == 1:
                img = T.functional.hflip(img)
                mask = T.functional.hflip(mask)
            img = T.functional.adjust_brightness(img, self.brightness_factor[idx])
            img = T.functional.adjust_contrast(img, self.contrast_factor[idx])
        img = T.functional.resize(img, self.output_shape)
        mask = T.functional.resize(mask, self.output_shape)
        img = img / 255
        mask = mask / 255
        mask = torch.where(mask < 1, 0, 1)
        return img, mask

    def __len__(self):
        return len(self.imgs)

In [13]:
train_dataset = ShipSegmentationData(train[:1000], output_shape=(512, 512))
val_dataset = ShipSegmentationData(val, output_shape=(512, 512), is_val=True)
trainloader = DataLoader(train_dataset, shuffle=True, num_workers=10, prefetch_factor=2, batch_size=24)
valloader = DataLoader(val_dataset, shuffle=False, num_workers=2, prefetch_factor=2, batch_size=24)
dataloaders = {
    'train': trainloader,
    'val': valloader
}
dataset_sizes = {
    'train' : len(train),
    'val' : len(val)
}

In [14]:
def dice_bce_loss(inputs, targets, smooth = 1):
    inputs = F.sigmoid(inputs)
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum() 
    dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
    targets = targets.float()
    BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
    Dice_BCE = BCE + dice_loss
    return Dice_BCE

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [23]:
def train_model(model, criterion, optimizer, writer, num_epochs = 5):
    since = time.time()

    for epoch in range(num_epochs):
        print(f'epoch: {epoch}/{num_epochs - 1}')
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0.0

            for inputs, masks in dataloaders[phase]:
                inputs = inputs.to(device)
                masks = masks.to(device)
                optimizer.zero_grad()

                # Forward
                # Track hist if it's only train
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, masks)

                    # backward + optimize only if training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # some stats
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(outputs == masks.data)
            

            epoch_loss = running_loss / dataset_sizes[phase]
            writer.add_scalar("Loss/train", epoch_loss, epoch)
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            torch.save(model.state_dict(), f'../../weights/{epoch}.pth')

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    # print(f'Best val Acc: {best_acc:4f}')

In [24]:
# unet_incep.to(device)
writer = SummaryWriter()
optimzer = optim.SGD(unet_incep.parameters(), lr=0.01)
train_model(unet_incep, dice_bce_loss, optimzer, writer)

epoch: 0/4
----------
epoch: 1/4
----------
epoch: 2/4
----------
epoch: 3/4
----------
epoch: 4/4
----------
Training complete in 5m 25s
