# Import Libraries

In [1]:
import os
import sys
import shutil
import copy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam


import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import transforms
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
#import torchvision.transforms.functional as F
#from torchvision.datasets import OxfordIIITPet

from PIL import Image
import io
from io import BytesIO 

import cv2

from tqdm import tqdm
from tqdm.notebook import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from albumentations import (
    Compose, OneOf, Normalize, CenterCrop, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, RandomRotate90, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout,GridDropout
    )

from IPython.display import Image, display
from tensorflow.keras.preprocessing.image import load_img
import PIL
from PIL import Image as Im
from PIL import ImageOps

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Configurations

In [2]:
 class CFG:
        BASE_PATH = '../input/oxfordiiitpet-dataset/'
        IMAGE_PATH = '../input/oxfordiiitpet-dataset/images/images/'
        ANNOTATION_PATH = '../input/oxfordiiitpet-dataset/annotations/annotations/'
        
        DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        IMAGE_SIZE = 128
        BATCH_SIZE = 64
        
        def to_numpy(tensor):
            return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# Create DataFrame

In [3]:
all_df = pd.read_table(CFG.ANNOTATION_PATH+'list.txt',sep=' ',skiprows=5 )
all_df = all_df.iloc[:,:4]
all_df = all_df.set_axis(['image_name','id','species','breed'],axis=1)
all_df

In [4]:
all_df.groupby("id").count()

In [5]:
DEBUG = False
if DEBUG:
    all_df = all_df.sample(frac = 0.5).reset_index(drop = True)

In [6]:
all_df

# Split train and validation 7:3

In [7]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(all_df, test_size = 0.3)
print(train_df.shape, valid_df.shape)

In [8]:
train_df = train_df.reset_index(drop=True)
train_df

In [9]:
valid_df = valid_df.reset_index(drop=True)
valid_df

# Visualize images

In [10]:
#nomal images
image = Im.open(CFG.IMAGE_PATH+all_df["image_name"][0]+".jpg")
#image = Im.open('../input/oxfordiiitpet-dataset/images/images/Abyssinian_10.jpg')
display(image)
i = np.array(image)
i.shape

In [11]:
#mask image
#mask_image = PIL.ImageOps.autocontrast(load_img('../input/oxfordiiitpet-dataset/annotations/annotations/trimaps/Abyssinian_10.png'))
mask_image = PIL.ImageOps.autocontrast(load_img(CFG.ANNOTATION_PATH+'trimaps/'+all_df["image_name"][0]+'.png'))
display(mask_image)
rgb2gry = transforms.Grayscale()
mask_images = mask_image.convert('L')
#mask_image = mask_image.transpose(1,0,2)

mask = np.array(mask_image)
print(mask.shape)
#mask = mask.transpose(1,0,2)

obj_ids = np.unique(mask)
print(obj_ids)

# convert [0,127,255] to [0,1,2]
mask[mask == 0] = 1
mask[mask == 127] = 0
mask[mask == 255] = 1
print(mask)
print(np.unique(mask))
#obj_ids[:]
#masks = mask == obj_ids[:,None,None]

In [12]:
print(type(image))
print(type(mask_image))

# image and mask

In [13]:
def get_concat_h(im1, im2):
    dst = Im.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

get_concat_h(image, mask_image)

# Augmentation function

In [14]:
def get_transform(data):
    if data == 'train':
        return A.Compose([
                A.Resize(CFG.IMAGE_SIZE,CFG.IMAGE_SIZE),
                A.HorizontalFlip(p=0.5),
                #A.GridDropout(ratio=0.2, unit_size_min=None, unit_size_max=None, holes_number_x=5, holes_number_y=5, shift_x=0, shift_y=0, random_offset=False, fill_value=0, mask_fill_value=None, always_apply=False, p=0.5),
                A.Normalize(),
                ToTensorV2()
            ])
    elif data == 'valid':
        return A.Compose([
                A.Resize(CFG.IMAGE_SIZE,CFG.IMAGE_SIZE),
                A.Normalize(),
                ToTensorV2()
            ])

In [15]:
len(all_df)

In [16]:
#confirm path for dataset class
image_path = CFG.IMAGE_PATH
mask_path = CFG.ANNOTATION_PATH
print(os.path.join(image_path, all_df["image_name"][0]+'.jpg'))
print(os.path.join(mask_path, 'trimaps/'+all_df["image_name"][0]+'.png'))

# Dataset

