## Import

In [None]:
import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import optim
from PIL import Image
import torch.nn.functional as F
import cv2
# from sklearn.model_selection import train_test_split

%matplotlib inline

## Root

In [None]:
ROOT_PATH = '/home/yasaisen/Desktop/09_research/09_research_main/lab_03'

In [None]:
dataset_folder = 'dataset_C_v_2.9.3'

train_img_path = os.path.join(ROOT_PATH, dataset_folder, 'train_for_base_imgs')
train_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'train_for_base_mask')

valid_img_path = os.path.join(ROOT_PATH, dataset_folder, 'valid_imgs')
valid_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'valid_mask')

test_img_path = os.path.join(ROOT_PATH, dataset_folder, 'test_imgs')
test_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'test_mask')

## Aug

In [None]:
img_size = 224
train_bsz = 4
device = 'cuda'
epochs = 30
valid_bsz = 8
test_bsz = 8

## Dataset

In [None]:
def get_df(img_path, mask_path):
    images, masks = [], []

    i = 0

    for get_img_name in os.listdir(img_path):
        images += [os.path.join(img_path, get_img_name)] # NORMAL_G1_Lid1_LRid293_Gid3133_Bl30.png
        masks += [os.path.join(mask_path, get_img_name.replace(get_img_name.split('_')[-1], 'C4.png'))] # NORMAL_G1_Lid1_LRid293_Gid3133_C4.png
        
        i = i+1

    PathDF = pd.DataFrame({'images': images, 'masks': masks})
    print(i)
    PathDF.head()
    return PathDF

In [None]:
train_df = get_df(train_img_path, train_mask_path)
valid_df = get_df(valid_img_path, valid_mask_path)
test_df = get_df(test_img_path, test_mask_path)

In [None]:
def plot_example(idx, df):
    image_path = df['images'].iloc[idx]
    mask_path = df['masks'].iloc[idx]
    image = Image.open(image_path)#.convert('RGB')
    mask = Image.open(mask_path)#.convert('RGB')
    
    fig, ax = plt.subplots(1, 3, figsize=(8,4))
    ax[0].imshow(np.array(image).astype(np.uint8))
    ax[0].set_title("Image")
    ax[1].imshow(np.array(mask).astype(np.uint8))
    ax[1].set_title("Mask")
    img = np.array(image) * 0.3 + np.array(mask) * 0.7
    img = img.astype(np.uint8)
    ax[2].imshow(img)
    ax[2].set_title('')
    plt.show()

In [None]:
plot_example(0, train_df)
plot_example(50, valid_df)
plot_example(88, test_df)
plot_example(190, train_df)

In [None]:
transform = transforms.Compose([
            transforms.ToTensor()
            ])

In [None]:
# mask_path = '/home/yasaisen/Desktop/09_research/09_research_main/lab_03/dataset_C_v_2.9.3/train_for_base_mask/RSLN_L_G10_Lid45_LRid112_Gid7024_C4.png'
# label = Image.open(mask_path)
# label = np.array(label)

In [None]:
class mod_Dataset(Dataset):
    def __init__(self, path_df, transform=None):
        self.path_df = path_df
        self.transform = transform

    def __len__(self):
        return self.path_df.shape[0]
    
    def __getitem__(self, idx):
        if self.transform is not None:
            trans_Resize = transforms.Resize(224)

            images = trans_Resize(Image.open(self.path_df.iloc[idx]['images']).convert('RGB'))
            images = self.transform(images)

            mask = trans_Resize(Image.open(self.path_df.iloc[idx]['masks']))
            mask = np.array(mask)
            masks = np.zeros([np.max(np.unique(mask))+1, mask.shape[0], mask.shape[1]])
            for i in range(mask.shape[0]):
                for j in range(mask.shape[1]):
                    masks[mask[i][j]][i][j] = 1

            masks = torch.from_numpy(masks)
            masks = masks.type(torch.float32)

        return images, masks

In [None]:
train_data = mod_Dataset(train_df, transform)
valid_data = mod_Dataset(valid_df, transform)
test_data  = mod_Dataset(test_df, transform)

train_loader = DataLoader(train_data, batch_size=train_bsz, shuffle=True , num_workers=0, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_data, batch_size=valid_bsz, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_data , batch_size=test_bsz , shuffle=False, num_workers=0)

In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

def round(temp):
    return np.round((temp - np.min(temp))/((np.max(temp) - np.min(temp))))

