In [None]:
!nvidia-smi

In [None]:
import os
import csv
import cv2
import time
import torch
import itertools 
import numpy as np
import imagecodecs
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import tifffile as tifi
import torch.optim as optim
from torch.optim import Adam
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchmetrics import Accuracy
from torchvision import transforms
import torch.backends.cudnn as cudnn
from torchmetrics import Accuracy, Metric
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms

In [None]:
cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
LAMBDA_ADV = 0.1
MAX_LR = 0.0001
LEARNING_RATE_SEG = 0.00001
LEARNING_RATE_DIS = 0.0001
BATCH_SIZE = 8
BCE_LOSS = torch.nn.BCEWithLogitsLoss()

In [None]:
class ImageMaskDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images  # Keep as NumPy arrays, not torch tensors
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).permute(2, 0, 1).float()
        mask = torch.from_numpy(self.masks[idx]).permute(2, 0, 1).float()
        return image, mask

In [None]:
image_dir = 'overlap_patch/img'
mask_dir = 'overlap_patch/mask'

image_dataset = [] 
mask_dataset = []

imgNames = os.listdir(image_dir)
imgAddr = image_dir + '/'
maskAddr = mask_dir + '/'

count = 0

for i in tqdm(range(len(imgNames)), desc="Processing images"):
# for i in tqdm(range(200), desc="Processing images"):
    try:
        mask = cv2.imread(maskAddr + imgNames[i], 0) 
        mask = Image.fromarray(mask)
        mask_array = np.array(mask)
        
        if np.sum(mask_array) > 8000:
            image = cv2.imread(imgAddr + imgNames[i], cv2.IMREAD_COLOR)  
            image = Image.fromarray(image)
            image_dataset.append(np.array(image))
            mask_dataset.append(np.array(mask))
            count += 1
        
    except Exception as e:
        print(f"Error processing {imgNames[i]}: {e}")
        continue

In [None]:
imagebg_dir = 'train_dataset/img_ng'
maskbg_dir = 'train_dataset/mask_ng'

imgbgNames = os.listdir(imagebg_dir)
imgbgAddr = imagebg_dir + '/'
maskbgAddr = maskbg_dir + '/'
target_count = count
valid_count = 0
for i in tqdm(range(len(imgbgNames)), desc="Processing images"):
    if valid_count >= target_count:
        break  # Stop once we reach the desired count

#     try:
        # Read the mask in grayscale
    mask = cv2.imread(maskbgAddr + imgbgNames[i], 0)
    if mask is None:
        raise ValueError("Mask is None")

    # Normalize and binarize
    mask = mask / 255.
    mask = (mask > 0.5).astype(int)

    # Skip if mask has any foreground (non-zero)
    if np.sum(mask) != 0:
        continue

    # Read the corresponding image
    image = cv2.imread(imgbgAddr + imgbgNames[i], cv2.IMREAD_COLOR)
    if image is None:
        raise ValueError("Image is None")

    image = Image.fromarray(image)

    # Add to datasets
    image_dataset.append(np.array(image))
    mask_dataset.append(np.array(mask))
    valid_count += 1

In [None]:
image_dataset = np.array(image_dataset)  
mask_dataset = np.array(mask_dataset)    

image_dataset = np.expand_dims(image_dataset, axis=3)
image_dataset = np.squeeze(image_dataset)
mask_dataset = np.expand_dims(mask_dataset, axis=3)

source_img_train, source_img_test, source_mask_train, source_mask_test = train_test_split(
    image_dataset, mask_dataset, test_size=0.3, random_state=44, shuffle=True)
del image_dataset, mask_dataset
source_img_val, source_img_test, source_mask_val, source_mask_test = train_test_split(
    source_img_test, source_mask_test, test_size=0.2, random_state=44, shuffle=True)

print(f"Total processed images: {count}")
print(f"Training set size: {len(source_img_train)}, Validation set size: {len(source_img_val)}, Test set size: {len(source_img_test)}")

In [None]:
train_dataset_s = ImageMaskDataset(source_img_train, source_mask_train)
val_dataset_s = ImageMaskDataset(source_img_val, source_mask_val)
test_dataset_s = ImageMaskDataset(source_img_test, source_mask_test)