In [17]:
class TrainDataset(Dataset):
    def __init__(self, image_paths, mask_paths, df, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.df = df
        self.image_names = df["image_name"]
        self.transforms = transforms
        
    def __len__(self):
        #print(len(self.df))
        return len(self.df)
    
    def __getitem__(self,idx):
        # get path
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_paths, image_name+'.jpg')
        mask_path = os.path.join(self.mask_paths, 'trimaps/'+image_name+'.png')
        #print(image_path,mask_path)
        #open image
        image = Im.open(image_path)
        mask_images = PIL.ImageOps.autocontrast(load_img(mask_path))
        rgb2gry = transforms.Grayscale()
        mask_images = rgb2gry(mask_images)
        #PIL to numpy
        image = np.array(image)
        image = image[:,:,:3]#remove some of alpha channels

        mask_images = np.array(mask_images)

        #convert [0,127,255] -> [0,1,2] #0が黒→物体内部、127がグレー→背景、255が白→物体輪郭
        mask_images[mask_images == 0] = 0
        mask_images[mask_images == 127] = 1
        mask_images[mask_images == 255] = 2
        #print(mask_images)

        obj_ids = np.unique(mask_images)
        
        if self.transforms:
            augmented = self.transforms(image=image,mask=mask_images)
            image,mask = augmented['image'],augmented['mask']
            #print(image.shape,mask.shape)

        mask  = mask.unsqueeze(0)


        return image,mask, image_path, mask_path

In [18]:
train_dataset = TrainDataset(
    CFG.IMAGE_PATH,
    CFG.ANNOTATION_PATH,
    train_df,
    transforms = get_transform(data='train'),
)
valid_dataset = TrainDataset(
    CFG.IMAGE_PATH,
    CFG.ANNOTATION_PATH,
    valid_df,
    transforms = get_transform(data='valid'),
)

In [19]:
# CPUのコア数を確認
import os
os.cpu_count()  # コア数

In [20]:
train_loader = DataLoader(train_dataset, CFG.BATCH_SIZE, shuffle = False, num_workers=os.cpu_count(), pin_memory=True,drop_last=True)
train_dataset[0][2:]

In [21]:
valid_loader = DataLoader(valid_dataset, CFG.BATCH_SIZE, shuffle = False, num_workers=os.cpu_count(), pin_memory=True,drop_last=True)
valid_dataset[0]

In [22]:
best_f_image_path, best_f_mask_path = valid_dataset[0][2:]
best_f_image_path, best_f_mask_path

# Define U-Net model

## part of U-Net models

In [23]:
#pytorch official
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1,bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
        #self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        #self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1,x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX //2,
                        diffY // 2, diffY - diffY //2])
        
        x = torch.cat([x2,x1],dim=1)
        return self.conv(x)
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)

## main part of U-Net

