# Import Libraries

In [1]:
import os
import sys
import shutil
import copy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam


import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import transforms
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
#import torchvision.transforms.functional as F
#from torchvision.datasets import OxfordIIITPet

from PIL import Image
import io
from io import BytesIO 

import cv2

from tqdm import tqdm
from tqdm.notebook import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from albumentations import (
    Compose, OneOf, Normalize, CenterCrop, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, RandomRotate90, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout,GridDropout
    )

from IPython.display import Image, display
from tensorflow.keras.preprocessing.image import load_img
import PIL
from PIL import Image as Im
from PIL import ImageOps

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Configurations

In [2]:
 class CFG:
        BASE_PATH = '../input/oxfordiiitpet-dataset/'
        IMAGE_PATH = '../input/oxfordiiitpet-dataset/images/images/'
        ANNOTATION_PATH = '../input/oxfordiiitpet-dataset/annotations/annotations/'
        
        DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        IMAGE_SIZE = 256
        BATCH_SIZE = 16
        
        DEBUG  = False
        def debug(DEBUG):
            if DEBUG:
                df = df.sample(frac = 0.5).reset_index(drop = True)
            return df
        
        def to_numpy(tensor):
            return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# Create DataFrame

In [3]:
all_df = pd.read_table(CFG.ANNOTATION_PATH+'list.txt',sep=' ',skiprows=5 )
all_df = all_df.iloc[:,:4]
all_df = all_df.set_axis(['image_name','id','species','breed'],axis=1)
all_df

# Split train and validation 8:2

In [4]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(all_df, test_size = 0.2)
print(train_df.shape, valid_df.shape)

In [5]:
train_df = train_df.reset_index(drop=True)
train_df

In [6]:
valid_df = valid_df.reset_index(drop=True)
valid_df

# Visualize images

In [7]:
#nomal images
image = Im.open(CFG.IMAGE_PATH+all_df["image_name"][0]+".jpg")
#image = Im.open('../input/oxfordiiitpet-dataset/images/images/Abyssinian_10.jpg')
display(image)
i = np.array(image)
i.shape

In [8]:
#mask image
#mask_image = PIL.ImageOps.autocontrast(load_img('../input/oxfordiiitpet-dataset/annotations/annotations/trimaps/Abyssinian_10.png'))
mask_image = PIL.ImageOps.autocontrast(load_img(CFG.ANNOTATION_PATH+'trimaps/'+all_df["image_name"][0]+'.png'))
display(mask_image)
rgb2gry = transforms.Grayscale()
mask_images = mask_image.convert('L')
#mask_image = mask_image.transpose(1,0,2)
mask = np.array(mask_image)
#print(mask)
#mask = mask.transpose(1,0,2)

obj_ids = np.unique(mask)
print(obj_ids)

#new = np.where((mask!=0)&(mask!=255), 1,mask)
#new = np.where((new!=0)&(new!=127), 2,new)
# convert [0,127,255] to [0,1,2]
mask[mask == 0] = 1
mask[mask == 127] = 0
mask[mask == 255] = 1
#print(mask)
print(np.unique(mask))
#obj_ids[:]
#masks = mask == obj_ids[:,None,None]

In [9]:
print(type(image))
print(type(mask_image))

# image and mask

In [10]:
def get_concat_h(im1, im2):
    dst = Im.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

get_concat_h(image, mask_image)

# Augmentation function

In [11]:
def get_transform(data):
    if data == 'train':
        return A.Compose([
                A.Resize(CFG.IMAGE_SIZE,CFG.IMAGE_SIZE),
                A.HorizontalFlip(p=0.5),
                A.GridDropout(ratio=0.2, unit_size_min=None, unit_size_max=None, holes_number_x=5, holes_number_y=5, shift_x=0, shift_y=0, random_offset=False, fill_value=0, mask_fill_value=None, always_apply=False, p=0.5),
                A.Normalize(),
                ToTensorV2()
            ])
    elif data == 'valid':
        return A.Compose([
                A.Resize(CFG.IMAGE_SIZE,CFG.IMAGE_SIZE),
                A.Normalize(),
                ToTensorV2()
            ])

In [12]:
len(all_df)

