In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from torchmetrics import Accuracy, Metric
import torchmetrics
import os
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
import itertools
import torch.backends.cudnn as cudnn
from torchmetrics import Accuracy 
cudnn.benchmark = True
import csv
from torch.nn import init

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
LEARNING_RATE_SEG = 0.001
BCE_LOSS = torch.nn.BCEWithLogitsLoss()

In [None]:
image_dir = 'overlap_patch/img'
mask_dir = 'overlap_patch/mask'

image_dataset = [] 
mask_dataset = []

imgNames = os.listdir(image_dir)
imgAddr = image_dir + '/'
maskAddr = mask_dir + '/'

count = 0

for i in tqdm(range(len(imgNames)), desc="Processing images"):
# for i in tqdm(range(500), desc="Processing images"):
    try:
        mask = cv2.imread(maskAddr + imgNames[i], 0) 
        mask = Image.fromarray(mask)
        mask_array = np.array(mask)
        
        if np.sum(mask_array) > 8000:
            image = cv2.imread(imgAddr + imgNames[i], cv2.IMREAD_COLOR)  
            image = Image.fromarray(image)

            image_dataset.append(np.array(image))
            mask_dataset.append(np.array(mask))
            count += 1
        
    except Exception as e:
        print(f"Error processing {imgNames[i]}: {e}")
        continue

In [None]:
image_dataset = np.array(image_dataset)  
mask_dataset = np.array(mask_dataset)    

image_dataset = np.expand_dims(image_dataset, axis=3)
image_dataset = np.squeeze(image_dataset)
mask_dataset = np.expand_dims(mask_dataset, axis=3)

source_xtrain, source_xtest, source_ytrain, source_ytest = train_test_split(
    image_dataset, mask_dataset, test_size=0.3, random_state=44, shuffle=True)

source_xval, source_xtest, source_yval, source_ytest = train_test_split(
    source_xtest, source_ytest, test_size=0.2, random_state=44, shuffle=True)

print(f"Total processed images: {count}")
print(f"Training set size: {len(source_xtrain)}, Validation set size: {len(source_xval)}, Test set size: {len(source_xtest)}")

In [None]:
class ImageMaskDataset(Dataset):
    def __init__(self, images, masks):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.masks = torch.tensor(masks, dtype=torch.float32)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.masks[idx]

train_dataset_s = ImageMaskDataset(source_xtrain, source_ytrain)
val_dataset_s = ImageMaskDataset(source_xval, source_yval)
test_dataset_s = ImageMaskDataset(source_xtest, source_ytest)

train_loader_s = DataLoader(train_dataset_s, batch_size=8, shuffle=True, num_workers=4)
val_loader_s = DataLoader(val_dataset_s, batch_size=8, shuffle=False, num_workers=4)
test_loader_s = DataLoader(test_dataset_s, batch_size=1, shuffle=False, num_workers=4)

In [None]:
def dice_coef(y_pred, y_true):
    smooth = 1.

    iflat = y_pred.view(-1)
    tflat = y_true.view(-1)
    intersection = (iflat * tflat).sum()

    return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)

In [None]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        x = x.permute(0, 3, 1, 2)
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)
        d1 = d1.permute(0, 2, 3, 1)
        return d1

In [None]:
epochs = 150
model_seg = U_Net()
model_seg.train()
model_seg = model_seg.to(device)

optimizer_seg = optim.Adam(model_seg.parameters(), lr=LEARNING_RATE_SEG, betas=(0.9, 0.999))
scheduler = ReduceLROnPlateau(optimizer_seg, mode='min', factor=0.2, patience=5, min_lr = 1e-5, verbose=True)