train_loader_s = DataLoader(train_dataset_s, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader_s = DataLoader(val_dataset_s, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
test_loader_s = DataLoader(test_dataset_s, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
del source_img_train, source_img_test, source_mask_train, source_mask_test, source_img_val, source_mask_val

In [None]:
target_image_dir = 'aidpath_target/images_512'
target_mask_dir = 'aidpath_target/masks_512'
target_image = [] 
target_mask= []
imgNames = os.listdir(target_image_dir)
imgAddr = target_image_dir + '/'
maskAddr = target_mask_dir + '/'
count = 0

for i in tqdm(range(len(imgNames)), desc="Processing images"):
# for i in tqdm(range(200), desc="Processing images"):
    try:
        mask = cv2.imread(maskAddr + imgNames[i], 0) 
#         mask = Image.fromarray(mask)
        mask_array = np.array(mask)
        
        if np.sum(mask_array) > 0:
            image = cv2.imread(imgAddr + imgNames[i], cv2.IMREAD_COLOR)  
            mask = mask / 255.
            mask = (mask > 0.5).astype(int)
            target_image.append(np.array(image))
            target_mask.append(np.array(mask))
#             break
            count += 1
        
    except Exception as e:
        print(f"Error processing {imgNames[i]}: {e}")
        continue


In [None]:
class ImageMaskDataset1(Dataset):
    def __init__(self, images, masks, exists):
        self.images = images  # NumPy arrays
        self.masks = masks
        self.exists = exists  # 1D array of 0s and 1s

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).permute(2, 0, 1).float()
        mask = torch.from_numpy(self.masks[idx]).permute(2, 0, 1).float()
        exist = torch.tensor(self.exists[idx]).float()
        return image, mask, exist

In [None]:
target_image = np.array(target_image)  
target_mask = np.array(target_mask) 

target_image = np.expand_dims(target_image, axis=3)
target_image = np.squeeze(target_image)
target_mask = np.expand_dims(target_mask, axis=3)

target_train_image, target_test_image, target_train_mask, target_test_mask = train_test_split(
    target_image, target_mask, test_size=0.1, random_state=44, shuffle=True)

target_val_image, target_test_image, target_val_mask, target_test_mask = train_test_split(
    target_test_image, target_test_mask, test_size=0.2, random_state=44, shuffle=True)

print(f"Total processed images: {count}")
print(f"Training set size: {len(target_train_image)}, Validation set size: {len(target_val_image)}, Test set size: {len(target_test_image)}")

In [None]:
train_exists = np.ones_like(target_train_image)
train_exists = np.ones((target_train_image.shape[0],), dtype=np.float32)
# Set 50% of the entries to 0
num_zeros = train_exists.shape[0] // 2
zero_indices = np.random.choice(train_exists.shape[0], size=num_zeros, replace=False)
train_exists[zero_indices] = 0

val_exists = np.ones_like(target_val_image)
val_exists = np.ones((target_val_image.shape[0],), dtype=np.float32)
# Set 50% of the entries to 0
num_zeros = val_exists.shape[0] // 2
zero_indices = np.random.choice(val_exists.shape[0], size=num_zeros, replace=False)
val_exists[zero_indices] = 0

In [None]:
train_dataset_t = ImageMaskDataset1(target_train_image, target_train_mask, train_exists)
val_dataset_t = ImageMaskDataset1(target_val_image, target_val_mask, val_exists)
test_dataset_t = ImageMaskDataset(target_test_image, target_test_mask)

train_loader_t = DataLoader(train_dataset_t, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader_t = DataLoader(val_dataset_t, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
test_loader_t = DataLoader(test_dataset_t, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
del target_image, target_mask, target_train_image, target_test_image, target_train_mask, target_test_mask, target_val_image, target_val_mask

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out)
        )
        self.residual = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=True)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        residual = self.residual(x)
        x = self.conv(x)
        x += residual  # Residual connection
        return self.relu(x)

class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.up(x)


class encoder(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(encoder, self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.ResBlock1 = ResidualBlock(ch_in=img_ch, ch_out=64)
        self.ResBlock2 = ResidualBlock(ch_in=64, ch_out=128)
        self.ResBlock3 = ResidualBlock(ch_in=128, ch_out=256)
        self.ResBlock4 = ResidualBlock(ch_in=256, ch_out=512)
        self.ResBlock5 = ResidualBlock(ch_in=512, ch_out=1024)
        
        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_ResBlock5 = ResidualBlock(ch_in=1024, ch_out=512)
        
        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_ResBlock4 = ResidualBlock(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_ResBlock3 = ResidualBlock(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_ResBlock2 = ResidualBlock(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # Encoding path
#         x = x.permute(0, 3, 1, 2)
        x1 = self.ResBlock1(x)
        x2 = self.Maxpool(x1)
        x2 = self.ResBlock2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.ResBlock3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.ResBlock4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.ResBlock5(x5)

        # Decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_ResBlock5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_ResBlock4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_ResBlock3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_ResBlock2(d2)

        d1 = self.Conv_1x1(d2)
#         d1 = d1.permute(0, 2, 3, 1)
        return d1

In [None]:
class FCDiscriminator(nn.Module):

    def __init__(self, num_classes, ndf=64):
        super(FCDiscriminator, self).__init__()

        self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
        self.classifier = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
#         x = x.permute(0, 3, 1, 2)
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.conv3(x)
        x = self.leaky_relu(x)
        x = self.conv4(x)
        x = self.leaky_relu(x)
        x = self.classifier(x)
#         x = x.permute(0, 2, 3, 1)
        return x

In [None]:
def dice_coef(y_pred, y_true):
    smooth = 1.

    iflat = y_pred.view(-1)
    tflat = y_true.view(-1)
    intersection = (iflat * tflat).sum()

    return ((2. * intersection) + smooth) / (iflat.sum() + tflat.sum() + smooth)



In [None]:
def diceLoss(y_pred, y_true):
    smooth = 1.

    iflat = y_pred.view(-1)
    tflat = y_true.view(-1)
    intersection = (iflat * tflat).sum()

    return 1.0 - ((2. * intersection) + smooth) / (iflat.sum() + tflat.sum() + smooth)

In [None]:
def norm_image(image):
    min_val = image.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]  # Get min per channel
    max_val = image.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]  # Get max per channel
    
    normalized_img = (image - min_val) / (max_val - min_val + 1e-8)  # Add epsilon to avoid division by zero
    return normalized_img

In [None]:
target_loader_cycle = itertools.cycle(train_loader_t)

In [None]:
valtarget_loader_cycle = itertools.cycle(val_loader_t)

In [None]:
epochs = 100
source_label = 0
target_label = 1
model_seg = encoder()
model_dismain = FCDiscriminator(1)
model_seg = model_seg.train()
model_dismain = model_dismain.train()
model_seg = model_seg.to(device)
model_dismain = model_dismain.to(device)

In [None]:
optimizer_seg = optim.Adam(model_seg.parameters(), lr=LEARNING_RATE_SEG, betas=(0.9, 0.999))
# scheduler_seg = ReduceLROnPlateau(optimizer_seg, mode='min', factor=0.2, patience=5, min_lr = MAX_LR, verbose=True)

optimizer_dismain = optim.Adam(model_dismain.parameters(), lr=LEARNING_RATE_DIS, betas=(0.9, 0.999))
# scheduler_dis = ReduceLROnPlateau(optimizer_dismain, mode='min', factor=0.2, patience=5, min_lr = MAX_LR, verbose=True)
accuracy_train_metric = Accuracy(task='binary')
accuracy_val_metric = Accuracy(task='binary')
accuracy_train_metric = accuracy_train_metric.to(device)
accuracy_val_metric = accuracy_val_metric.to(device)

In [None]:
print(LEARNING_RATE_SEG)

In [None]:
log_name  = 'ss_en_dis50_3005'
writer = SummaryWriter('runs/' + log_name)
csv_file = open(log_name + '.csv', mode='w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow([
    "Epoch", "Epoch Time (mins)", 
    "Train Seg Loss", "Val Seg Loss",
    "Train Dice Loss", "Val Dice Loss",
    "Train Adv Loss", "Val Adv Loss",
    "Train Dis Loss", "Val Dis Loss", 
    "Train Accuracy", "Val Accuracy", 
    "Seg LR", "Dis LR"
])

In [None]:
best_val_dice = 1.0
for epoch in range(epochs):
    model_seg.train()
    model_dismain.train()
    torch.cuda.empty_cache()
    segloss_b_t = []
    diceloss_b_t = []
    advloss_b_t = []
    disloss_b_t = []
    epoch_start_time = time.time()
    for i, (batch_s) in enumerate(train_loader_s):
        optimizer_seg.zero_grad()
        optimizer_dismain.zero_grad()
        
        batch_t = next(target_loader_cycle)
        
        image_s, label_s = batch_s
        image_t, label_t, exist_t = batch_t
        image_s, label_s = image_s.to(device), label_s.to(device)
        image_s = norm_image(image_s) #color normalization

        image_t, label_t = image_t.to(device), label_t.to(device)
        image_t = norm_image(image_t) #color normalization

        #train discriminator model - start
        for param in model_dismain.parameters():
            param.requires_grad = True
            
        pred_s = model_seg(image_s).detach()
        pred_t = model_seg(image_t).detach()
        dis_output_s = model_dismain(pred_s)
        dis_output_t = model_dismain(pred_t)

        loss_dis_source = BCE_LOSS(dis_output_s, torch.FloatTensor(dis_output_s.data.size()).fill_(source_label).to(device))
        loss_dis_target = BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(target_label).to(device))
        loss_dis = (loss_dis_target + loss_dis_source)
        disloss_b_t.append(loss_dis.data.cpu().numpy())
        loss_dis_source.backward()
        loss_dis_target.backward()
        optimizer_dismain.step()
        #train discriminator model - end

        #train segmentation model - start
        for param in model_dismain.parameters():
            param.requires_grad = False
                    
        pred_s = model_seg(image_s)
        pred_t = model_seg(image_t)
        loss_seg = BCE_LOSS(pred_s, label_s)
        indices = (exist_t == 1)

        # Apply loss only to valid indices
        if indices.any():
            loss_seg = loss_seg + BCE_LOSS(pred_t[indices], label_t[indices])
#             print(indices)
        pred_s = torch.sigmoid(pred_s)
        pred_binary = (pred_s > 0.5).float()
        pred_binary = pred_binary.to(device)
        dice_loss = diceLoss(pred_s, label_s)

        accuracy_train_metric(pred_binary, label_s)
        segloss_b_t.append(loss_seg.data.cpu().numpy())
        diceloss_b_t.append(dice_loss.data.cpu().numpy())
        loss_seg.backward(retain_graph=True)
        #train segmentation model - end
        
        #Adv loss to segmentation model - start
        pred_t = model_seg(image_t)
        dis_output_t = model_dismain(pred_t)
        loss_adv_target = LAMBDA_ADV*BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(source_label).to(device))
        advloss_b_t.append(loss_adv_target.data.cpu().numpy())
        loss_adv_target.backward(retain_graph=True)
        #Adv loss to segmentation model - end
        optimizer_seg.step()
        
    #Batch loop ends here
    segloss_train = np.mean(segloss_b_t)
    diceLoss_train = np.mean(diceloss_b_t)
    advloss_train = np.mean(advloss_b_t)
    disloss_train = np.mean(disloss_b_t)
    accuracy_train = accuracy_train_metric
    accuracy_train = accuracy_train.compute()
    accuracy_train_metric.reset()
    
    model_seg.eval()
    model_dismain.eval()
    segloss_b_v = []
    diceloss_b_v = []
    advloss_b_v = []
    disloss_b_v = []
    with torch.no_grad():
        for i, data_vs in enumerate(val_loader_s):
            image_vs, label_vs = data_vs
            image_vs = image_vs.to(device)
            label_vs = label_vs.to(device)
            image_vs = norm_image(image_vs) #color normalization

            batch_vt = next(valtarget_loader_cycle)
            input_vt, label_vt, exist_vt = batch_vt
            input_vt = input_vt.to(device)
            label_vt = label_vt.to(device)
            input_vt = norm_image(input_vt) #color normalization

            voutput_s = model_seg(image_vs)
            voutput_t = model_seg(input_vt)
            
            indices = (exist_vt == 1)
            vloss = BCE_LOSS(voutput_s, label_vs)
            if indices.any():
                vloss = vloss + BCE_LOSS(voutput_t[indices], label_vt[indices])
#                 print(indices)
            voutput_s = torch.sigmoid(voutput_s)
            voutputs_binary_s = (voutput_s > 0.5).float()
            voutputs_binary_s = voutputs_binary_s.to(device)
            dice_loss = diceLoss(voutput_s, label_vs) #dice score on source val

            # voutputs_binary_t = (torch.sigmoid(voutput_t) > 0.5).float()
            # voutputs_binary_t = voutputs_binary_t.to(device)

            dis_output_t = model_dismain(voutput_t)
            loss_adv_target = LAMBDA_ADV*BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(source_label).to(device)) #adv loss on target val

            dis_output_s = model_dismain(voutput_s)
            loss_dis_source = BCE_LOSS(dis_output_s, torch.FloatTensor(dis_output_s.data.size()).fill_(source_label).to(device)) #dis loss on source val
            loss_dis_target = BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(target_label).to(device)) #dis loss on target val
            loss_dis = loss_dis_target + loss_dis_source

            advloss_b_v.append(loss_adv_target.data.cpu().numpy())
            segloss_b_v.append(vloss.data.cpu().numpy())
            diceloss_b_v.append(dice_loss.data.cpu().numpy())
            disloss_b_v.append(loss_dis.data.cpu().numpy())
            accuracy_val_metric(voutputs_binary_s, label_vs)
        #val_loader_s loop ends here
    #no grad ends here
    accuracy_val = accuracy_val_metric
    accuracy_val = accuracy_val.compute()
    segloss_val = np.mean(segloss_b_v)
    diceLoss_val = np.mean(diceloss_b_v)
    advloss_val = np.mean(advloss_b_v)
    disloss_val = np.mean(disloss_b_v)
    accuracy_val_metric.reset() 
    epoch_end_time = time.time()
    epoch_time = (epoch_end_time - epoch_start_time) / 60
    for param_group in optimizer_seg.param_groups:
        seglr = param_group['lr']
    for param_group in optimizer_dismain.param_groups:
        dislr = param_group['lr']

    #csv writer
    csv_writer.writerow([
        epoch + 1, epoch_time, 
        segloss_train, segloss_val, 
        diceLoss_train, diceLoss_val, 
        advloss_train, advloss_val, 
        disloss_train, disloss_val,
        accuracy_train.item(), accuracy_val.item(),
        seglr, dislr
    ])
    csv_file.flush() 
    #model save
    if (diceLoss_val < best_val_dice):
        best_val_dice = diceLoss_val
        torch.save(model_seg, log_name+'.pth')