def yasai_show_v2(dataset, idx, model=None, label=False):
    image, mask = dataset[idx]
    if model is not None:
        pred = model(image.unsqueeze(0))
        with torch.no_grad():
            pred = np.asarray(pred).squeeze()
    with torch.no_grad():
        image = np.asarray(image).transpose(1, 2, 0)
        mask = np.asarray(mask)

    if model is not None:
        tempdict = {}
        tempdict['image'] = image
        for i in range(pred.shape[0]):
            tempdict['pred_' + str(i)] = 0.4 * round(pred[i]) + 0.6 * image[...,0].squeeze()
        visualize(**tempdict)

    if label:
        tempdict = {}
        tempdict['image'] = image
        for i in range(mask.shape[0]):
            tempdict['mask_' + str(i)] = 0.4 * round(mask[i]) + 0.6 * image[...,0].squeeze()
        visualize(**tempdict)

def yasai_model_save_v1(model, text=''):
    temp = os.path.join(os.getcwd(), 'model_' + text + datetime.now().strftime("%y%m%d%H%M.pt"))
    torch.save({'state_dict': model.state_dict(), 'model': model}, temp)
    print('Successfully saved to ' + temp)

def yasai_model_load_v1(path):
    temp = torch.load(path)
    model = temp['model']
    model.load_state_dict(temp['state_dict'])
    print('Successfully loaded from ' + path)
    return model

def yasai_compute_iou_v1(pred, label):
    # print(label.shape, np.unique(label))
    # print(round(pred).shape, np.unique(round(pred)))
    label_c = label == 1
    pred_c = round(pred) == 1

    intersection = np.logical_and(pred_c, label_c).sum()
    union = np.logical_or(pred_c, label_c).sum()

    if union != 0 and np.sum(label_c) != 0:
        return intersection / union
    
def yasai_compute_batch_iou_v1(model, data_loader):
    ious = []
    for image, mask in tqdm(data_loader, desc='Iterating'):
        pred = model(image)
        with torch.no_grad():
            pred = np.asarray(pred).squeeze()
            mask = np.asarray(mask)
        ious += [yasai_compute_iou_v1(pred, mask)]
    print(sum(ious)/len(ious))

In [None]:
yasai_show_v1(train_data, 1, None)

## Model

In [None]:
class Decoder(nn.Module):
  def __init__(self, in_channels, middle_channels, out_channels):
    super(Decoder, self).__init__()
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.conv_relu = nn.Sequential(
        nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
        )
  def forward(self, x1, x2):
    x1 = self.up(x1)
    x1 = torch.cat((x1, x2), dim=1)
    x1 = self.conv_relu(x1)
    return x1

class Unet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = torchvision.models.resnet18(weights='DEFAULT')
        self.base_layers = list(self.base_model.children())
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            self.base_layers[1],
            self.base_layers[2])
        self.layer2 = nn.Sequential(*self.base_layers[3:5])
        self.layer3 = self.base_layers[5]
        self.layer4 = self.base_layers[6]
        self.layer5 = self.base_layers[7]
        self.decode4 = Decoder(512, 256+256, 256)
        self.decode3 = Decoder(256, 256+128, 256)
        self.decode2 = Decoder(256, 128+64, 128)
        self.decode1 = Decoder(128, 64+64, 64)
        self.decode0 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
            )
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        e1 = self.layer1(input) # 64,128,128
        e2 = self.layer2(e1) # 64,64,64
        e3 = self.layer3(e2) # 128,32,32
        e4 = self.layer4(e3) # 256,16,16
        f = self.layer5(e4) # 512,8,8
        d4 = self.decode4(f, e4) # 256,16,16
        d3 = self.decode3(d4, e3) # 256,32,32
        d2 = self.decode2(d3, e2) # 128,64,64
        d1 = self.decode1(d2, e1) # 64,128,128
        d0 = self.decode0(d1) # 64,256,256
        out = self.conv_last(d0) # 1,256,256
        return out

In [None]:
model = yasai_model_load_v1('/home/yasaisen/Desktop/09_research/09_research_main/lab_10/model_bast_ima_2305151053.pt')

In [None]:
# import segmentation_models_pytorch as smp

In [None]:
# model = smp.Unet(
#     in_channels=3,
#     classes=4,
#     activation="softmax").to(device)

In [None]:
# model = Unet(4).to(device)
# # print(model)
# t = torch.randn((4, 3, 224, 224)).to(device)
# print(t.shape)
# get = model(t)
# print(get.shape)

# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)
#     break

## Train

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduction=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1
    return intersection / union

In [None]:
def train_loop(model, optimizer, criterion, train_loader, device=device):
    running_loss = 0
    model.train()
    pbar = tqdm(train_loader, desc='Iterating over train data')

    final_dice_coef = 0 
    
    for imgs, masks in pbar:
        # pass to device

        # print(type(imgs), imgs.shape)
        # print(type(masks), masks.shape)

        imgs = imgs.to(device)
        masks = masks.to(device)

        # forward
        out = model(imgs)
        loss = criterion(out, masks)
        running_loss += loss.item() * imgs.shape[0]
