# Import functions

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
# from albumentations import HorizontalFlip
from torchviz import make_dot, make_dot_from_trace
import torch.optim as optim
from torch.optim import lr_scheduler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.transforms.functional as TF

import segmentation_models_pytorch as smp
from collections import defaultdict
import os, time
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score, confusion_matrix
import random
import matplotlib.pyplot as plt
from path import Path
import pandas as pd

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime



# Check and Use GPU if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(0)

print(device)
print(torch.cuda.current_device())

# Set Parameters

In [None]:
LEARNING_RATE= 1e-3
BATCH_SIZE= 4
NUM_EPOCHS= 10
NUM_WORKERS= 0

IMAGE_HEIGHT= 512
IMAGE_WIDTH= 416
PIN_MEMORY= True
LOAD_MODEL= False

num_block= [3, 4, 6, 3];
input_channel=3 

model_category = 'uresnet'
checkpoint_path = 'uresnet.pth'
training_checkpoint = 'training_checkpoint.pth'

TRAIN_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_f/*"))
TRAIN_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_m/*"))

VAL_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_f/*"))
VAL_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_m/*"))

data_str = f"Dataset Size:\nTrain images: {len(TRAIN_IMG_DIR)}\t Train masks: {len(TRAIN_MASK_DIR)}"
print(data_str)

data_str = f"Val images: {len(VAL_IMG_DIR)}\t Val masks: {len(VAL_MASK_DIR)}"
print(data_str)

# Create Dataset

In [None]:
class EchoDataset(Dataset):
    def __init__(self, images_path, masks_path, transform=None):
        self.images_path = images_path
        self.masks_path = masks_path
        self.transform = transform


    def __getitem__(self, index):
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        image = image/image.max()
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)

        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)        
        mask = cv2.resize(mask, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        masks = [(mask==c) for c in range(3)]
        mask = np.stack(masks, axis=0)
        mask = mask.astype(np.float32)

        if self.transform is not None:
            augmentation= self.transform(image= image, mask= mask)
            image = augmentation['image']
            mask = augmentation['mask']

            # image = np.transpose(image, (1,2,0)).to(torch.float32)
            # mask = mask.to(torch.float32)

        return image, mask, self.images_path[index]

    def __len__(self):
        return len(self.images_path)

transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit= 70,p=0.6)

])

def get_train_data(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, batch_size, train_transform, val_transform, num_workers, pin_memory):
    train_ds= EchoDataset(train_img_dir, train_mask_dir, train_transform)
    train_dataloader= DataLoader(train_ds, batch_size=batch_size,
                                 shuffle=True, 
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)
    val_ds= EchoDataset(val_img_dir, val_mask_dir, val_transform)
    val_dataloader= DataLoader(val_ds, batch_size=batch_size,
                               shuffle=False,
                               num_workers=num_workers,
                               pin_memory=pin_memory)

    return train_dataloader, val_dataloader

def get_test_data(test_img_dir, test_mask_dir, batch_size, test_transform, num_workers, pin_memory):
    test_ds= EchoDataset(test_img_dir, test_mask_dir, test_transform)
    test_dataloader= DataLoader(test_ds, batch_size=batch_size,
                                shuffle= False,
                                num_workers=num_workers,
                                pin_memory=pin_memory) 
    return test_dataloader


# train_ds= EchoDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, transform)
# print(ds[1][0].dtype)
train_dataloader, val_dataloader = get_train_data(train_img_dir= TRAIN_IMG_DIR, train_mask_dir= TRAIN_MASK_DIR, 
                                                  val_img_dir= VAL_IMG_DIR, val_mask_dir= VAL_MASK_DIR, 
                                                  train_transform=None, val_transform=None,
                                                  batch_size= BATCH_SIZE, 
                                                  num_workers= NUM_WORKERS, 
                                                  pin_memory= PIN_MEMORY)

len(train_dataloader)

# for i , data in enumerate(train_dataloader):
#     inputs, mask= data
#     print(inputs.dtype)
                        

# Define Loss Function

In [None]:
random.seed(42)
y_pred=np.random.randn(BATCH_SIZE,IMAGE_WIDTH, IMAGE_HEIGHT)
y_true=np.random.randn(BATCH_SIZE,IMAGE_WIDTH, IMAGE_HEIGHT)
loss= 1-((2*sum(y_pred.flatten()*y_true.flatten()))/(sum(y_pred**2) + sum(y_true**2) + 1))
print(loss.shape)



In [None]:
pred= torch.randn((BATCH_SIZE, 3, IMAGE_WIDTH, IMAGE_HEIGHT))
true= torch.randn((BATCH_SIZE, 3, IMAGE_WIDTH, IMAGE_HEIGHT))

In [None]:
def dice_loss1(y_pred, y_true):
    #flatten label and prediction tensors
    smooth = 1.
    y_pred= y_pred.view(-1)
    y_true= y_true.view(-1)
    
    intersection = (y_pred*y_true).sum()
    dice = (2*intersection + smooth) / (y_pred.sum() + y_pred.sum() + smooth)
    return (1 - dice)/BATCH_SIZE
    

dice_loss1(pred, true)


In [None]:
def dice_loss(pred, target):
    smooth = 1.
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def combo_loss(pred, target, bce_weight=0.5):
    pred= torch.sigmoid(pred)

    dice = dice_loss(pred, target)

    bce= F.binary_cross_entropy(pred, target)

    loss = bce*bce_weight + dice*(1-bce_weight)

    return dice, bce, loss

# dice, bce, loss= combo_loss(pred, true)

# Call Model

In [None]:
model = smp.Unet(
    encoder_name= "resnet34",
    encoder_weights= "imagenet",
    in_channels=1,
    classes=3,
)

# Training network

In [None]:
def train_network(model, optimizer, scheduler, num_epochs, tb_writer):

    # training session
    best_vloss= 1e10

    for epoch in range(num_epochs):
        # print('--'*50)
        # print("Epoch {}/{}".format(epoch, num_epochs))
        # print('--'*50)
        since= time.time()
        model.train(True)
        running_loss, running_dice_loss, running_bce_loss=0, 0, 0
        samples_per_epoch=0
        with tqdm(train_dataloader, unit="batch") as tepoch:
            for inputs, masks in tepoch:
                # print(inputs.shape)
                inputs, masks= inputs.cuda(), masks.cuda()
                tepoch.set_description(f"Epoch {epoch}")
                optimizer.zero_grad()

                outputs = model(inputs)

                dice, bce, loss = combo_loss(outputs, masks)
                loss.backward()

                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_dice_loss += dice.item() * inputs.size(0)
                running_bce_loss += bce.item() * inputs.size(0)
                samples_per_epoch += inputs.size(0)
                torch.save(model.state_dict(), training_checkpoint)
                tepoch.set_postfix(loss= running_loss/samples_per_epoch)
                # print('saved!')
        
        avg_train_loss = running_loss/samples_per_epoch
        avg_dice_loss = running_dice_loss/samples_per_epoch
        avg_bce_loss = running_bce_loss/samples_per_epoch

        model.train(False)

        running_vloss,running_val_dice, running_val_bce = 0, 0, 0
        for index, data in enumerate(val_dataloader):
            inputs, masks= data
            inputs, masks= inputs.cuda(), masks.cuda()
            outputs = model(inputs)
            val_dice, val_bce, val_loss= combo_loss(outputs, masks)
            running_vloss+=val_loss.item()
            running_val_dice += val_dice.item()
            running_val_bce += val_bce.item()
            
        
        avg_val_loss = running_vloss / (index+1)
        avg_val_dice = running_val_dice / (index+1)
        avg_val_bce  = running_val_bce / (index+1) 

        time_elapsed= time.time()-since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print("Train: bce:{}\t dice:{}\t loss:{}\t".format(avg_bce_loss,avg_dice_loss, avg_train_loss))
        print("Validation: bce:{}\t dice:{}\t loss:{}\t".format(avg_val_bce, avg_val_dice, avg_val_loss))
        scheduler.step()
        for param_group in optimizer.param_groups:
            print("LR", param_group['lr'])

        tb_writer.add_scalars('Training vs. Validation Loss', {'Training':avg_train_loss, 'Validation': avg_val_loss}, epoch+1)
        # tb_writer.flush()

        if avg_val_loss < best_vloss:
            print('Validation loss improved from {:.4f} to {:.4f}. Model is saved to {}'.format(best_vloss, avg_val_loss, checkpoint_path))
            best_vloss= avg_val_loss
            torch.save(model.state_dict(), checkpoint_path)

    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [None]:
NUM_EPOCHS=10
since= time.time()
timestamp= datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/smp_uresnet_{}'.format(timestamp))

model = smp.Unet(
    encoder_name= "resnet34",
    encoder_weights= "imagenet",
    in_channels=1,
    classes=3,
).cuda()

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=3, gamma=0.5)
# model.load_state_dict(torch.load(training_checkpoint))
model= train_network(model, optimizer_ft, exp_lr_scheduler, num_epochs=NUM_EPOCHS, tb_writer=writer).cuda()
writer.close()
time_elapsed= time.time() - since
print('Total time taken: {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


# Evaluation

### saving predictions to local folder

In [None]:
model.load_state_dict(torch.load('uresnet.pth'))
model= model.cuda()

model.eval()

test_loader= get_test_data(TRAIN_IMG_DIR, TRAIN_MASK_DIR, BATCH_SIZE, 
                           test_transform=None, 
                           num_workers=NUM_WORKERS, 
                           pin_memory=PIN_MEMORY)

with Path("/home/haobo/HaoboSeg-pytorch/eizzaty/"):
    if not os.path.exists(model_category):
        print(model_category + " does not exist. Creating directory...")
        os.makedirs(model_category)
        print(model_category+ " created!")

    for j, (inputs, masks, file_names) in enumerate(test_loader):

        inputs, masks= inputs.cuda(), masks.cuda()

        pred= model(inputs)
        pred= torch.sigmoid(pred)
        pred = pred.cpu().detach().numpy()
        ori_imgs = inputs.cpu().detach().numpy()
        ground_truth = masks.cpu().detach().numpy()


        if not os.path.exists(model_category+"/predicted_images"):
            print(model_category+"/predicted_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/predicted_images")
            print(model_category+"/predicted_images"+ " created!")
        
        if not os.path.exists(model_category+"/original_images"):
            print(model_category+"/original_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/original_images")
            print(model_category+"/original_images"+ " created!")

        if not os.path.exists(model_category+"/groundtruth_images"):
            print(model_category+"/groundtruth_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/groundtruth_images")
            print(model_category+"/groundtruth_images"+ " created!")
        

        for i in range(inputs.size(0)):
            ground_truth_final= np.argmax(np.transpose(ground_truth[i].reshape(3, IMAGE_HEIGHT, IMAGE_WIDTH), (1,2,0)), axis=2)
            pred_final= np.argmax(np.transpose(pred[i].reshape(3, IMAGE_HEIGHT, IMAGE_WIDTH), (1,2,0)), axis=2)
            cv2.imwrite(model_category+"/groundtruth_images/"+Path(file_names[i]).stem + "_groundtruth.png", 100*ground_truth_final)
            cv2.imwrite(model_category+"/predicted_images/"+Path(file_names[i]).stem + "_pred.png", 100*pred_final)
            cv2.imwrite(model_category+"/original_images/"+Path(file_names[i]).stem + "_ori.png", 255*ori_imgs[i].reshape(IMAGE_HEIGHT, IMAGE_WIDTH))

### evaluating the predictions 