#epoch loop ends here
csv_file.close()
torch.save(model_seg, log_name+'final.pth')

final save

In [None]:
model_seg.eval()
def visualize_prediction(test_image, ground_truth, predicted_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Test Image
    axes[0].imshow(test_image)  # Permute for correct shape (H, W, C)
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(ground_truth)  # Squeeze to remove channel dimension
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    # Predicted Mask
    axes[2].imshow(predicted_mask)  # Squeeze to remove channel dimension
    axes[2].set_title(f'Predicted Mask\nDice Coeff: {dice_score:.4f}')
    axes[2].axis('off')

    plt.show()

dice_scores_s = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_s):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_s.append(dice_score.cpu().numpy())
#         visualize_prediction(tm, ground_truth[0].cpu().numpy(), preds[0].cpu().numpy(), dice_score)

In [None]:
print('source dice score with final saved model:',np.mean(dice_scores_s))

In [None]:
dice_scores_t = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_t.append(dice_score.cpu().numpy())
#         visualize_prediction(tm, ground_truth[0].cpu().numpy(), preds[0].cpu().numpy(), dice_score)

In [None]:
print('target dice score with final saved model:', np.mean(dice_scores_t))

In [None]:
dice_scores_t2 = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        if(np.sum(ground_truth.cpu().numpy())>8000):
            dice_scores_t2.append(dice_score.cpu().numpy())

