# Import Modules

In [18]:
import torch 
import torch.nn as nn
from torchmetrics import Dice
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
import albumentations as A
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

import numpy as np
import cv2
import os, time
import matplotlib.pyplot as plt
from glob import glob

from torch.utils.tensorboard import SummaryWriter
from path import Path


# Check GPUs

In [19]:
# device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
# torch.cuda.set_device(0)

# print(device)
# print(torch.cuda.current_device())

# Set Parameters

In [20]:
LEARNING_RATE= 1e-3
BATCH_SIZE= 8
NUM_EPOCHS= 10
NUM_WORKERS= 0

IMAGE_HEIGHT= 512
IMAGE_WIDTH= 416
PIN_MEMORY= True
LOAD_MODEL= False

# num_block= [3, 4, 6, 3];
features_depth= [64, 128, 256, 512]
input_channel= 1 
num_classes= 3

model_category = 'segAN'
checkpoint_path = 'segAN.pth'
# training_checkpoint = 'training_checkpoint.pth'

TRAIN_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_f/*"))
TRAIN_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_m/*"))

VAL_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_f/*"))
VAL_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_m/*"))

data_str = f"Dataset Size:\nTrain images: {len(TRAIN_IMG_DIR)}\t Train masks: {len(TRAIN_MASK_DIR)}"
print(data_str)

data_str = f"Val images: {len(VAL_IMG_DIR)}\t Val masks: {len(VAL_MASK_DIR)}"
print(data_str)

Dataset Size:
Train images: 3813	 Train masks: 3813
Val images: 195	 Val masks: 195


# Create Dataset

In [21]:
class EchoDataset(Dataset):
    def __init__(self, images_path, masks_path, transform=None):
        self.images_path = images_path
        self.masks_path = masks_path
        self.transform = transform


    def __getitem__(self, index):
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        image = image/image.max()
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)

        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)        
        mask = cv2.resize(mask, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        masks = [(mask==c) for c in range(3)]
        mask = np.stack(masks, axis=0)
        mask = mask.astype(np.float32)

        if self.transform is not None:
            augmentation= self.transform(image= image, mask= mask)
            image = augmentation['image']
            mask = augmentation['mask']

            # image = np.transpose(image, (1,2,0)).to(torch.float32)
            # mask = mask.to(torch.float32)

        return image, mask, self.images_path[index]

    def __len__(self):
        return len(self.images_path)

transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit= 70,p=0.6)

])

def get_train_data(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, batch_size, train_transform, val_transform, num_workers, pin_memory):
    train_ds= EchoDataset(train_img_dir, train_mask_dir, train_transform)
    train_dataloader= DataLoader(train_ds, batch_size=batch_size,
                                 shuffle=True, 
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)
    val_ds= EchoDataset(val_img_dir, val_mask_dir, val_transform)
    val_dataloader= DataLoader(val_ds, batch_size=batch_size,
                               shuffle=False,
                               num_workers=num_workers,
                               pin_memory=pin_memory)

    return train_dataloader, val_dataloader

def get_test_data(test_img_dir, test_mask_dir, batch_size, test_transform, num_workers, pin_memory):
    test_ds= EchoDataset(test_img_dir, test_mask_dir, test_transform)
    test_dataloader= DataLoader(test_ds, batch_size=batch_size,
                                shuffle= False,
                                num_workers=num_workers,
                                pin_memory=pin_memory) 
    return test_dataloader


# train_ds= EchoDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, transform)
# print(ds[1][0].dtype)
train_dataloader, val_dataloader = get_train_data(train_img_dir= TRAIN_IMG_DIR, train_mask_dir= TRAIN_MASK_DIR, 
                                                  val_img_dir= VAL_IMG_DIR, val_mask_dir= VAL_MASK_DIR, 
                                                  train_transform=None, val_transform=None,
                                                  batch_size= BATCH_SIZE, 
                                                  num_workers= NUM_WORKERS, 
                                                  pin_memory= PIN_MEMORY)

len(train_dataloader)

                        

477

# Call Model

In [24]:
class down_conv(nn.Module):
    """ Down convolution with a kernel size of 4x4 and optional batch normalization.
    Args:
    in_c : int
    out_c: int
    stride: int
    batch_normalization: bool
    """
    def __init__(self, in_c, out_c, batch_normalization=True):
        super().__init__()
        
        self.conv= nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1)
        if batch_normalization:
            self.bn= nn.BatchNorm2d(out_c)
        else: 
            self.bn = None
        self.relu= nn.LeakyReLU()

    def forward(self, x):
        x= self.conv(x)
        if self.bn:
            x= self.bn(x)
        x= self.relu(x)

        return x

class up_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.upsample= nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        # self.conv= nn.Conv2d(in_c+in_c, out_c, kernel_size=3, stride=1)
        self.conv= nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
        self.bn= nn.BatchNorm2d(out_c)
        self.relu= nn.ReLU()

    def forward(self, x):
        # x= torch.cat([x, skip], axis=1)
        x= self.upsample(x)
        # print(x.shape)
        x= self.conv(x)
        x= self.bn(x)
        x= self.relu(x)
        return x

class final_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.upsample= nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv= nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x= self.upsample(x)
        x= self.conv(x)
        # need to add sigmoid?
        return x

class segmentor(nn.Module):
    def __init__(self, num_classes, filter_size= [64, 128, 256, 512]):
        super().__init__()
        self.encoder1= down_conv(1, filter_size[0], batch_normalization=False) # 1 -> 64
        self.encoder2= down_conv(filter_size[0], filter_size[1]) # 64 -> 128
        self.encoder3= down_conv(filter_size[1], filter_size[2]) # 128 -> 256
        self.encoder4= down_conv(filter_size[2], filter_size[3]) # 256 -> 512

        self.decoder1= up_conv(filter_size[3], filter_size[2]) # 512 -> 256
        self.decoder2= up_conv(filter_size[2] + filter_size[2], filter_size[1]) # (256+256) -> 128 (double the initial size because of concatenation)
        self.decoder3= up_conv(filter_size[1] + filter_size[1], filter_size[0]) # (128+128) -> 64

        self.final_conv= final_conv(filter_size[0] + filter_size[0], 3) # (64+64) -> 3

        self.output= nn.Conv2d(3, num_classes, kernel_size=1)
    
    def forward(self, x):
        el1= self.encoder1(x)
        # print("el1:" ,el1.shape)
        el2= self.encoder2(el1)
        # print("el2:" ,el2.shape)
        el3= self.encoder3(el2)
        # print("el3:" ,el3.shape)
        el4= self.encoder4(el3)
        # print("el4:" ,el4.shape)

        dl1= self.decoder1(el4)
        # print("dl1:" ,dl1.shape)
        # print("cat:", torch.cat([dl1, el3], axis=1).shape)
        dl2= self.decoder2(torch.cat([dl1, el3], axis=1))
        # print("dl2:" ,dl2.shape)
        dl3= self.decoder3(torch.cat([dl2, el2], axis=1))
        # print("dl3:" ,dl3.shape)

        dl4= self.final_conv(torch.cat([dl3, el1], axis=1))

        output= self.output(dl4)

        return torch.sigmoid(output)

class critic(nn.Module):
    def __init__(self, input_c, filter_size=[64, 128, 256, 512]):
        super().__init__()

        self.conv1= down_conv(input_c, filter_size[0], batch_normalization=False)
        self.conv2= down_conv(filter_size[0], filter_size[1])
        self.conv3= down_conv(filter_size[1], filter_size[2])
        self.conv4= down_conv(filter_size[2], filter_size[3])

        self.conv5= up_conv(filter_size[3], filter_size[2])
        self.conv6= up_conv(filter_size[2], filter_size[1])
        self.conv7= up_conv(filter_size[1], filter_size[0])

        self.conv= final_conv(filter_size[0], input_c)


    def forward(self, pred, true, ground_truth):
        masked_pred= pred*true # masking the predicted image with the true image
        masked_truth= ground_truth*true # masking the ground truth with the true image
        c1_pred= self.conv1(masked_pred) # level 1
        c2_pred= self.conv2(c1_pred) # level 2
        c3_pred= self.conv3(c2_pred) # level 3

        c1_gt= self.conv1(masked_truth)
        c2_gt= self.conv2(c1_gt)
        c3_gt= self.conv3(c2_gt)

        c= loss(masked_pred, masked_truth)
        c1= loss(c1_pred, c1_gt)
        c2= loss(c2_pred, c2_gt)
        c3= loss(c3_pred, c3_gt)    
        # print(c.shape)
        # print(c1.shape)
        # print(c2.shape)
        # print(c3.shape)


        l_mae= torch.cat([c, c1, c2, c3], axis=1).mean(dim=1)

        return l_mae
    
def loss(pred, target):
    loss1= (pred-target)**2
    loss1= loss1.mean(dim=[1,2,3])
    return loss1.unsqueeze(dim=1)

class GAN(pl.LightningModule):
    def __init__(self, num_classes, lr=0.0002):
        super().__init__()
        self.save_hyperparameters()

        self.segmentor= segmentor(self.hparams.num_classes)
        self.critic = critic(self.hparams.num_classes)
        self.automatic_optimization = False
        
        self.validation_step_dice = []
        self.validation_step_bce = []
    
    def forward(self, z):
        return self.segmentor(z)
    
    def adversarial_loss(self, l_mae):
        return l_mae.mean()
    
    def training_step(self, batch, batch_idx):
    
        true_imgs, gt, _= batch
        opt_c, opt_s= self.optimizers()

        # train the critic
        self.toggle_optimizer(opt_c)
        lmae_c= self.critic(self(true_imgs).detach(), true_imgs, gt)
        c_loss= self.adversarial_loss(lmae_c)
        self.manual_backward(c_loss)
        self.log("c_loss", c_loss, prog_bar=True)
        opt_c.step()
        opt_c.zero_grad()
        self.untoggle_optimizer(opt_c)

        # train the segmentor
        self.toggle_optimizer(opt_s)
        pred_imgs= self(true_imgs)
        lmae_s= self.critic(pred_imgs, true_imgs, gt)
        loss_dice= dice_loss(pred_imgs, gt)
        s_loss= self.adversarial_loss(lmae_s)+loss_dice
        # print(s_loss)
        self.log("s_loss", s_loss, prog_bar=True)
        self.manual_backward(s_loss)
        opt_s.step()
        opt_s.zero_grad()
        self.untoggle_optimizer(opt_s)



    def configure_optimizers(self):
        lr = self.hparams.lr

        opt_s= torch.optim.Adam(self.segmentor.parameters(), lr=lr)
        opt_c= torch.optim.Adam(self.critic.parameters(), lr=lr)

        return [opt_c, opt_s], []

    def validation_step(self, batch, batch_index):
        true_imgs, gt, _= batch
        pred_imgs=self(true_imgs)

        dice_score= dice_loss(pred_imgs, gt)
        bce_loss= nn.BCELoss()(pred_imgs, gt)

        self.validation_step_dice.append(dice_score.item())
        self.validation_step_bce.append(bce_loss.item())
        # val_loss= self.adversarial_loss(val_lmae)
    
    # def test_step(self, batch, batch_index):
        # true_imgs, gt, files= batch

        
    # def on_train_epoch_end(self):
        # print('\n')
    
    def on_validation_epoch_end(self):
        dice_score= np.mean(self.validation_step_dice)
        bce_loss= np.mean(self.validation_step_bce)
        print("\nvalidation dice loss in epoch {}: {}".format(self.current_epoch, dice_score))
        print("validation bce loss in epoch {}: {}".format(self.current_epoch, bce_loss))
        self.log("dice loss", dice_score)
        self.log("bce", bce_loss)
        self.validation_step_dice.clear()
        self.validation_step_bce.clear()
        
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    # print(intersection.shape)
    dice = 1-((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    
    return dice.mean()
        
x= torch.randn((8,2,20,20))
y= torch.randn((8,1,512,416))
z= torch.randn((8,1,228,228))

# segmentor(3)(y)


In [25]:
model= GAN(num_classes=3, lr= 0.0003)
torch.manual_seed(2023)
trainer = pl.Trainer(max_epochs=40, devices=[0])
early_stopping= EarlyStopping(monitor="dice loss", mode= 'max', patience=3)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type      | Params
----------------------------------------
0 | segmentor | segmentor | 4.4 M 
1 | critic    | critic    | 4.2 M 
----------------------------------------
8.6 M     Trainable params
0         Non-trainable params
8.6 M     Total params
34.436    Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 21.78it/s]
validation dice loss in epoch 0: 0.7327018082141876
validation bce loss in epoch 0: 0.6778037250041962
Epoch 0: 100%|██████████| 477/477 [01:13<00:00,  6.52it/s, v_num=86, c_loss=0.00714, s_loss=0.433]
validation dice loss in epoch 0: 0.4786368644237518
validation bce loss in epoch 0: 2.589010057449341
Epoch 1: 100%|██████████| 477/477 [01:12<00:00,  6.54it/s, v_num=86, c_loss=0.00698, s_loss=0.383]
validation dice loss in epoch 1: 0.4467296624183655
validation bce loss in epoch 1: 3.1609934425354003
Epoch 2: 100%|██████████| 477/477 [01:12<00:00,  6.55it/s, v_num=86, c_loss=0.00942, s_loss=0.381]
validation dice loss in epoch 2: 0.44116597294807436
validation bce loss in epoch 2: 3.656405839920044
Epoch 3: 100%|██████████| 477/477 [01:12<00:00,  6.55it/s, v_num=86, c_loss=0.008, s_loss=0.345]  
validation dice loss in epoch 3: 0.45806516528129576
validation bce loss in epoch 3: 3.9045039081573485
Epoch 4: 100%

`Trainer.fit` stopped: `max_epochs=40` reached.


Epoch 39: 100%|██████████| 477/477 [01:14<00:00,  6.39it/s, v_num=86, c_loss=0.000248, s_loss=0.0219]


### performance log

epoch 17 
dice= 0.559
c_loss = 0.00059, s_loss= 0.00059

version = 53 \
dice loss= 0.614 \
c_loss = 0.0028 s_loss= 0.0032 

version = 57 \
dice= 0.7\
c_loss= 0.011, s_loss= 0.00367




# Evaluate model

In [None]:
model = model.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.cuda()
# model= model.cuda()

model.eval()

test_loader= get_test_data(TRAIN_IMG_DIR, TRAIN_MASK_DIR, BATCH_SIZE, 
                           test_transform=None, 
                           num_workers=NUM_WORKERS, 
                           pin_memory=PIN_MEMORY)

with Path("/home/haobo/HaoboSeg-pytorch/eizzaty/"):
    if not os.path.exists(model_category):
        print(model_category + " does not exist. Creating directory...")
        os.makedirs(model_category)
        print(model_category+ " created!")

    for j, (inputs, masks, file_names) in enumerate(test_loader):

        inputs, masks= inputs.cuda(), masks.cuda()

        pred= model(inputs)
        pred= torch.sigmoid(pred)
        pred = pred.cpu().detach().numpy()
        ori_imgs = inputs.cpu().detach().numpy()
        ground_truth = masks.cpu().detach().numpy()


        if not os.path.exists(model_category+"/predicted_images"):
            print(model_category+"/predicted_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/predicted_images")
            print(model_category+"/predicted_images"+ " created!")
        
        if not os.path.exists(model_category+"/original_images"):
            print(model_category+"/original_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/original_images")
            print(model_category+"/original_images"+ " created!")

        if not os.path.exists(model_category+"/groundtruth_images"):
            print(model_category+"/groundtruth_images" + " does not exist. Creating directory...")
            os.makedirs(model_category+"/groundtruth_images")
            print(model_category+"/groundtruth_images"+ " created!")
        

        for i in range(inputs.size(0)):
            ground_truth_final= np.argmax(np.transpose(ground_truth[i].reshape(3, IMAGE_HEIGHT, IMAGE_WIDTH), (1,2,0)), axis=2)
            pred_final= np.argmax(np.transpose(pred[i].reshape(3, IMAGE_HEIGHT, IMAGE_WIDTH), (1,2,0)), axis=2)
            cv2.imwrite(model_category+"/groundtruth_images/"+Path(file_names[i]).stem + "_groundtruth.png", 100*ground_truth_final)
            cv2.imwrite(model_category+"/predicted_images/"+Path(file_names[i]).stem + "_pred.png", 100*pred_final)
            cv2.imwrite(model_category+"/original_images/"+Path(file_names[i]).stem + "_ori.png", 255*ori_imgs[i].reshape(IMAGE_HEIGHT, IMAGE_WIDTH))