In [13]:
#confirm path for dataset class
image_path = CFG.IMAGE_PATH
mask_path = CFG.ANNOTATION_PATH
print(os.path.join(image_path, all_df["image_name"][0]+'.jpg'))
print(os.path.join(mask_path, 'trimaps/'+all_df["image_name"][0]+'.png'))

# Dataset

In [14]:
class TrainDataset(Dataset):
    def __init__(self, image_paths, mask_paths, df, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.df = df
        self.image_names = df["image_name"]
        self.transforms = transforms
        
    def __len__(self):
        #print(len(self.df))
        return len(self.df)
    
    def __getitem__(self,idx):
        # get path
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_paths, image_name+'.jpg')
        mask_path = os.path.join(self.mask_paths, 'trimaps/'+image_name+'.png')
        #print(image_path,mask_path)
        #open image
        image = Im.open(image_path)
        mask_images = PIL.ImageOps.autocontrast(load_img(mask_path))
        rgb2gry = transforms.Grayscale()
        mask_images = rgb2gry(mask_images)
        #PIL to numpy
        image = np.array(image)
        image = image[:,:,:3]#remove some of alpha channels
        mask_images = np.array(mask_images)
        #convert [0,127,255] -> [0,1,2]
        mask_images[mask_images == 127] = 1
        mask_images[mask_images == 255] = 2

        obj_ids = np.unique(mask_images)
        #print(obj_ids)
        
        if self.transforms:
            augmented = self.transforms(image=image,mask=mask_images)
            image,mask = augmented['image'],augmented['mask']
        #print(image.shape, mask.shape)
        mask  = mask.unsqueeze(0)
        #mask = mask.permute(2,1,0)
        #mask = np.expand_dims(mask, axis=0)

        return image,mask

In [15]:
train_dataset = TrainDataset(
    CFG.IMAGE_PATH,
    CFG.ANNOTATION_PATH,
    train_df,
    transforms = get_transform(data='train'),
)
valid_dataset = TrainDataset(
    CFG.IMAGE_PATH,
    CFG.ANNOTATION_PATH,
    valid_df,
    transforms = get_transform(data='valid'),
)

In [16]:
# CPUのコア数を確認
import os
os.cpu_count()  # コア数

In [17]:
train_loader = DataLoader(train_dataset, CFG.BATCH_SIZE, shuffle = False, num_workers=os.cpu_count(), pin_memory=True)
train_dataset[0]

In [18]:
valid_loader = DataLoader(valid_dataset, CFG.BATCH_SIZE, shuffle = False, num_workers=os.cpu_count(), pin_memory=True)
valid_dataset[0]

# Define U-Net model

## part of U-Net models

In [19]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1,x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX //2,
                        diffY // 2, diffY - diffY //2])
        
        x = torch.cat([x2,x1],dim=1)
        return self.conv(x)
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)

## main part of U-Net

In [20]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
        self.down3 = Down(256,512)
        self.down4 = Down(512,1024)
        self.up1 = Up(1024,512)
        self.up2 = Up(512,256)
        self.up3 = Up(256,128)
        self.up4 = Up(128,64)
        self.outc = OutConv(64,n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5,x4)
        x = self.up2(x,x3)
        x = self.up3(x,x2)
        x = self.up4(x,x1)
        logits = self.outc(x)
        return logits

# Custom metrics

In [21]:
def Dice_Coeff(x,y):
    #print(x.shape, y.shape)
    x = x.flatten()
    y = y.flatten()
    #ans = x&y
    #print(len(x&y)print()
    
    intersection = sum(x*y)
    #print(intersection)
    dice = (2*intersection+1.0)/(sum(x) + sum(y)+1.0)
    #print(intersection)
    return dice

# Visualize image (input and mask)