In [None]:
print('target dice score with filter and final saved model:',np.mean(dice_scores_t2))

with best saved model

In [None]:
model_seg = torch.load('ss_en_dis50_3005.pth')

In [None]:
model_seg.eval()
def visualize_prediction(test_image, ground_truth, predicted_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Test Image
    axes[0].imshow(test_image)  # Permute for correct shape (H, W, C)
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(ground_truth)  # Squeeze to remove channel dimension
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    # Predicted Mask
    axes[2].imshow(predicted_mask)  # Squeeze to remove channel dimension
    axes[2].set_title(f'Predicted Mask\nDice Coeff: {dice_score:.4f}')
    axes[2].axis('off')

    plt.show()

In [None]:
dice_scores_s = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_s):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_s.append(dice_score.cpu().numpy())
print('source dice score with best saved model:',np.mean(dice_scores_s))

In [None]:
dice_scores_t = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_t.append(dice_score.cpu().numpy())
print('target dice score with best saved model:',np.mean(dice_scores_t))

In [None]:
dice_scores_t2 = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        if(np.sum(ground_truth.cpu().numpy())>8000):
            dice_scores_t2.append(dice_score.cpu().numpy())
print('target dice score with filter and best saved model:',np.mean(dice_scores_t2))

In [None]:
def calculate_accuracy(ground_truth, preds):
    """Calculate pixel-wise accuracy."""
    # Flatten the tensors to compare pixel-wise
    ground_truth_flat = ground_truth.flatten()
    preds_flat = preds.flatten()
    correct = (ground_truth_flat == preds_flat).float().sum()
    total = ground_truth_flat.numel()
    accuracy = correct / total if total > 0 else 0.0
    return accuracy