In [24]:
#pytorch official
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes,bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
        self.down3 = Down(256,512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512,1024 // factor)
        self.up1 = Up(1024,512 // factor, bilinear)
        self.up2 = Up(512,256 // factor, bilinear)
        self.up3 = Up(256,128 // factor, bilinear)
        self.up4 = Up(128,64, bilinear)
        self.outc = OutConv(64,n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5,x4)
        x = self.up2(x,x3)
        x = self.up3(x,x2)
        x = self.up4(x,x1)
        logits = self.outc(x)
        return logits

# Custom metrics

In [25]:
def Dice_Coeff(x,y,flag):
    DICE = []
    #x:predmask, y:truemask
    x = CFG.to_numpy(x)
    y = CFG.to_numpy(y)
    x1 = copy.copy(x)
    y1 = copy.copy(y)
    x2 = copy.copy(x)
    y2 = copy.copy(y)
    x3 = copy.copy(x)
    y3 = copy.copy(y)
    
    if flag == 1:#foreground
        x1[x1==0] = 3
        y1[y1==0] = 3
        x1[x1==1] = 0
        y1[y1==1] = 0
        x1[x1==2] = 0
        y1[y1==2] = 0
        x1[x1==3] = 1
        y1[y1==3] = 1
        x1 = torch.tensor(x1)
        y1 = torch.tensor(y1)
        
        #Foreground = []
        inter = x1&y1
        for i in range(CFG.BATCH_SIZE):
            intersection = torch.sum(inter[i])
            dice = (2*intersection)/(torch.sum(x1[i])+torch.sum(y1[i]))                
            DICE.append(dice)
        
    elif flag == 2:#background
        x2[x2==1] = 1
        y2[y2==1] = 1
        x2[x2==2] = 0
        y2[y2==2] = 0
        x2[x2==0] = 0
        y2[y2==0] = 0
        x2 = torch.tensor(x2)
        y2 = torch.tensor(y2)

        #Background = []
        inter = x2&y2
        for i in range(CFG.BATCH_SIZE):
            intersection = torch.sum(inter[i])
            dice = (2*intersection)/(torch.sum(x2[i])+torch.sum(y2[i]))
            DICE.append(dice)
            
    elif flag == 3:#border
        x3[x3==1] = 0
        y3[y3==1] = 0
        x3[x3==2] = 1
        y3[y3==2] = 1
        x3[x3==0] = 0
        y3[y3==0] = 0
        x3 = torch.tensor(x3)
        y3 = torch.tensor(y3)

        #Border = []
        inter = x3&y3
        for i in range(CFG.BATCH_SIZE):
            intersection = torch.sum(inter[i])
            #dice = (2*intersection)/(CFG.IMAGE_SIZE*CFG.IMAGE_SIZE+CFG.IMAGE_SIZE*CFG.IMAGE_SIZE)
            dice = (2*intersection)/(torch.sum(x3[i])+torch.sum(y3[i]))
            DICE.append(dice)

    return DICE

# Visualize image (input and mask)

In [26]:
def show_pred_mask(original,true_tensor,out_tensor,cnt):
        unnormalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
        std=[1/0.229, 1/0.224, 1/0.255]
        )
        original = unnormalize(original)

        y = CFG.to_numpy(true_tensor)
        z = CFG.to_numpy(out_tensor)

        original = original.permute(2,1,0)
        original = CFG.to_numpy(original)
        original = np.rot90(original)
        original = np.flipud(original)
        
        #convert [0,127,255] -> [0,1,2] #0が黒→物体内部、127がグレー→背景、255が白→物体輪郭
        #mask_images[mask_images == 127] = 1
        #mask_images[mask_images == 255] = 2

        y[y==0] = 255
        y[y==1] = 127
        y[y==2] = 0

        z[z==0] = 255
        z[z==1] = 127
        z[z==2] = 0

        n_data = 3
        row=1
        col=3
        
        fig, ax = plt.subplots(nrows=row, ncols=col,figsize=(15,18))
        #input image
        ax[0].imshow(original)
        #truth mask
        ax[1].imshow(y,cmap='Greys')
        #pred mask
        ax[2].imshow(z,cmap='Greys')
        plt.show()
        


# Save best,median,worst tensor
* original tensor
* pred tensor
* true tensor

# Training

In [27]:
import statistics
plot_train_loss = []
plot_train_dice1 = []
plot_train_dice2 = []
plot_train_dice3 = []

plot_valid_loss = []
plot_valid_dice1 = []
plot_valid_dice2 = []
plot_valid_dice3 = []

L_t = []
D1_t = []
L_v = []
D1_v = []

D2_t = []
D2_v = []

D3_t = []
D3_v = []



def training_model(model, datasets, dataloaders, criterion, optimizer, num_epochs, device):
    best_loss = 2.0
    best_f_dice = 0.0
    worst_loss = 0.0
    
    worst_f_dice = 1.0
    median_loss = 0.0
    median_f_dice = 0.0
    
    best_b_dice = 0.0
    worst_b_dice = 1.0
    median_b_dice = 0.0
    
    best_n_dice = 0.0
    worst_n_dice = 1.0
    median_n_dice = 0.0
    #scaler = GradScaler()
    
    max_value = 0
    min_value = 10**9
    med_value = 0
    
    medlist = []
    med_ori = []
    med_pre = []
    med_tru = []
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('-'*10)
        
        for phase in ['train','valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_dice1 = 0.0
            running_dice2 = 0.0
            running_dice3 = 0.0
                
            stream = tqdm(dataloaders[phase])
            for cnt, (inputs, masks,im_path,ms_path) in enumerate(stream, start=1):
                total_value = []

                original = inputs
                inputs = inputs.to(device=CFG.DEVICE, dtype=torch.float32)
                masks = masks.to(device=CFG.DEVICE, dtype=torch.long)
                optimizer.zero_grad()
                
                #with autocast():
                with torch.set_grad_enabled(phase=='train'):
                    model = model.to(CFG.DEVICE)
                    outputs = model(inputs)
                    
                    pred_masks = F.softmax(outputs, dim=1).float()
                    masks =  masks.squeeze(1)
                    loss = criterion(pred_masks, masks)#CrossEntropyLoss
                    _,ch = torch.max(outputs, -3)

                    flag = 1
                    dice11 = sum(Dice_Coeff(ch,masks, flag))/CFG.BATCH_SIZE#return array(dice1,dice2,...,dice(batch_size))
                    dice22 = sum(Dice_Coeff(ch,masks, flag+1))/CFG.BATCH_SIZE
                    dice33 = sum(Dice_Coeff(ch,masks, flag+2))/CFG.BATCH_SIZE
                    dice1 = Dice_Coeff(ch, masks,flag)
                    dice2 = Dice_Coeff(ch, masks,flag+1)
                    dice3 = Dice_Coeff(ch, masks, flag+2)
                    #dice11 = sum(dice1)/CFG.BATCH_SIZE#return array(dice1,dice2,...,dice(batch_size))
                    #dice22 = sum(dice2)/CFG.BATCH_SIZE
                    #dice33 = sum(dice3)/CFG.BATCH_SIZE
                    for i in range(CFG.BATCH_SIZE):
                        total_value.append((dice1[i]+dice2[i]+dice3[i]).item())

                    #max dice
                    sum_max = max(total_value)
                    if sum_max > max_value:
                        max_value = sum_max
                        max_idx = total_value.index(max_value)
                        max_ori = original[max_idx]
                        max_pre = masks[max_idx]
                        max_tru = ch[max_idx]
                    
                    #min dice
                    sum_min = min(total_value)
                    if sum_min < min_value:
                        min_value = sum_min
                        min_idx = total_value.index(min_value)
                        min_ori = original[min_idx]
                        min_pre = masks[min_idx]
                        min_tru = ch[min_idx]
                        
                    #med dice
                    total_value.append(0)
                    sum_med = statistics.median(total_value)
                    if sum_med > med_value:
                        med_value = sum_med
                        med_idx = total_value.index(med_value)
                        med_ori = original[med_idx]
                        med_pre = masks[med_idx]
                        med_tru = ch[med_idx]
                        
                    #Diceの真ん中の値を取得したい
                    #今浮かんでる構成は、各バッチごとの中央を以下のコードで求めてそれぞれの値（idx,ori,pre,tru）を配列で管理する。
                    
                    #med dice メモリ食いすぎる
                    #total_value.sort()
                    #sum_med = total_value[CFG.BATCH_SIZE//2]
                    #medlist.append(sum_med)
                    #med_idx = total_value.index(sum_med)
                    #med_ori.append(original[med_idx])
                    #med_pre.append(masks[med_idx])
                    #med_tru.append(ch[med_idx])

                    #dice11 = sum(Dice_Coeff(ch, masks,flag))/CFG.BATCH_SIZE#return array(dice1,dice2,...,dice(batch_size))
                    #dice22 = sum(Dice_Coeff(ch, masks,flag+1))/CFG.BATCH_SIZE
                    #dice33 = sum(Dice_Coeff(ch, masks, flag+2))/CFG.BATCH_SIZE
                    running_loss += loss.item()
                    running_dice1 += dice11
                    running_dice2 += dice22
                    running_dice3 += dice33
                    
                    if phase =='train':
                        L_t.append(loss.item())
                        D1_t.append(CFG.to_numpy(dice11))
                        D2_t.append(CFG.to_numpy(dice22))
                        D3_t.append(CFG.to_numpy(dice33))
                    else:
                        L_v.append(loss.item())
                        D1_v.append(CFG.to_numpy(dice11))
                        D2_v.append(CFG.to_numpy(dice22))
                        D3_v.append(CFG.to_numpy(dice33))

                    if cnt-1 == 0:
                        show_pred_mask(original[0],masks[0],ch[0],cnt)
                        show_pred_mask(original[1],masks[1],ch[1],cnt)
                    else:
                        continue

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
################################################################
                    if phase == 'valid' and loss < best_loss:
                        best_loss = loss
                    if phase == 'valid' and loss > worst_loss:
                        worst_loss = loss
################################################################fffff
                    if phase == 'valid' and dice11 > best_f_dice:
                        best_f_dice = dice11
                        print(f'[Foreground] Best Dice Result at the moment EPOCH:{epoch+1} Dice:{best_f_dice}')

                    if phase == 'valid' and dice11 < worst_f_dice:
                        worst_f_dice = dice11
                        print(f'[Foreground] Worst Dice Result at the moment EPOCH:{epoch+1} Dice:{worst_f_dice}')

################################################################bbbbb
                    if phase == 'valid' and dice22 > best_b_dice:
                        best_b_dice = dice22
                        print(f'[Background] Best Dice Result at the moment EPOCH:{epoch+1} Dice:{best_b_dice}')

                    if phase == 'valid' and dice22 < worst_b_dice:
                        worst_b_dice = dice22
                        print(f'[Background] Worst Dice Result at the moment EPOCH:{epoch+1} Dice:{worst_b_dice}')
                        
################################################################bbbbb
                    if phase == 'valid' and dice33 > best_n_dice:
                        best_n_dice = dice33
                        print(f'[Border] Best Border Dice Result at the moment EPOCH:{epoch+1} Dice:{best_n_dice}')

                    if phase == 'valid' and dice33 <worst_n_dice:
                        worst_n_dice = dice33
                        print(f'[Border] Worst Border Dice Result at the moment EPOCH:{epoch+1} Dice:{worst_n_dice}')

            epoch_loss = running_loss/len(dataloaders[phase])
            epoch_dice1 = running_dice1/len(dataloaders[phase])
            epoch_dice2 = running_dice2/len(dataloaders[phase])
            epoch_dice3 = running_dice3/len(dataloaders[phase])
            
            if phase == 'train':
                plot_train_loss.append(epoch_loss)
                plot_train_dice1.append(epoch_dice1)
                plot_train_dice2.append(epoch_dice2)
                plot_train_dice3.append(epoch_dice3)
            else:
                plot_valid_loss.append(epoch_loss)
                plot_valid_dice1.append(epoch_dice1)
                plot_valid_dice2.append(epoch_dice2)
                plot_valid_dice3.append(epoch_dice3)
            
            print(f'{phase} Loss: {epoch_loss} Foreground Dice: {epoch_dice1} Background Dice: {epoch_dice2} Border Dice: {epoch_dice3}')
            
        print()
        
    #max dice image
    print(f'max dice image')
    show_pred_mask(max_ori,max_pre,max_tru,cnt)
    
    #median dice image
    #medlist.append(0)
    #medice = statistics.median(medlist)
    #med_idx = medlist.index(medice)
    print(f'median dice image')
    show_pred_mask(med_ori,med_pre,med_tru,cnt)
    
    #min dice image
    print(f'min dice image')
    show_pred_mask(min_ori,min_pre,min_tru,cnt)
    
    print(f'Best val Loss:{best_loss}')
    print(f'Median val Loss:{(best_loss+worst_loss)/2}')
    print(f'Worst val Loss:{worst_loss}')
        
    print(f'[Foreground] Best val Dice:{best_f_dice}　')
    print(f'[Foreground] Median val Dice:{(best_f_dice+worst_f_dice)/2}　')
    print(f'[Foreground] Worst val Dice:{worst_f_dice}　')
    
    print(f'[Background] Best val Dice:{best_b_dice}　')
    print(f'[Background] Median val Dice:{(best_b_dice+worst_b_dice)/2}　')
    print(f'[Background] Worst val Dice:{worst_b_dice}　')
    
    print(f'[Border] Best val Dice:{best_n_dice}　')
    print(f'[Border] Median val Dice:{(best_n_dice+worst_n_dice)/2}　')
    print(f'[Border] Worst val Dice:{worst_n_dice}　')
    
    return model

In [28]:
datasets = {'train': train_dataset,
            'valid': valid_dataset}

dataloaders = {'train': train_loader,
               'valid': valid_loader}
model = UNet(n_channels=3, n_classes=3)
#optimizer = optim.SGD(model.paradfmeters(),lr=0.001,momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
num_epochs = 30

In [29]:
trained_model = training_model(model,datasets, dataloaders, criterion, optimizer, num_epochs, CFG.DEVICE)

In [30]:
data = (list(np.array(D1_v)),list(np.array(D2_v)),list(np.array(D3_v)))
fig,ax = plt.subplots()
ax.set_title('Dice')
ax.set_xticklabels(['Foreground', 'Background', 'Border'])
ax.boxplot(data)
plt.show()


In [41]:
data = (list(np.array(D1_t)),list(np.array(D2_t)),list(np.array(D3_t)))
fig,ax = plt.subplots()
ax.set_title('Dice')
ax.set_xticklabels(['Foreground', 'Background', 'Border'])
ax.boxplot(data)
plt.show()

In [46]:
#loss
plt.title("Loss",fontsize=18)
plt.xlabel("Epoch",fontsize=14)
plt.ylabel("Loss",fontsize=14)

plt.ylim(0.8, 1.1)
plt.xlim(0, num_epochs+1, 1)

plt.xticks(np.arange(0, num_epochs+1, 1))
plt.plot(range(1, num_epochs+1),plot_train_loss,label='Training Loss',marker ='o')
plt.plot(range(1, num_epochs+1),plot_valid_loss,label='Validation Loss',marker ='o')
plt.legend(frameon=False, fontsize=14)

plt.show()

#dice1
plt.title("Dice Foreground",fontsize=18)
plt.xlabel("Epoch",fontsize=14)
plt.ylabel("Dice",fontsize=14)

plt.ylim(0.1, 0.83)
plt.xlim(0, num_epochs+1, 1)

plt.xticks(np.arange(0, num_epochs+1, 1))
#plt.plot(range(1, num_epochs+1),plot_train_dice1,label='Training Dice',marker ='o')
plt.plot(range(1, num_epochs+1),plot_valid_dice1,label='Foreground Dice',marker ='o')
plt.plot(range(1, num_epochs+1),plot_valid_dice2,label='Background Dice',marker ='o')
plt.plot(range(1, num_epochs+1),plot_valid_dice3,label='border Dice',marker ='o')
plt.legend(frameon=False, fontsize=14)

plt.show()

In [33]:
fig,ax = plt.subplots()
bp = ax.boxplot(list(np.array_split(L_v, num_epochs)))
plt.title('Loss')
plt.ylabel('loss')
plt.ylim([0.7, 1.2])
plt.grid()

In [34]:
fig,ax = plt.subplots()
bp = ax.boxplot(list(np.array_split(D1_v, num_epochs)))
plt.title('Dice Foreground')
plt.ylabel('dice')
plt.ylim([0, 1.0])
plt.grid()

In [35]:
fig,ax = plt.subplots()
bp = ax.boxplot(list(np.array_split(D2_v, num_epochs)))
plt.title('Dice Background')
plt.ylabel('dice')
plt.ylim([0, 1.0])
plt.grid()

In [36]:
fig,ax = plt.subplots()
bp = ax.boxplot(list(np.array_split(D3_v, num_epochs)))
plt.title('Dice Border')
plt.ylabel('dice')
plt.ylim([0, 1.0])
plt.grid()

# d = (2*intersection)/(CFG.IMAGE_SIZE*CFG.IMAGE_SIZE*CFG.BATCH_SIZE + CFG.IMAGE_SIZE*CFG.IMAGE_SIZE*CFG.BATCH_SIZE)

In [37]:
import torch
tensorx = torch.tensor([[[1,0,1,1],[1,1,1,1],[1,0,0,1],[1,1,0,0]],[[1,0,1,0],[0,0,1,1],[0,0,0,1],[0,1,0,0]]])
tensory = torch.tensor([[[1,0,0,1],[1,0,0,1],[1,1,0,1],[1,0,0,0]],[[0,1,0,1],[0,1,0,1],[0,1,1,1],[1,0,0,0]]])

print(f'pred {tensorx}')
print(f'true {tensory}')
print('------------------')
print(f'and {tensorx & tensory}')
print('------------------')
print(tensorx[0])

print(torch.sum(tensorx[0]))

print(torch.sum(tensorx[i])+sum(tensory[i]))

d = []
inter = tensorx&tensory
print(f'intersection {inter}')

for i in range(2):
    intersection = torch.sum(inter[i])
    print(intersection)
    dice = (2*intersection)/(torch.sum(tensorx[i])+torch.sum(tensory[i]))
    d.append(dice)
print(f'Dice{d}')

In [None]:
import torch

ori=[]
pre=[]
ch=[]

tensorx = torch.tensor([[[1,0,1,1],[1,1,1,1],[1,0,0,1],[1,1,0,0]],[[1,0,1,0],[0,0,1,1],[0,0,0,1],[0,1,0,0]]])
tensory = torch.tensor([[[1,0,0,1],[1,0,0,1],[1,1,0,1],[1,0,0,0]],[[0,1,0,1],[0,1,0,1],[0,1,1,1],[1,0,0,0]]])

ori.append(tensorx)
print(ori)
ori.append(tensory)
print(ori)

print(ori[0])

In [None]:
l = [1,4,2,5]
l