## Import

In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

def round(temp):
    return np.round((temp - np.min(temp))/((np.max(temp) - np.min(temp))))

def yasai_show_v1(dataset, idx, model=None):
    image, mask = dataset[idx]
    if model is not None:
        pred = model(image.unsqueeze(0))
        with torch.no_grad():
            pred = np.asarray(pred).squeeze()
    with torch.no_grad():
        image = np.asarray(image).transpose(1, 2, 0)
        mask = np.asarray(mask)

    if model is not None:
        tempdict = {}
        tempdict['image'] = image
        for i in range(pred.shape[0]):
            tempdict['pred_' + str(i)] = 0.4 * round(pred[i]) + 0.6 * image[...,0].squeeze()
        visualize(**tempdict)

    tempdict = {}
    tempdict['image'] = image
    for i in range(mask.shape[0]):
        tempdict['mask_' + str(i)] = 0.4 * round(mask[i]) + 0.6 * image[...,0].squeeze()
    visualize(**tempdict)

def yasai_model_save_v1(model, text=''):
    temp = os.path.join(os.getcwd(), 'model_' + text + datetime.now().strftime("%y%m%d%H%M.pt"))
    torch.save({'state_dict': model.state_dict(), 'model': model}, temp)
    print('Successfully saved to ' + temp)

def yasai_model_load_v1(path):
    temp = torch.load(path)
    model = temp['model']
    model.load_state_dict(temp['state_dict'])
    print('Successfully loaded from ' + path)
    return model

def yasai_compute_iou_v1(pred, label):
    # print(label.shape, np.unique(label))
    # print(round(pred).shape, np.unique(round(pred)))
    label_c = label == 1
    pred_c = round(pred) == 1

    intersection = np.logical_and(pred_c, label_c).sum()
    union = np.logical_or(pred_c, label_c).sum()

    if union != 0 and np.sum(label_c) != 0:
        return intersection / union
    
def yasai_compute_batch_iou_v1(model, data_loader):
    ious = []
    for image, mask in tqdm(data_loader, desc='Iterating'):
        pred = model(image)
        with torch.no_grad():
            pred = np.asarray(pred).squeeze()
            mask = np.asarray(mask)
        ious += [yasai_compute_iou_v1(pred, mask)]
    print(sum(ious)/len(ious))

In [2]:
import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import optim
from PIL import Image
import torch.nn.functional as F
import cv2
# from sklearn.model_selection import train_test_split

%matplotlib inline

## Root

In [3]:
ROOT_PATH = '/home/yasaisen/Desktop/09_research/09_research_main/lab_03'

In [4]:
dataset_folder = 'dataset_C_v_2.9.3'

train_img_path = os.path.join(ROOT_PATH, dataset_folder, 'train_for_base_imgs_rgb')
train_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'train_for_base_mask')

valid_img_path = os.path.join(ROOT_PATH, dataset_folder, 'valid_imgs_rgb')
valid_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'valid_mask')

test_img_path = os.path.join(ROOT_PATH, dataset_folder, 'test_imgs_rgb')
test_mask_path = os.path.join(ROOT_PATH, dataset_folder, 'test_mask')

## Aug

In [5]:
img_size = 224
train_bsz = 32
device = 'cuda'
epochs = 30
valid_bsz = 8
test_bsz = 8

## Dataset

In [6]:
def get_df(img_path, mask_path):
    images, labels = [], []

    i = 0

    for get_img_name in os.listdir(img_path):
        images += [os.path.join(img_path, get_img_name)] # NORMAL_G1_Lid1_LRid293_Gid3133_Bl30.png
        labels += [get_img_name.split('_')[0]]
        
        i = i+1

    PathDF = pd.DataFrame({'images': images, 'labels': labels})
    print(i)
    PathDF.head()
    return PathDF

In [7]:
train_df = get_df(train_img_path, train_mask_path)
valid_df = get_df(valid_img_path, valid_mask_path)
test_df = get_df(test_img_path, test_mask_path)

3771
1037
917


In [8]:
transform = transforms.Compose([
            transforms.ToTensor()
            ])

In [9]:
# mask_path = '/home/yasaisen/Desktop/09_research/09_research_main/lab_03/dataset_C_v_2.9.3/train_for_base_mask/RSLN_L_G10_Lid45_LRid112_Gid7024_C4.png'
# label = Image.open(mask_path)
# label = np.array(label)

In [10]:
class mod_Dataset(Dataset):
    def __init__(self, path_df, transform=None):
        self.path_df = path_df
        self.transform = transform

    def __len__(self):
        return self.path_df.shape[0]
    
    def __getitem__(self, idx):
        if self.transform is not None:
            trans_Resize = transforms.Resize(224)

            images = trans_Resize(Image.open(self.path_df.iloc[idx]['images']).convert('RGB'))
            images = self.transform(images)

            if self.path_df.iloc[idx]['labels'] == 'NORMAL':
                lables = torch.tensor([1, 0, 0], dtype=torch.float32)
            if self.path_df.iloc[idx]['labels'] == 'RLN':
                lables = torch.tensor([0, 1, 0], dtype=torch.float32)
            if self.path_df.iloc[idx]['labels'] == 'RSLN':
                lables = torch.tensor([0, 0, 1], dtype=torch.float32)

        return images, lables

In [11]:
train_data = mod_Dataset(train_df, transform)
valid_data = mod_Dataset(valid_df, transform)
test_data  = mod_Dataset(test_df, transform)

