# Import Modules

In [13]:
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
import albumentations as A
import torchvision.transforms.functional as TF


import numpy as np
import cv2
import os, time
import matplotlib.pyplot as plt
from glob import glob

from torch.utils.tensorboard import SummaryWriter


# Check GPUs

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(0)

print(device)
print(torch.cuda.current_device())

cuda
0


# Set Parameters

In [15]:
LEARNING_RATE= 1e-3
BATCH_SIZE= 4
NUM_EPOCHS= 10
NUM_WORKERS= 0

IMAGE_HEIGHT= 512
IMAGE_WIDTH= 416
PIN_MEMORY= True
LOAD_MODEL= False

# num_block= [3, 4, 6, 3];
features_depth= [64, 128, 256, 512]
input_channel= 1 
num_classes= 3

model_category = 'segAN'
checkpoint_path = 'segAN.pth'
# training_checkpoint = 'training_checkpoint.pth'

TRAIN_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_f/*"))
TRAIN_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/train_m/*"))

VAL_IMG_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_f/*"))
VAL_MASK_DIR = sorted(glob("/home/haobo/HaoboSeg-pytorch/data_all/val_m/*"))

data_str = f"Dataset Size:\nTrain images: {len(TRAIN_IMG_DIR)}\t Train masks: {len(TRAIN_MASK_DIR)}"
print(data_str)

data_str = f"Val images: {len(VAL_IMG_DIR)}\t Val masks: {len(VAL_MASK_DIR)}"
print(data_str)

Dataset Size:
Train images: 3813	 Train masks: 3813
Val images: 195	 Val masks: 195


# Create Dataset

In [16]:
class EchoDataset(Dataset):
    def __init__(self, images_path, masks_path, transform=None):
        self.images_path = images_path
        self.masks_path = masks_path
        self.transform = transform


    def __getitem__(self, index):
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        image = image/image.max()
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)

        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)        
        mask = cv2.resize(mask, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_NEAREST)
        masks = [(mask==c) for c in range(3)]
        mask = np.stack(masks, axis=0)
        mask = mask.astype(np.float32)

        if self.transform is not None:
            augmentation= self.transform(image= image, mask= mask)
            image = augmentation['image']
            mask = augmentation['mask']

            # image = np.transpose(image, (1,2,0)).to(torch.float32)
            # mask = mask.to(torch.float32)

        return image, mask, self.images_path[index]

    def __len__(self):
        return len(self.images_path)

transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomGamma(gamma_limit= 70,p=0.6)

])

def get_train_data(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, batch_size, train_transform, val_transform, num_workers, pin_memory):
    train_ds= EchoDataset(train_img_dir, train_mask_dir, train_transform)
    train_dataloader= DataLoader(train_ds, batch_size=batch_size,
                                 shuffle=True, 
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)
    val_ds= EchoDataset(val_img_dir, val_mask_dir, val_transform)
    val_dataloader= DataLoader(val_ds, batch_size=batch_size,
                               shuffle=False,
                               num_workers=num_workers,
                               pin_memory=pin_memory)

    return train_dataloader, val_dataloader

def get_test_data(test_img_dir, test_mask_dir, batch_size, test_transform, num_workers, pin_memory):
    test_ds= EchoDataset(test_img_dir, test_mask_dir, test_transform)
    test_dataloader= DataLoader(test_ds, batch_size=batch_size,
                                shuffle= False,
                                num_workers=num_workers,
                                pin_memory=pin_memory) 
    return test_dataloader


# train_ds= EchoDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, transform)
# print(ds[1][0].dtype)
train_dataloader, val_dataloader = get_train_data(train_img_dir= TRAIN_IMG_DIR, train_mask_dir= TRAIN_MASK_DIR, 
                                                  val_img_dir= VAL_IMG_DIR, val_mask_dir= VAL_MASK_DIR, 
                                                  train_transform=None, val_transform=None,
                                                  batch_size= BATCH_SIZE, 
                                                  num_workers= NUM_WORKERS, 
                                                  pin_memory= PIN_MEMORY)

len(train_dataloader)

                        

954

# Call Model

In [45]:
class down_conv(nn.Module):
    """ Down convolution with a kernel size of 4x4 and optional batch normalization.
    Args:
    in_c : int
    out_c: int
    stride: int
    batch_normalization: bool
    """
    def __init__(self, in_c, out_c, batch_normalization=True):
        super().__init__()
        
        self.conv= nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1)
        if batch_normalization:
            self.bn= nn.BatchNorm2d(out_c)
        else: 
            self.bn = None
        self.relu= nn.LeakyReLU()

    def forward(self, x):
        x= self.conv(x)
        if self.bn:
            x= self.bn(x)
        x= self.relu(x)

        return x

class up_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.upsample= nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        # self.conv= nn.Conv2d(in_c+in_c, out_c, kernel_size=3, stride=1)
        self.conv= nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
        self.bn= nn.BatchNorm2d(out_c)
        self.relu= nn.ReLU()

    def forward(self, x):
        # x= torch.cat([x, skip], axis=1)
        x= self.upsample(x)
        # print(x.shape)
        x= self.conv(x)
        x= self.bn(x)
        x= self.relu(x)
        return x

class final_conv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.upsample= nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv= nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x= self.upsample(x)
        x= self.conv(x)
        # need to add sigmoid?
        return x

class segmentor(nn.Module):
    def __init__(self, num_classes, input_depth, filter_size= [64, 128, 256, 512]):
        super().__init__()
        self.encoder1= down_conv(input_depth, filter_size[0], batch_normalization=False) # 1 -> 64
        self.encoder2= down_conv(filter_size[0], filter_size[1]) # 64 -> 128
        self.encoder3= down_conv(filter_size[1], filter_size[2]) # 128 -> 256
        self.encoder4= down_conv(filter_size[2], filter_size[3]) # 256 -> 512

        self.decoder1= up_conv(filter_size[3], filter_size[2]) # 512 -> 256
        self.decoder2= up_conv(filter_size[2] + filter_size[2], filter_size[1]) # (256+256) -> 128 (double the initial size because of concatenation)
        self.decoder3= up_conv(filter_size[1] + filter_size[1], filter_size[0]) # (128+128) -> 64

        self.final_conv= final_conv(filter_size[0] + filter_size[0], 3) # (64+64) -> 3

        self.output= nn.Conv2d(3, num_classes, kernel_size=1)
    
    def forward(self, x):
        el1= self.encoder1(x)
        # print("el1:" ,el1.shape)
        el2= self.encoder2(el1)
        # print("el2:" ,el2.shape)
        el3= self.encoder3(el2)
        # print("el3:" ,el3.shape)
        el4= self.encoder4(el3)
        # print("el4:" ,el4.shape)

        dl1= self.decoder1(el4)
        # print("dl1:" ,dl1.shape)
        # print("cat:", torch.cat([dl1, el3], axis=1).shape)
        dl2= self.decoder2(torch.cat([dl1, el3], axis=1))
        # print("dl2:" ,dl2.shape)
        dl3= self.decoder3(torch.cat([dl2, el2], axis=1))
        # print("dl3:" ,dl3.shape)

        dl4= self.final_conv(torch.cat([dl3, el1], axis=1))

        output= self.output(dl4)

        return output

class critic(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1= nn.Conv2d(1,10, kernel_size=5)
        self.conv2= nn.Conv2d(10,20, kernel_size=5)
        self.conv2_drop= nn.Dropout2d()
        
        




input_channel=1
x= torch.rand((8, input_channel, 160, 160))
# ans= nn.Conv2d(3, 3, kernel_size=1)(x)
# ans= up_conv(input_channel, 256)(x)
# ans= down_conv(input_channel, 3, False)(x)
ans= segmentor(3, input_channel, filter_size=features_depth)(x)
print(x.shape)
print(ans.shape)



torch.Size([8, 1, 160, 160])
torch.Size([8, 3, 160, 160])
