In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from torchmetrics.classification import Dice
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import random
import torch
from utils import unet, mypreprocess, util_functions, eff_unet, eff_unet2, dataset3d, dataset2d
from tqdm import tqdm
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
import json
import warnings
warnings.filterwarnings("ignore")
import math
import gc
import albumentations as A
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler
from mcdropout import MCDropout2D

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# base_path = '/scratch/student/sinaziaee/datasets/3d_dataset'
# train_dir = os.path.join(base_path, 'training')
# valid_dir = os.path.join(base_path, 'validation')
# test_dir = os.path.join(base_path, 'testing')
# # IMG_SIZE = 512
# BATCH_SIZE = 80

# transform_input, transform_output = util_functions.custom_transformers(scale=(0.5, 2), 
#                                                     contrast=(0.5, 2), brightness=(0.5, 1.5), rotation=180, blur=1)
# aug1_dataset = my_dataset.SegmentationDataset(input_root=f'{train_dir}/images/',target_root=f'{train_dir}/labels/',
#                                transform_input= transform_input, transform_target=transform_output)

In [3]:
base_path = '/scratch/student/sinaziaee/datasets/2d_dataset/'
train_dir = os.path.join(base_path, 'training')
valid_dir = os.path.join(base_path, 'validation')
test_dir = os.path.join(base_path, 'testing')
# IMG_SIZE = 512
BATCH_SIZE = 80

transform_input, transform_output = util_functions.custom_transformers(scale=(0.5, 2), 
                                                    contrast=(0.5, 2), brightness=(0.5, 1.5), rotation=180, blur=1)
valid_transformer = transforms.Compose([transforms.ToTensor()])
aug1_dataset = dataset2d.SegmentationDataset(input_root=f'{train_dir}/images/',target_root=f'{train_dir}/labels/',
                               transform_input= transform_input, transform_target=transform_output)
transform_input, transform_output = util_functions.custom_transformers(scale=(0.7, 1.4), 
                                                    brightness=(0.75, 1.25), contrast=(0.5, 2), rotation=360, blur=1)
aug2_dataset = dataset2d.SegmentationDataset(input_root=f'{train_dir}/images/',target_root=f'{train_dir}/labels/',
                               transform_input= transform_input, transform_target=transform_output)
plain_train_dataset = dataset2d.SegmentationDataset(input_root=f'{train_dir}/images/',target_root=f'{train_dir}/labels/',
                               transform_input= valid_transformer, transform_target=valid_transformer)
valid_dataset = dataset2d.SegmentationDataset(input_root=f'{valid_dir}/images/',target_root=f'{valid_dir}/labels/',
                               transform_input= valid_transformer, transform_target=valid_transformer)
test_dataset = dataset2d.SegmentationDataset(input_root=f'{test_dir}/images/',target_root=f'{test_dir}/labels/',
                               transform_input= valid_transformer, transform_target=valid_transformer)
t_dataset = ConcatDataset([plain_train_dataset, aug1_dataset, aug2_dataset])

# Loaders
# train_loader = DataLoader(t_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_dataset = ConcatDataset([plain_train_dataset, aug1_dataset, aug2_dataset])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

print("Number of images, training:", len(train_loader.dataset), ", validation", len(valid_loader.dataset), " testing:", len(test_loader.dataset))

Number of images, training: 41013 , validation 1047  testing: 1063


In [5]:
def eval_fn(data_loader, model, criterion, device):
    model.eval()
    total_loss = 0
    total_iou = 0
    with torch.no_grad():
        for batch in data_loader:
            images, masks = batch
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            
            loss = criterion(outputs, masks)

            total_loss += loss.item()
            iou = util_functions.calculate_IoU(outputs, masks)
            total_iou += iou.item()
            dice_value = util_functions.dice_coefficient(loss.item())
        
        avg_loss = total_loss / len(data_loader)
        avg_iou = total_iou / len(data_loader)
    return avg_loss, avg_iou, dice_value

In [6]:
def train_loop(n_epochs, model, optimizer, train_loader, valid_loader, device,
                                criterion1, scheduler=None):
    
    model = model.to(device)
        
    best_valid_loss = np.Inf

    train_loss_list = []
    valid_loss_list = []
    valid_iou_list = []
    valid_dice_list = []

    results_folder = util_functions.create_result_folder(path='results')
    print(results_folder)
    for epoch in tqdm(range(n_epochs)):
        train_loss = util_functions.train_fn(data_loader=train_loader, model=model, criterion=criterion1, 
                              optimizer=optimizer, device=device)
        valid_loss, valid_iou, valid_dice = eval_fn(data_loader=valid_loader, model=model, criterion=criterion1,
                                        device=device)
        
        scheduler.step()
        
        # Access the current learning rate
        current_lr = scheduler.get_lr()[0]
        
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        valid_iou_list.append(valid_iou)
        valid_dice_list.append(valid_dice)
        
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            directory = 'results'
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(model.state_dict(), f'{results_folder}/best_model.pt')
            print('SAVED-MODEL')
        
        print(f'Epoch: {epoch+1}, Train Loss: {train_loss}, Valid Loss: {valid_loss}, Valid IoU: {valid_iou}, lr: {current_lr}')
        if epoch % 10 == 0:
            util_functions.visualize_training(train_loss_list=train_loss_list, valid_loss_list=valid_loss_list,
                                            valid_iou_list=valid_iou_list, valid_dice_list=valid_dice_list, results_folder=results_folder)
            
        lists_dict = {
            'train_loss_list': train_loss_list,
            'valid_loss_list': valid_loss_list,
            'valid_iou_list': valid_iou_list,
            'valid_dice_list': valid_dice_list,
        }

        with open(f'{results_folder}/training_trend.json', 'w') as f:
            json.dump(lists_dict, f)
        
    return f'{results_folder}/best_model.pt'
                

In [7]:
torch.cuda.empty_cache()
model = eff_unet2.EffUNet(in_channels=1, classes=1)
device = torch.device('cuda:0')
print(device)
model.to(device)
criterion1 = DiceLoss(mode="binary")
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)
# scheduler = lr_scheduler.ReduceLROnPlateau()
n_epochs = 21

result_folder = train_loop(n_epochs, model, optimizer, train_loader, valid_loader, device, criterion1, scheduler=scheduler)

cuda:0
results/2023-12-14_13-43


  0%|          | 0/21 [00:00<?, ?it/s]

In [None]:
print(result_folder)

results/2023-11-15_15-32/best_model.pt