def calculate_f1_score(ground_truth, preds):
    """Calculate F1 score."""
    # Flatten the tensors
    ground_truth_flat = ground_truth.flatten()
    preds_flat = preds.flatten()
    
    # Calculate true positives, false positives, false negatives
    tp = ((preds_flat == 1) & (ground_truth_flat == 1)).float().sum()
    fp = ((preds_flat == 1) & (ground_truth_flat == 0)).float().sum()
    fn = ((preds_flat == 0) & (ground_truth_flat == 1)).float().sum()
    
    # Calculate precision and recall
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    
    # Calculate F1 score
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    return f1

In [None]:
dice_scores_s = []
accuracy_scores_s = []
f1_scores_s = []

with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_s):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        
        # Permute for visualization and metric computation
        test_image = test_image.permute(0, 2, 3, 1)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        
        # Compute metrics for each patch in the batch
        for j in range(test_image.shape[0]):  # Loop over each image in the batch
            # Dice score
            dice_score = dice_coef(ground_truth[j:j+1], preds[j:j+1])
            dice_scores_s.append(dice_score.cpu().numpy())
            
            # Accuracy
            accuracy = calculate_accuracy(ground_truth[j:j+1], preds[j:j+1])
            accuracy_scores_s.append(accuracy)  # Directly append the float
            
            # F1 score
            f1_score = calculate_f1_score(ground_truth[j:j+1], preds[j:j+1])
            f1_scores_s.append(f1_score)  # Directly append the float
            
            # Visualize (optional: limit to a few to avoid too many plots)
            if i % 10 == 0:  # Visualize every 10th batch
                tm = test_image[j].cpu().numpy()
                visualize_prediction(tm, ground_truth[j].cpu().numpy(), preds[j].cpu().numpy(), dice_score)