accuracy_train_metric = torchmetrics.Accuracy(task='binary').to(device)
accuracy_val_metric = torchmetrics.Accuracy(task='binary').to(device)
log_name = 'baseunet3001'
writer = SummaryWriter('runs/' + log_name)
csv_file = open(log_name + '.csv', 'w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow([
    'Epoch', 'Training Seg Loss', 'Validation Seg Loss',
    'Training Dice Score', 'Validation Dice Score',
    'Training Accuracy', 'Validation Accuracy',
    'Learning Rate', 'Time (mins)'
])
for epoch in range(epochs):
    model_seg.train()
    seg_loss_b = []
    dice_b = []
    epoch_start_time = time.time()
    for i, (batch_s) in enumerate(train_loader_s):
        optimizer_seg.zero_grad()
        image_s, label_s = batch_s
        image_s, label_s = image_s.to(device), label_s.to(device)

        pred_s = model_seg(image_s)
        loss_seg = BCE_LOSS(pred_s, label_s)

        pred_s = torch.sigmoid(pred_s)
        pred_binary = (pred_s > 0.5).float()
        dice_loss = dice_coef(pred_binary, label_s)
        accuracy_train_metric.update(pred_binary, label_s)

        loss_seg.backward()
        optimizer_seg.step()

        seg_loss_b.append(loss_seg.item())
        dice_b.append(dice_loss.item())

    # Training metrics
    segloss_train = np.mean(seg_loss_b)
    diceScore_train = np.mean(dice_b)
    accuracy_train = accuracy_train_metric.compute().item()
    accuracy_train_metric.reset()

    model_seg.eval()
    seg_loss_b = []
    dice_b = []
    
    with torch.no_grad():
        for i, vdata in enumerate(val_loader_s):
            vinputs, vlabels = vdata
            vinputs, vlabels = vinputs.to(device), vlabels.to(device)

            voutputs = model_seg(vinputs)
            vloss = BCE_LOSS(voutputs, vlabels)

            voutputs = torch.sigmoid(voutputs)
            voutputs_binary = (voutputs > 0.5).float()
            dice_loss = dice_coef(voutputs_binary, vlabels)
            accuracy_val_metric.update(voutputs_binary, vlabels)
            seg_loss_b.append(vloss.item())
            dice_b.append(dice_loss.item())

    segloss_val = np.mean(seg_loss_b)
    diceScore_val = np.mean(dice_b)
    accuracy_val = accuracy_val_metric.compute().item()
    accuracy_val_metric.reset()
    
    epoch_end_time = time.time()
    epoch_time = (epoch_end_time - epoch_start_time) / 60
    print('[Epoch: {}, {:.2f}mins] Training Seg Loss: {:.2f}, Validation Seg Loss: {:.2f}, '
          'Training Dice Score: {:.2f}, Validation Dice Score: {:.2f}, '
          'Training Accuracy: {:.2f}, Validation Accuracy: {:.2f}'.format(
          epoch+1, epoch_time, segloss_train, segloss_val, diceScore_train, diceScore_val, accuracy_train, accuracy_val))
    
    writer.add_scalar('Segmentation Loss/train', segloss_train, epoch+1)
    writer.add_scalar('Segmentation Loss/validation', segloss_val, epoch+1)
    writer.add_scalar('Segmentation Accuracy/train', accuracy_train, epoch+1)
    writer.add_scalar('Segmentation Accuracy/validation', accuracy_val, epoch+1)
    writer.add_scalar('Dice Score/train', diceScore_train, epoch+1)
    writer.add_scalar('Dice Score/validation', diceScore_val, epoch+1)
    writer.add_scalar('Learning Rate', scheduler.optimizer.param_groups[0]['lr'], epoch+1)
    writer.add_scalar('Time', epoch_time, epoch+1)
    # Step scheduler
    scheduler.step(segloss_val)
    
    csv_writer.writerow([
        epoch + 1, segloss_train, segloss_val, 
        diceScore_train, diceScore_val, 
        accuracy_train, accuracy_val, 
        scheduler.optimizer.param_groups[0]['lr'], epoch_time
    ])
    csv_file.flush() 
    
    if (epoch+1) % 10 == 0:
        torch.save(model_seg, log_name + '.pth')
writer.close()
csv_file.close()

In [None]:
model_seg.eval()
def visualize_prediction(test_image, ground_truth, predicted_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Test Image
    axes[0].imshow(test_image)  # Permute for correct shape (H, W, C)
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(ground_truth)  # Squeeze to remove channel dimension
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    # Predicted Mask
    axes[2].imshow(predicted_mask)  # Squeeze to remove channel dimension
    axes[2].set_title(f'Predicted Mask\nDice Coeff: {dice_score:.4f}')
    axes[2].axis('off')

    plt.show()

dice_scores = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_s):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
#         test_image = nor
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        tm = test_image[0].cpu().numpy().astype('uint8')
        dice_score = dice_coef(ground_truth, preds)
        dice_scores.append(dice_score.cpu().numpy())
        visualize_prediction(tm, ground_truth[0].cpu().numpy(), preds[0].cpu().numpy(), dice_score)

In [None]:
print(np.mean(dice_scores))