In [22]:
def show_pred_mask(original,true_tensor,out_tensor,cnt):
    #fig = plt.figure(figsize=(16, 12))
    #x = CFG.to_numpy(in_tensor).transpose(1, 2, 0)
    
    unnormalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
    )
    original = unnormalize(original)
    
    y = CFG.to_numpy(true_tensor)
    z = CFG.to_numpy(out_tensor)
    #print(y,z)
    original = original.permute(2,1,0)
    original = CFG.to_numpy(original)
    original = np.rot90(original)
    original = np.flipud(original)
    #print(original.shape)
    
    y[y==0] = 255
    y[y==1] = 127
    y[y==2] = 0
    
    z[z==0] = 255
    z[z==1] = 127
    z[z==2] = 0
    #print(y,z)
    n_data = 3 # 表示するデータ数
    row=1 # 行数
    col=3 # 列数
    fig, ax = plt.subplots(nrows=row, ncols=col,figsize=(15,18))
    #fig = plt.figure()

    #input image
    ax[0].imshow(original)
    
    #truth mask
    ax[1].imshow(y,cmap='Greys')
    
    #pred mask
    ax[2].imshow(z,cmap='Greys')
    
    plt.show()

# Training

In [23]:
plot_train_loss = []
plot_train_dice = []

plot_valid_loss = []
plot_valid_dice = []

def training_model(model, datasets, dataloaders, criterion, optimizer, num_epochs, device):
    best_model_weights = copy.deepcopy(model.state_dict())
    best_dice = 0.0
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('-'*10)
        
        for phase in ['train','valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_dice = 0.0
            
            stream = tqdm(dataloaders[phase])
            for cnt, (inputs, masks) in enumerate(stream, start=1):
            #for inputs, masks in stream:
                #print(f'inputs and masks shape {inputs.shape}, {masks.shape}')
                original = inputs
                inputs = inputs.to(device=CFG.DEVICE, dtype=torch.float32)
                masks = masks.to(device=CFG.DEVICE, dtype=torch.long)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase=='train'):
                    model = model.to(CFG.DEVICE)
                    outputs = model(inputs)
                    
                    pred_masks = F.softmax(outputs, dim=1).float()
                    masks =  masks.squeeze(1)
                    #_,pred_masks = torch.max(outputs, -3)
                    loss = criterion(pred_masks, masks)#CrossEntropyLoss
                    _,ch = torch.max(outputs, -3)
                    #print(cnt)
                    #入力と対応する予測マスクを表示するための関数を作る
                    if cnt-1 == 0:
                        show_pred_mask(original[0],masks[0],ch[0],cnt)
                    else:
                        continue
                    
                    #dice = Dice_Coeff(_, masks)
                    #print(ch,masks)
                    #print(loss)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                    running_loss += loss.item()

                    #running_acc += accuracy_score(CFG.to_numpy(pred_masks),CFG.to_numpy(masks))
                    #running_dice += Dice_Coeff(CFG.to_numpy(_),CFG.to_numpy(masks))
                    #print(running_dice)
                    
            epoch_loss = running_loss / cnt
            #epoch_acc = running_acc / batch
            #epoch_dice = running_dice / batch
            
            if phase == 'train':
                plot_train_loss.append(epoch_loss)
                #plot_train_dice.append(running_dice/len(dataloaders[phase]))
                #plot_train_loss.append(epoch_loss)
                #plot_train_dice.append(epoch_dice)
            else:
                plot_valid_loss.append(running_loss/len(dataloaders[phase]))
                #plot_valid_dice.append(running_dice/len(dataloaders[phase]))
            
            print(f'{phase} Loss: {epoch_loss}')# Dice: {running_dice/len(dataloaders[phase])}')
            
            #if phase == 'valid' and epoch_dice > best_dice:
                #best_dice = epoch_dice
                #best_model_weights = copy().deepcopy(model.state_dict())
                
        print()
        
    #print(f'Best val Dice:{best_dice}')
          
    #model.load_state_dict(best_model_weights)
    
    return model

In [24]:
datasets = {'train': train_dataset,
            'valid': valid_dataset}

dataloaders = {'train': train_loader,
               'valid': valid_loader}
model = UNet(n_channels=3, n_classes=3)
#optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
num_epochs = 10

In [25]:
trained_model = training_model(model,datasets, dataloaders, criterion, optimizer, num_epochs, CFG.DEVICE)

In [None]:
tensor = torch.tensor([[127,127,127,127],[127,0,0,127],[127,255,255,127],[127,127,127,127]])
x = CFG.to_numpy(tensor)
x

In [None]:
plt.imshow(x)
plt.show()

In [None]:
show_from_tensor(tensor)