# Since the evaluation method for this competition is "instance segmentation", this notebook **cannot** be used directly for submission.
 
Please use this notebook as a reference for semantic segmentation.

The codes in this notebook refer to https://github.com/YutaroOgawa/pytorch_advanced/tree/master/3_semantic_segmentation.


Copyright (c) 2019 Yutaro Ogawa

Released under the MIT license
https://github.com/YutaroOgawa/pytorch_advanced/blob/master/LICENSE

In [None]:
import os
import sys

import random
import math
import time
import pandas as pd
import numpy as np

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

In [None]:
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

In [None]:
os.mkdir('./utils')
os.mkdir('./weights')
sys.path.append('./utils')

In [None]:
from shutil import copyfile
copyfile(src = "../input/utils-train/data_augumentation.py", dst = "./utils/data_augumentation.py")
copyfile(src = "../input/utils-train/dataloader.py", dst = "./utils/dataloader.py")
copyfile(src = "../input/utils-train/pspnet.py", dst = "./utils/pspnet.py")

In [None]:
from dataloader import make_datapath_list, DataTransform, VOCDataset

rootpath = "../input/sartorius-cell-instance-segmentation/"
train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(
    rootpath=rootpath)

color_mean = (1.0, 1.0, 1.0)
color_std = (1.0, 1.0, 1.0)

train_dataset = VOCDataset(train_img_list, train_anno_list, phase="train", transform=DataTransform(
    input_size=520, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(val_img_list, val_anno_list, phase="val", transform=DataTransform(
    input_size=520, color_mean=color_mean, color_std=color_std))

batch_size = 8

train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

In [None]:
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))

In [None]:
from utils.pspnet import PSPNet

net = PSPNet(n_classes=150)

n_classes = 1
net.decode_feature.classification = nn.Conv2d(
    in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

net.aux.classification = nn.Conv2d(
    in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

net.decode_feature.classification.apply(weights_init)
net.aux.classification.apply(weights_init)

In [None]:
net

In [None]:
from sklearn.metrics import log_loss

In [None]:
class PSPLoss(nn.Module):

    def __init__(self, aux_weight=0.4):
        super(PSPLoss, self).__init__()
        self.aux_weight = aux_weight

    def forward(self, outputs, targets):
        
        loss = F.binary_cross_entropy_with_logits(torch.reshape(outputs[0], ([-1, 520, 520])), targets.to(torch.float32))      
        loss_aux = F.binary_cross_entropy_with_logits(torch.reshape(outputs[1], ([-1, 520, 520])), targets.to(torch.float32)) 

        return loss+self.aux_weight*loss_aux

criterion = PSPLoss(aux_weight=0.4)

In [None]:
optimizer = optim.SGD([
    {'params': net.feature_conv.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
    {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
    {'params': net.decode_feature.parameters(), 'lr': 1e-2},
    {'params': net.aux.parameters(), 'lr': 1e-2},
], momentum=0.9, weight_decay=0.0001)

def lambda_epoch(epoch):
    max_epoch = 100
    return math.pow((1-epoch/max_epoch), 0.9)

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

In [None]:
def train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net.to(device)

    torch.backends.cudnn.benchmark = True

    num_train_imgs = len(dataloaders_dict["train"].dataset)
    num_val_imgs = len(dataloaders_dict["val"].dataset)
    batch_size = dataloaders_dict["train"].batch_size

    iteration = 1
    logs = []

    batch_multiplier = 3

    for epoch in range(num_epochs):

        t_epoch_start = time.time()
        t_iter_start = time.time()
        epoch_train_loss = 0.0 
        epoch_val_loss = 0.0 

        print('-------------')
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  
                scheduler.step()
                optimizer.zero_grad()
                print('（train）')

            else:
                if((epoch+1) % 5 == 0):
                    net.eval() 
                    print('-------------')
                    print('（val）')
                else:
                    continue

            count = 0 
            for imges, anno_class_imges in dataloaders_dict[phase]:
                if imges.size()[0] == 1:
                    continue

                imges = imges.to(device)
                anno_class_imges = anno_class_imges.to(device)

                if (phase == 'train') and (count == 0):
                    optimizer.step()
                    optimizer.zero_grad()
                    count = batch_multiplier

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(imges)
                    loss = criterion(
                        outputs, anno_class_imges.long()) / batch_multiplier

                    if phase == 'train':
                        loss.backward()
                        count -= 1 

                        if (iteration % 10 == 0):
                            t_iter_finish = time.time()
                            duration = t_iter_finish - t_iter_start
                            print('iterations {} || Loss: {:.4f} || 10iter: {:.4f} sec.'.format(
                                iteration, loss.item()/batch_size*batch_multiplier, duration))
                            t_iter_start = time.time()

                        epoch_train_loss += loss.item() * batch_multiplier
                        iteration += 1

                    else:
                        epoch_val_loss += loss.item() * batch_multiplier

        t_epoch_finish = time.time()
        print('-------------')
        print('epoch {} || Epoch_TRAIN_Loss:{:.4f} ||Epoch_VAL_Loss:{:.4f}'.format(
            epoch+1, epoch_train_loss/num_train_imgs, epoch_val_loss/num_val_imgs))
        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

        log_epoch = {'epoch': epoch+1, 'train_loss': epoch_train_loss /
                     num_train_imgs, 'val_loss': epoch_val_loss/num_val_imgs}
        logs.append(log_epoch)
        df = pd.DataFrame(logs)
        df.to_csv("log_output.csv")

    torch.save(net.state_dict(), 'weights/pspnet50_' +
               str(epoch+1) + '.pth')

In [None]:
num_epochs = 40
train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=num_epochs)