#         print(loss.item())
        
        out_cut = np.copy(out.detach().cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
        train_dice = dice_coef_metric(out_cut, masks.data.cpu().numpy())
        final_dice_coef += train_dice 
        
        # optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    running_loss /= len(train_loader.sampler)
    return {'dice coef':final_dice_coef/len(train_loader), 
                'loss':running_loss}

In [None]:
def eval_loop(model, criterion, eval_loader, device=device):
    
    running_loss = 0
    final_dice_coef = 0 
    
    model.eval()
    with torch.no_grad():

        pbar = tqdm(eval_loader, desc='Interating over evaluation data')
        
        for imgs, masks in pbar:
            
            imgs = imgs.to(device)
            masks = masks.to(device)
            
            out = model(imgs)
            loss = criterion(out, masks)
            running_loss += loss.item() * imgs.shape[0]
#             print(loss.item())
            
            out_cut = np.copy(out.detach().cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            valid_dice = dice_coef_metric(out_cut, masks.data.cpu().numpy())
            final_dice_coef += valid_dice 
            
    running_loss /= len(eval_loader.sampler)   
    return {
                'dice coef':final_dice_coef/len(eval_loader), 
                'loss':running_loss}

In [None]:
def train(model, optimizer, criterion, scheduler, train_loader, 
          valid_loader,device = device,
          num_epochs = epochs,
          valid_loss_min = np.inf):
    
    train_loss_list = []
    train_dice_coef = []
    val_loss_list = []
    val_dice_coef = []
    
    for e in range(num_epochs):
        
        train_metrics = train_loop(model, optimizer, criterion, train_loader, device=device)
        
        val_metrics = eval_loop(model, criterion, valid_loader, device=device)
        
        scheduler.step(val_metrics['dice coef'])
        
        train_loss_list.append(train_metrics['loss']) 
        train_dice_coef.append(train_metrics['dice coef'])
        val_loss_list.append(val_metrics['loss'])
        val_dice_coef.append(val_metrics['dice coef'])
        
        print_string = f"Epoch: {e+1}\n"
        print_string += f"Train Loss: {train_metrics['loss']:.5f}\n"
        print_string += f"Train Dice Coef: {train_metrics['dice coef']:.5f}\n"
        print_string += f"Valid Loss: {val_metrics['loss']:.5f}\n"
        print_string += f"Valid Dice Coef: {val_metrics['dice coef']:.5f}\n"
        print(print_string)
        
        # save model
        if val_metrics["loss"] <= valid_loss_min:
            torch.save(model.state_dict(), "UNET.pt")
            valid_loss_min = val_metrics["loss"]
    return [train_loss_list,
    train_dice_coef,
    val_loss_list,
    val_dice_coef]

In [None]:
# optimizer = optim.Adam(model.parameters(), lr=0.01)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)
# # criterion = nn.BCELoss(reduction='mean')
# criterion = FocalLoss()
# train_loss_list, train_dice_coef,val_loss_list,val_dice_coef = train(
#     model, optimizer, criterion, scheduler, train_loader, valid_loader)

In [None]:
# yasai_compute_batch_iou(model, test_loader)

In [None]:
# yasai_compute_batch_iou(model, valid_loader)

In [None]:
for i in range(1500):
    yasai_show_v2(test_data, i, model)

## Eval

In [None]:
def plot_predictions(model, idx, transforms):
    img = Image.open(test_df['images'].iloc[idx]).convert('RGB')
    mask = Image.open(test_df['masks'].iloc[idx])
    
    tensor_img = transforms(img)
    tensor_img = tensor_img.unsqueeze(0).to(device)
    
    model.eval()
    
    with torch.no_grad():
        pred = model(tensor_img)[0].detach().cpu().numpy()
        pred = pred.transpose((1,2,0)).squeeze()
        print(np.max(pred))
        rounded = np.round(pred)

    
        
    plot_images = {
        'Image': img,
        'Mask': mask,
        'Predicted Mask': pred,
        'Predicted Rounded Mask':rounded
    }
    
    fig, ax = plt.subplots(1, 4, figsize=(16,4))
    for i, key in enumerate(plot_images.keys()):
        ax[i].imshow(plot_images[key])
        ax[i].set_title(key)
        
    plt.show()

In [None]:
plt.plot(np.arange(1, epochs + 1), train_loss_list, label="train loss")
plt.plot(np.arange(1, epochs + 1), val_loss_list, label="val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and validation loss")
plt.show()

In [None]:
plt.plot(np.arange(1, epochs + 1), train_dice_coef, label="train dice score")
plt.plot(np.arange(1, epochs + 1), val_dice_coef, label="val dice score")
plt.xlabel("Epoch")
plt.ylabel("Dice")
plt.legend()
plt.title("Training and validation Dice Score")
plt.show()

In [None]:
plot_predictions(model, 59, transform)
plot_predictions(model, 0, transform)
plot_predictions(model, 26, transform)
plot_predictions(model, 3, transform)