train_loader = DataLoader(train_data, batch_size=train_bsz, shuffle=True , num_workers=0, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_data, batch_size=valid_bsz, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_data , batch_size=test_bsz , shuffle=False, num_workers=0)

## Model

In [12]:
class resnet34(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = torchvision.models.resnet34(weights=None)
        self.fc1 = nn.Linear(1000, num_classes)

    def forward(self, input):
        output = self.base_model(input)
        output = self.fc1(output)
        return output

In [13]:
model = resnet34(3).to(device)
# print(model)
t = torch.randn((4, 3, 224, 224)).to(device)
print(t.shape)
get = model(t)
print(get.shape)

for x, y in train_loader:
    print(x.shape)
    print(y.shape)
    break

torch.Size([4, 3, 224, 224])
torch.Size([4, 3])
torch.Size([32, 3, 224, 224])
torch.Size([32, 3])


## Train

In [14]:
def check_accuracy(loader, model, device):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            scores = model(x)
            _, predictions = scores.max(1)
            _, y = y.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    model.train()
    return (num_correct/num_samples).item()

def train(epochs, model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


    for epoch in range(epochs):
        model.train()
        losses = []

        pbar = tqdm(train_loader, total=len(train_loader), position=0, leave=True, desc=f"Epoch {epoch}")
        for data, targets in pbar:
            data = data.to(device)
            targets = targets.to(device)

            # forward
            scores = model(data)
            loss = criterion(scores, targets)
            losses.append(loss.item())
            # backward
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        
        avg_loss = sum(losses) / len(losses)
        acc = check_accuracy(test_loader, model, device)
        print(f"Loss:{avg_loss:.8f}\tAccuracy:{acc:.8f}")

    return model

In [15]:
trained_resnet34_model = train(30, model)

Epoch 0: 100%|██████████| 117/117 [01:37<00:00,  1.20it/s]


Loss:0.36553614	Accuracy:0.34569246


Epoch 1: 100%|██████████| 117/117 [01:32<00:00,  1.27it/s]


Loss:0.05602545	Accuracy:0.32824427


Epoch 2: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.05395426	Accuracy:0.42639038


Epoch 3: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.02843638	Accuracy:0.27480915


Epoch 4: 100%|██████████| 117/117 [01:31<00:00,  1.27it/s]


Loss:0.00343214	Accuracy:0.29770991


Epoch 5: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.00103095	Accuracy:0.30643401


Epoch 6: 100%|██████████| 117/117 [01:33<00:00,  1.25it/s]


Loss:0.00071755	Accuracy:0.28353325


Epoch 7: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.00029976	Accuracy:0.30752453


Epoch 8: 100%|██████████| 117/117 [01:33<00:00,  1.25it/s]


Loss:0.00020224	Accuracy:0.30316249


Epoch 9: 100%|██████████| 117/117 [01:32<00:00,  1.26it/s]


Loss:0.00013655	Accuracy:0.28571427


Epoch 10: 100%|██████████| 117/117 [01:32<00:00,  1.27it/s]


Loss:0.00009814	Accuracy:0.28680480


Epoch 11: 100%|██████████| 117/117 [01:31<00:00,  1.27it/s]


Loss:0.00007442	Accuracy:0.28571427


Epoch 12: 100%|██████████| 117/117 [01:31<00:00,  1.27it/s]


Loss:0.00014898	Accuracy:0.28026173


Epoch 13: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.00043606	Accuracy:0.27044711


Epoch 14: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.00037538	Accuracy:0.36859322


Epoch 15: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.03340307	Accuracy:0.44383860


Epoch 16: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.09908101	Accuracy:0.41875678


Epoch 17: 100%|██████████| 117/117 [01:30<00:00,  1.29it/s]


Loss:0.01649511	Accuracy:0.51799345


Epoch 18: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.02305389	Accuracy:0.32933477


Epoch 19: 100%|██████████| 117/117 [01:31<00:00,  1.29it/s]


Loss:0.03440163	Accuracy:0.36205015


Epoch 20: 100%|██████████| 117/117 [01:31<00:00,  1.27it/s]


Loss:0.00170275	Accuracy:0.30098146


Epoch 21: 100%|██████████| 117/117 [01:33<00:00,  1.26it/s]


Loss:0.00049440	Accuracy:0.31188658


Epoch 22: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.00053530	Accuracy:0.32606325


Epoch 23: 100%|██████████| 117/117 [01:32<00:00,  1.27it/s]


Loss:0.01555282	Accuracy:0.39149398


Epoch 24: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.01940561	Accuracy:0.29880041


Epoch 25: 100%|██████████| 117/117 [01:30<00:00,  1.30it/s]


Loss:0.00443666	Accuracy:0.36205015


Epoch 26: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.02051344	Accuracy:0.40676117


Epoch 27: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s]


Loss:0.00874591	Accuracy:0.42420936


Epoch 28: 100%|██████████| 117/117 [01:31<00:00,  1.27it/s]


Loss:0.00069688	Accuracy:0.41112322


Epoch 29: 100%|██████████| 117/117 [01:32<00:00,  1.27it/s]


Loss:0.00032586	Accuracy:0.41330424


In [16]:
yasai_model_save_v1(trained_resnet34_model, 'resnet34_trainbasergb_')

Successfully saved to /home/yasaisen/Desktop/09_research/09_research_main/lab_10/model_resnet34_trainbasergb_2305151701.pt