# Calculate and display the final test metrics
final_test_dice = sum(dice_scores_s) / len(dice_scores_s) if dice_scores_s else 0.0
final_test_accuracy = sum(accuracy_scores_s) / len(accuracy_scores_s) if accuracy_scores_s else 0.0
final_test_f1 = sum(f1_scores_s) / len(f1_scores_s) if f1_scores_s else 0.0

print(f'Final Test Dice Score: {final_test_dice:.4f}')
print(f'Final Test Accuracy: {final_test_accuracy:.4f}')
print(f'Final Test F1 Score: {final_test_f1:.4f}')

In [None]:
dice_scores_t2 = []
accuracy_scores_t2 = []
f1_scores_t2 = []

with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        
        # Permute for visualization and metric computation
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        
        # Compute metrics for the batch
        dice_score = dice_coef(ground_truth, preds)
        accuracy = calculate_accuracy(ground_truth, preds)
        f1_score = calculate_f1_score(ground_truth, preds)
        
        # Apply the condition for including metrics
        if np.sum(ground_truth.cpu().numpy()) > 8000:
            dice_scores_t2.append(dice_score.cpu().numpy() if torch.is_tensor(dice_score) else dice_score)
            accuracy_scores_t2.append(accuracy)  # Directly append the float
            f1_scores_t2.append(f1_score)  # Directly append the float

# Calculate and display the final test metrics
final_test_dice = sum(dice_scores_t2) / len(dice_scores_t2) if dice_scores_t2 else 0.0
final_test_accuracy = sum(accuracy_scores_t2) / len(accuracy_scores_t2) if accuracy_scores_t2 else 0.0
final_test_f1 = sum(f1_scores_t2) / len(f1_scores_t2) if f1_scores_t2 else 0.0

print(f'Final Test Dice Score: {final_test_dice:.4f}')
print(f'Final Test Accuracy: {final_test_accuracy:.4f}')
print(f'Final Test F1 Score: {final_test_f1:.4f}')