In [None]:
import os
curr_dir = os.getcwd()
dir_1AlgoG3 = '/workspace/1AlgoG3'
code_repo_dir = 'code_repo/'

os.chdir(dir_1AlgoG3)
os.chdir(code_repo_dir)


import utils.image_utils as iu
import pma.utils as pu
import utils.utils as u
import s3.s3 as s3
import compath.tissue_utils as tu
from standard_libraries import *

import yaml
import shutil
import random
from tifffile import imread, imsave
import warnings

#from stardist_utils.utils import * 
os.chdir(dir_1AlgoG3)
os.chdir(curr_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
def load_yaml(file_path):
    with open(file_path, 'r') as file:
        try:
            yaml_data = yaml.safe_load(file)
            return yaml_data
        except yaml.YAMLError as e:
            print(f"Error reading YAML file: {e}")
            return None

In [None]:
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import Tensor
import segmentation_models_pytorch as smp

In [None]:
%run /workspace/1AlgoG3/code_repo/smp_unet++/training_utils/vis_utils.py
%run /workspace/1AlgoG3/code_repo/smp_unet++/training_utils/dataloading.py
%run /workspace/1AlgoG3/code_repo/smp_unet++/training_utils/metrics.py
%run /workspace/1AlgoG3/code_repo/smp_unet++/training_utils/validation_logic.py
%run /workspace/1AlgoG3/code_repo/smp_unet++/training_utils/train_logic.py

In [None]:
def combine_train_val_metrics(epoch, train_metrics_per_epoch, val_metrics_per_epoch):
    temp_dict = {}
    temp_dict['epoch'] = epoch

    exclude_cols = ['epoch', 'batch_idx']
    for metric_df in [train_metrics_per_epoch, val_metrics_per_epoch]:
        for col in metric_df.columns:
            if col not in exclude_cols:
                temp_dict[col] = metric_df[col].sum()

    return(temp_dict)

In [None]:
from matplotlib.ticker import MaxNLocator
def create_overall_plots(overall_metrics, overall_metrics_dir):
    variables = [
        'accuracy', 'dice_score', 'iou', 'bce_loss', 'dice_loss', 'loss'
        ]

    for var in variables:
        plt.figure(figsize=(10, 6))
    
        plt.plot(overall_metrics['epoch'], overall_metrics[f'train_{var}'], label=f'Train {var}', marker='o')
        plt.plot(overall_metrics['epoch'], overall_metrics[f'val_{var}'], label=f'Val {var}', marker='o')
    
        plt.title(f'{var.capitalize()} Over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel(var.capitalize())
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend()
        plt.savefig(f'{overall_metrics_dir}/{var.capitalize()} Over Epochs.png')
        plt.close()

# Define Training Folder

In [None]:
train_config = load_yaml('train_config.yaml')

In [None]:
print(f"Training Root: {train_config['training_root']}\nRun Name: {train_config['run_name']}\n-----------------")
change_run_name = input(f"\nWant to change run_name? (0 -No, 1-Yes)")
if bool(int(change_run_name)):
   train_config['run_name'] = input("Input new run_name") 

In [None]:
input_image_shape = train_config['input_image_shape']
training_root = f"{train_config['training_root']}/{train_config['run_name']}"
data_root = train_config['data_root']
model_dir = f"{training_root}/models"
val_images_dir = f"{training_root}/val_images"
metrics_dir = f"{training_root}/metrics"
train_images_dir = f"{training_root}/train_images"
overall_metrics_dir = f'{training_root}/overall_metrics'

u.create_folder(training_root)
u.create_folder(model_dir)
u.create_folder(val_images_dir)
u.create_folder(metrics_dir)
u.create_folder(train_images_dir)
u.create_folder(overall_metrics_dir)

In [None]:
rem_dir = input(f"Want to remove training root? (0 -No, 1-Yes)")
if bool(int(rem_dir)):
    shutil.rmtree(training_root)
    print('train root removed')

In [None]:
use_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if use_cuda else "cpu"
print(f'Runing on: {device} | GPU available: {torch.cuda.is_available()}')

if device.type == 'cuda':
    num_gpus = torch.cuda.device_count()
    if num_gpus >1:
        print(f"{num_gpus} GPU's Available")
    else:
        print(f"{num_gpus} GPU's Available")

In [None]:
#dist_df = pd.read_csv(f'{data_root}/train_data_class_dist.csv')

x_train_dir = f"{data_root}/train_images"
y_train_dir = f"{data_root}/train_masks"

x_valid_dir = f"{data_root}/val_images"
y_valid_dir = f"{data_root}/val_masks"

# Check Loaders

In [None]:
u.start_timer()
train_dataset = CustomDataset(x_train_dir, y_train_dir, augmentation=get_training_augmentation(input_image_shape))
valid_dataset = CustomDataset(x_valid_dir, y_valid_dir)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
u.stop_timer()

In [None]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

for batch in tqdm(train_loader, total = len(train_loader), desc = f'Batches'):   
    images, true_masks = batch[0], batch[1]
    for image_tensor, mask_tensor in zip(images, true_masks):
        image_tensor = image_tensor.cpu().numpy().transpose((1, 2, 0))
        print(torch.unique(mask_tensor))
        iu.plot_image_series([image_tensor, mask_tensor.cpu().numpy()])
        print('\n')
    #break

# Initialize Models

In [None]:
ENCODER = "resnet50"
ENCODER_WEIGHTS = 'imagenet'
n_classes = 1
DEVICE = 'cuda'

model = smp.UnetPlusPlus(
    encoder_name=ENCODER,        
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=n_classes
)

model.to(device);
if num_gpus >1:
    model = nn.DataParallel(model)

In [None]:
epochs = 10
gradient_clipping = float(train_config['gradient_clipping'])
weight_decay = float(train_config['weight_decay'])
learning_rate = float(train_config['learning_rate'])
amp = train_config['amp']

In [None]:
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: minimize validation loss

print(f"Epochs: {epochs}")

In [None]:
train_dataset = CustomDataset(x_train_dir, y_train_dir, trial_run = True, augmentation=get_training_augmentation(512))
valid_dataset = CustomDataset(x_valid_dir, y_valid_dir, trial_run = True)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [None]:
print(f"Dataset Length: {len(train_dataset)}\nTotal Batches: {len(train_loader)}")

In [None]:
epoch_logs = []
batch_logs = []
overall_metrics_list = []

max_val_iou = 0
global_step = 0

model.train()

for epoch in tqdm(range(epochs),total=epochs, desc = "Epochs"):
    epoch += 1
    
    model, optimizer, grad_scaler,learning_rate, train_metrics_per_epoch = train(model,
                                                                                 optimizer,
                                                                                 grad_scaler,
                                                                                 learning_rate,
                                                                                 device,
                                                                                 train_loader,
                                                                                 epoch,
                                                                                 train_config
                                                                                )
    
    train_metrics_per_epoch = pd.DataFrame(train_metrics_per_epoch)
    
    model.eval();
    val_metrics_per_epoch = validate(model, device, valid_loader, epoch, train_config)
    model.train();
    
    val_metrics_per_epoch = pd.DataFrame(val_metrics_per_epoch )
    val_metrics_per_epoch.to_csv(f'{metrics_dir}/val_metrics_per_batch_for_epoch{epoch}.csv', index = False)
    
    save_avg_val_metrics(val_metrics_per_epoch, epoch, training_root)
    
    validation_iou = val_metrics_per_epoch['val_iou'].mean()
    
    scheduler.step(validation_iou)
    
    if epoch == 0:
        max_val_iou = validation_iou
    elif validation_iou> max_val_iou:
        max_val_iou = validation_iou
        torch.save(model, f'{model_dir}/max_val_score_model_epoch{epoch}_score({round(validation_iou,5)}).pth')
        
        
    if (epoch)%2 == 0:
        torch.save(model, f'{model_dir}/default_save_epoch{epoch}.pth')

    temp_dict = combine_train_val_metrics(epoch, train_metrics_per_epoch, val_metrics_per_epoch)
    overall_metrics_list.append(temp_dict)

    overall_metrics = pd.DataFrame(overall_metrics_list)
    overall_metrics.to_csv(f'{overall_metrics_dir}/overall_metrics.csv', index = False)

    create_overall_plots(overall_metrics, overall_metrics_dir)