In [1]:
!nvidia-smi

Fri May  2 03:25:46 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.03             Driver Version: 535.216.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  | 00000000:4B:00.0 Off |                    0 |
| N/A   38C    P0              43W / 300W |     17MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [1]:
import os
import csv
import cv2
import time
import torch
import itertools 
import numpy as np
import imagecodecs
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import tifffile as tifi
import torch.optim as optim
from torch.optim import Adam
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchmetrics import Accuracy
from torchvision import transforms
import torch.backends.cudnn as cudnn
from torchmetrics import Accuracy, Metric
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
LAMBDA_ADV = 0.1
MAX_LR = 0.00001
LEARNING_RATE_SEG = 0.00001
LEARNING_RATE_DIS = 0.0001
BATCH_SIZE = 8
BCE_LOSS = torch.nn.BCEWithLogitsLoss()

In [8]:
# class ImageMaskDataset(Dataset):
#     def __init__(self, images, masks):
#         self.images = torch.tensor(images, dtype=torch.float32).permute(0, 3, 1, 2)
#         self.masks = torch.tensor(masks, dtype=torch.float32).permute(0, 3, 1, 2)

#     def __len__(self):
#         return len(self.images)

#     def __getitem__(self, idx):
#         return self.images[idx], self.masks[idx]

class ImageMaskDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images  # Keep as NumPy arrays, not torch tensors
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.images[idx]).permute(2, 0, 1).float()
        mask = torch.from_numpy(self.masks[idx]).permute(2, 0, 1).float()
        return image, mask

In [4]:
image_dir = 'overlap_patch/img'
mask_dir = 'overlap_patch/mask'

image_dataset = [] 
mask_dataset = []

imgNames = os.listdir(image_dir)
imgAddr = image_dir + '/'
maskAddr = mask_dir + '/'

count = 0

for i in tqdm(range(len(imgNames)), desc="Processing images"):
# for i in tqdm(range(100), desc="Processing images"):
    try:
        mask = cv2.imread(maskAddr + imgNames[i], 0) 
        mask = Image.fromarray(mask)
        mask_array = np.array(mask)
        
        if np.sum(mask_array) > 8000:
            image = cv2.imread(imgAddr + imgNames[i], cv2.IMREAD_COLOR)  
            image = Image.fromarray(image)
            image_dataset.append(np.array(image))
            mask_dataset.append(np.array(mask))
            count += 1
        
    except Exception as e:
        print(f"Error processing {imgNames[i]}: {e}")
        continue

Processing images:  98%|████████████████████████████████████████████████████████▌ | 12919/13239 [07:34<00:10, 30.28it/s]

Error processing .ipynb_checkpoints: 'NoneType' object has no attribute '__array_interface__'


Processing images: 100%|██████████████████████████████████████████████████████████| 13239/13239 [07:47<00:00, 28.34it/s]


In [5]:
image_dataset = np.array(image_dataset)  
mask_dataset = np.array(mask_dataset)    

image_dataset = np.expand_dims(image_dataset, axis=3)
image_dataset = np.squeeze(image_dataset)
mask_dataset = np.expand_dims(mask_dataset, axis=3)

source_img_train, source_img_test, source_mask_train, source_mask_test = train_test_split(
    image_dataset, mask_dataset, test_size=0.3, random_state=44, shuffle=True)

source_img_val, source_img_test, source_mask_val, source_mask_test = train_test_split(
    source_img_test, source_mask_test, test_size=0.2, random_state=44, shuffle=True)

print(f"Total processed images: {count}")
print(f"Training set size: {len(source_img_train)}, Validation set size: {len(source_img_val)}, Test set size: {len(source_img_test)}")

Total processed images: 10367
Training set size: 7256, Validation set size: 2488, Test set size: 623


In [9]:
# train_dataset_s = ImageMaskDataset(source_img_train, source_mask_train)
# val_dataset_s = ImageMaskDataset(source_img_val, source_mask_val)
test_dataset_s = ImageMaskDataset(source_img_test, source_mask_test)

# train_loader_s = DataLoader(train_dataset_s, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
# val_loader_s = DataLoader(val_dataset_s, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
test_loader_s = DataLoader(test_dataset_s, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
del image_dataset, mask_dataset, source_img_train, source_img_test, source_mask_train, source_mask_test, source_img_val, source_mask_val

In [10]:
target_image_dir = 'aidpath_target/images_512'
target_mask_dir = 'aidpath_target/masks_512'
target_image = [] 
target_mask= []

imgNames = os.listdir(target_image_dir)
imgAddr = target_image_dir + '/'
maskAddr = target_mask_dir + '/'
count = 0

for i in tqdm(range(len(imgNames)), desc="Processing images"):
# for i in tqdm(range(100), desc="Processing images"):
    try:
        mask = cv2.imread(maskAddr + imgNames[i], 0) 
#         mask = Image.fromarray(mask)
        mask_array = np.array(mask)
        
        if np.sum(mask_array) > 0:
            image = cv2.imread(imgAddr + imgNames[i], cv2.IMREAD_COLOR)  
            mask = mask / 255.
            mask = (mask > 0.5).astype(int)
            target_image.append(np.array(image))
            target_mask.append(np.array(mask))
#             break
            count += 1
        
    except Exception as e:
        print(f"Error processing {imgNames[i]}: {e}")
        continue


Processing images:  98%|██████████████████████████████████████████████████████████▌ | 2432/2493 [01:41<00:02, 23.56it/s]

Error processing .ipynb_checkpoints: '>' not supported between instances of 'NoneType' and 'int'


Processing images: 100%|████████████████████████████████████████████████████████████| 2493/2493 [01:44<00:00, 23.87it/s]


In [10]:
#  'SESCAM_2_0.svs', 'SESCAM_3_0.svs', 'SESCAM_4_0.svs', 'SESCAM_7_0.svs', 'SESCAM_8_0.svs','VUHSK_1912.svs', 'VUHSK_1992.svs'

In [11]:
target_image = np.array(target_image)  
target_mask = np.array(target_mask)    

target_image = np.expand_dims(target_image, axis=3)
target_image = np.squeeze(target_image)
target_mask = np.expand_dims(target_mask, axis=3)

target_train_image, target_test_image, target_train_mask, target_test_mask = train_test_split(
    target_image, target_mask, test_size=0.1, random_state=44, shuffle=True)

target_val_image, target_test_image, target_val_mask, target_test_mask = train_test_split(
    target_test_image, target_test_mask, test_size=0.2, random_state=44, shuffle=True)

print(f"Total processed images: {count}")
print(f"Training set size: {len(target_train_image)}, Validation set size: {len(target_val_image)}, Test set size: {len(target_test_image)}")

Total processed images: 2492
Training set size: 2242, Validation set size: 200, Test set size: 50


In [12]:
# train_dataset_t = ImageMaskDataset(target_train_image, target_train_mask)
# val_dataset_t = ImageMaskDataset(target_val_image, target_val_mask)
test_dataset_t = ImageMaskDataset(target_test_image, target_test_mask)

# train_loader_t = DataLoader(train_dataset_t, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
# val_loader_t = DataLoader(val_dataset_t, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
test_loader_t = DataLoader(test_dataset_t, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [13]:
del target_image, target_mask, target_train_image, target_test_image, target_train_mask, target_test_mask, target_val_image, target_val_mask

In [14]:
class ResidualBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out)
        )
        self.residual = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0, bias=True)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        residual = self.residual(x)
        x = self.conv(x)
        x += residual  # Residual connection
        return self.relu(x)

class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),A
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.up(x)


class encoder(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(encoder, self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.ResBlock1 = ResidualBlock(ch_in=img_ch, ch_out=64)
        self.ResBlock2 = ResidualBlock(ch_in=64, ch_out=128)
        self.ResBlock3 = ResidualBlock(ch_in=128, ch_out=256)
        self.ResBlock4 = ResidualBlock(ch_in=256, ch_out=512)
        self.ResBlock5 = ResidualBlock(ch_in=512, ch_out=1024)
        
        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_ResBlock5 = ResidualBlock(ch_in=1024, ch_out=512)
        
        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_ResBlock4 = ResidualBlock(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_ResBlock3 = ResidualBlock(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_ResBlock2 = ResidualBlock(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # Encoding path
#         x = x.permute(0, 3, 1, 2)
        x1 = self.ResBlock1(x)
        x2 = self.Maxpool(x1)
        x2 = self.ResBlock2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.ResBlock3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.ResBlock4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.ResBlock5(x5)

        # Decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_ResBlock5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_ResBlock4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_ResBlock3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_ResBlock2(d2)

        d1 = self.Conv_1x1(d2)
#         d1 = d1.permute(0, 2, 3, 1)
        return d1

In [15]:
class FCDiscriminator(nn.Module):

    def __init__(self, num_classes, ndf=64):
        super(FCDiscriminator, self).__init__()

        self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
        self.classifier = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
#         x = x.permute(0, 3, 1, 2)
        x = self.conv1(x)
        x = self.leaky_relu(x)
        x = self.conv2(x)
        x = self.leaky_relu(x)
        x = self.conv3(x)
        x = self.leaky_relu(x)
        x = self.conv4(x)
        x = self.leaky_relu(x)
        x = self.classifier(x)
#         x = x.permute(0, 2, 3, 1)
        return x

In [16]:
def dice_coef(y_pred, y_true):
    smooth = 1.

    iflat = y_pred.view(-1)
    tflat = y_true.view(-1)
    intersection = (iflat * tflat).sum()

    return ((2. * intersection) + smooth) / (iflat.sum() + tflat.sum() + smooth)



In [17]:
def diceLoss(y_pred, y_true):
    smooth = 1.

    iflat = y_pred.view(-1)
    tflat = y_true.view(-1)
    intersection = (iflat * tflat).sum()

    return 1.0 - ((2. * intersection) + smooth) / (iflat.sum() + tflat.sum() + smooth)

In [18]:
def norm_image(image):
    min_val = image.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]  # Get min per channel
    max_val = image.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]  # Get max per channel
    
    normalized_img = (image - min_val) / (max_val - min_val + 1e-8)  # Add epsilon to avoid division by zero
    return normalized_img

In [None]:
target_loader_cycle = itertools.cycle(train_loader_t)

In [None]:
valtarget_loader_cycle = itertools.cycle(val_loader_t)

In [None]:
epochs = 100
source_label = 0
target_label = 1
model_seg = encoder()
model_dismain = FCDiscriminator(1)
model_seg = model_seg.train()
model_dismain = model_dismain.train()
model_seg = model_seg.to(device)
model_dismain = model_dismain.to(device)

In [None]:
optimizer_seg = optim.Adam(model_seg.parameters(), lr=LEARNING_RATE_SEG, betas=(0.9, 0.999))
# scheduler_seg = ReduceLROnPlateau(optimizer_seg, mode='min', factor=0.2, patience=5, min_lr = MAX_LR, verbose=True)

optimizer_dismain = optim.Adam(model_dismain.parameters(), lr=LEARNING_RATE_DIS, betas=(0.9, 0.999))
# scheduler_dis = ReduceLROnPlateau(optimizer_dismain, mode='min', factor=0.2, patience=5, min_lr = MAX_LR, verbose=True)
accuracy_train_metric = Accuracy(task='binary')
accuracy_val_metric = Accuracy(task='binary')
accuracy_train_metric = accuracy_train_metric.to(device)
accuracy_val_metric = accuracy_val_metric.to(device)

In [None]:
print(LEARNING_RATE_SEG)

In [None]:
log_name  = 'en_dis_0305'
writer = SummaryWriter('runs/' + log_name)
csv_file = open(log_name + '.csv', mode='w', newline='')
csv_writer = csv.writer(csv_file)
csv_writer.writerow([
    "Epoch", "Epoch Time (mins)", 
    "Train Seg Loss", "Val Seg Loss",
    "Train Dice Loss", "Val Dice Loss",
    "Train Adv Loss", "Val Adv Loss",
    "Train Dis Loss", "Val Dis Loss", 
    "Train Accuracy", "Val Accuracy", 
    "Seg LR", "Dis LR"
])

In [None]:
for epoch in range(epochs):
    model_seg.train()
    model_dismain.train()
    torch.cuda.empty_cache()
    segloss_b_t = []
    diceloss_b_t = []
    advloss_b_t = []
    disloss_b_t = []
    epoch_start_time = time.time()
    for i, (batch_s) in enumerate(train_loader_s):
        optimizer_seg.zero_grad()
        optimizer_dismain.zero_grad()
        
        batch_t = next(target_loader_cycle)
        
        image_s, label_s = batch_s
        image_t, label_t = batch_t
        image_s, label_s = image_s.to(device), label_s.to(device)
        image_s = norm_image(image_s) #color normalization
#         label_s = label_s.permute(0, 3, 1, 2)

        image_t, label_t = image_t.to(device), label_t.to(device)
        image_t = norm_image(image_t) #color normalization
#         label_t = label_t.permute(0, 3, 1, 2)

        #train discriminator model - start
        for param in model_dismain.parameters():
            param.requires_grad = True
            
        pred_s = model_seg(image_s).detach()
        pred_t = model_seg(image_t).detach()
        dis_output_s = model_dismain(pred_s)
        dis_output_t = model_dismain(pred_t)

        loss_dis_source = BCE_LOSS(dis_output_s, torch.FloatTensor(dis_output_s.data.size()).fill_(source_label).to(device))
        loss_dis_target = BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(target_label).to(device))
        loss_dis = (loss_dis_target + loss_dis_source)
        disloss_b_t.append(loss_dis.data.cpu().numpy())
        loss_dis_source.backward()
        loss_dis_target.backward()
#         loss_dis.backward()
        optimizer_dismain.step()
        #train discriminator model - end

        #train segmentation model - start
        for param in model_dismain.parameters():
            param.requires_grad = False
                    
        pred_s = model_seg(image_s)
        loss_seg = BCE_LOSS(pred_s, label_s)
        pred_s = torch.sigmoid(pred_s)
        pred_binary = (pred_s > 0.5).float()
        pred_binary = pred_binary.to(device)
        dice_loss = diceLoss(pred_s, label_s)        
        accuracy_train_metric(pred_binary, label_s)
        segloss_b_t.append(loss_seg.data.cpu().numpy())
        diceloss_b_t.append(dice_loss.data.cpu().numpy())
        (0.7 * loss_seg).backward(retain_graph=True)
        (0.3 * dice_loss).backward(retain_graph=True)
        #train segmentation model - end
        
        #Adv loss to segmentation model - start
        pred_t = model_seg(image_t)
        dis_output_t = model_dismain(pred_t)
        loss_adv_target = LAMBDA_ADV*BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(source_label).to(device))
        advloss_b_t.append(loss_adv_target.data.cpu().numpy())
        loss_adv_target.backward(retain_graph=True)
        #Adv loss to segmentation model - end
        optimizer_seg.step()
        
    #Batch loop ends here
    segloss_train = np.mean(segloss_b_t)
    diceLoss_train = np.mean(diceloss_b_t)
    advloss_train = np.mean(advloss_b_t)
    disloss_train = np.mean(disloss_b_t)
    accuracy_train = accuracy_train_metric
    accuracy_train = accuracy_train.compute()
    accuracy_train_metric.reset()
    
    model_seg.eval()
    model_dismain.eval()
    segloss_b_v = []
    diceloss_b_v = []
    advloss_b_v = []
    disloss_b_v = []
    with torch.no_grad():
        for i, data_vs in enumerate(val_loader_s):
            image_vs, label_vs = data_vs
            image_vs = image_vs.to(device)
            label_vs = label_vs.to(device)
            image_vs = norm_image(image_vs) #color normalization
            voutput_s = model_seg(image_vs)
#             label_vs = label_vs.permute(0, 3, 1, 2)
            vloss = BCE_LOSS(voutput_s, label_vs) #seg loss on source val
            voutput_s = torch.sigmoid(voutput_s)
            voutputs_binary_s = (voutput_s > 0.5).float()
            voutputs_binary_s = voutputs_binary_s.to(device)
            dice_loss = diceLoss(voutput_s, label_vs) #dice score on source val
            
            batch_vt = next(valtarget_loader_cycle)
            input_vt, label_vt = batch_vt
            input_vt = input_vt.to(device)
            label_vt = label_vt.to(device)
            input_vt = norm_image(input_vt) #color normalization
            voutput_t = model_seg(input_vt)
#             label_vt = label_vt.permute(0, 3, 1, 2)
            voutputs_binary_t = (torch.sigmoid(voutput_t) > 0.5).float()
            voutputs_binary_t = voutputs_binary_t.to(device)

            dis_output_t = model_dismain(voutput_t)
            loss_adv_target = LAMBDA_ADV*BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(source_label).to(device)) #adv loss on target val

            dis_output_s = model_dismain(voutput_s)
            loss_dis_source = BCE_LOSS(dis_output_s, torch.FloatTensor(dis_output_s.data.size()).fill_(source_label).to(device)) #dis loss on source val
            loss_dis_target = BCE_LOSS(dis_output_t, torch.FloatTensor(dis_output_t.data.size()).fill_(target_label).to(device)) #dis loss on target val
            loss_dis = loss_dis_target + loss_dis_source

            advloss_b_v.append(loss_adv_target.data.cpu().numpy())
            segloss_b_v.append(vloss.data.cpu().numpy())
            diceloss_b_v.append(dice_loss.data.cpu().numpy())
            disloss_b_v.append(loss_dis.data.cpu().numpy())
            accuracy_val_metric(voutputs_binary_s, label_vs)
        #val_loader_s loop ends here
    #no grad ends here
    accuracy_val = accuracy_val_metric
    accuracy_val = accuracy_val.compute()
    segloss_val = np.mean(segloss_b_v)
    diceLoss_val = np.mean(diceloss_b_v)
    advloss_val = np.mean(advloss_b_v)
    disloss_val = np.mean(disloss_b_v)
    accuracy_val_metric.reset() 
    epoch_end_time = time.time()
    epoch_time = (epoch_end_time - epoch_start_time) / 60
    for param_group in optimizer_seg.param_groups:
        seglr = param_group['lr']
    for param_group in optimizer_dismain.param_groups:
        dislr = param_group['lr']

#     print('[Epoch:{}, {:.2f}mins] Train Seg Loss: {:.2f}, Val Seg Loss: {:.2f}, Train Dice Score: {:.2f}, Val Dice Score: {:.2f}, '
#       'Train Adv Loss: {:.2f}, Val Adv Loss: {:.2f}, Train Dis Loss: {:.2f}, Val Dis Loss: {:.2f}, Train Dec Loss: {:.2f},'
#       'Val Dec Loss: {:.2f}'.format(epoch+1, epoch_time, segloss_train, segloss_val, diceLoss_train, diceLoss_val, advloss_train, advloss_val, disloss_train, disloss_val))
    #update learning rate
#     scheduler_seg.step(segloss_val)
#     scheduler_dis.step(disloss_val)

    #csv writer
    csv_writer.writerow([
        epoch + 1, epoch_time, 
        segloss_train, segloss_val, 
        diceLoss_train, diceLoss_val, 
        advloss_train, advloss_val, 
        disloss_train, disloss_val,
        accuracy_train.item(), accuracy_val.item(),
        seglr, dislr
    ])
    csv_file.flush() 
    #model save
    if (epoch+1) % 10 == 0:
        torch.save(model_seg, log_name+'.pth')
#epoch loop ends here
csv_file.close()
torch.save(model_seg, log_name+'.pth')

In [38]:
torch.save(model_seg, log_name+'.pth')

In [None]:
batch = next(iter(train_loader_s))

# Extract the image and mask from the batch
image, mask = batch

# Convert the image and mask to numpy arrays
image = image.numpy()
mask = mask.numpy()
# Plot the image and mask
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image[0])
ax[0].set_title('Image')
ax[1].imshow(mask[0], cmap='gray')
ax[1].set_title('Mask')
plt.show()

In [None]:
torch.save(model_seg, 'udapy_0212.pth')

In [None]:
for param_group in optimizer_dismain.param_groups:
        lr = param_group['lr']
print(lr)

In [19]:
model_seg = torch.load('en_dis_0305.pth')

In [20]:
model_seg.eval()
def visualize_prediction(test_image, ground_truth, predicted_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Test Image
    axes[0].imshow(test_image)  # Permute for correct shape (H, W, C)
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(ground_truth)  # Squeeze to remove channel dimension
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    # Predicted Mask
    axes[2].imshow(predicted_mask)  # Squeeze to remove channel dimension
    axes[2].set_title(f'Predicted Mask\nDice Coeff: {dice_score:.4f}')
    axes[2].axis('off')

    plt.show()

dice_scores_s = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_s):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_s.append(dice_score.cpu().numpy())
#         visualize_prediction(tm, ground_truth[0].cpu().numpy(), preds[0].cpu().numpy(), dice_score)

In [21]:
print('source dice score:',np.mean(dice_scores_s))

source dice score: 0.91261804


In [22]:
model_seg.eval()
def visualize_prediction(test_image, ground_truth, predicted_mask, dice_score):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Test Image
    axes[0].imshow(test_image)  # Permute for correct shape (H, W, C)
    axes[0].set_title('Test Image')
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(ground_truth)  # Squeeze to remove channel dimension
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')

    # Predicted Mask
    axes[2].imshow(predicted_mask)  # Squeeze to remove channel dimension
    axes[2].set_title(f'Predicted Mask\nDice Coeff: {dice_score:.4f}')
    axes[2].axis('off')

    plt.show()

dice_scores_t = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        dice_scores_t.append(dice_score.cpu().numpy())
#         visualize_prediction(tm, ground_truth[0].cpu().numpy(), preds[0].cpu().numpy(), dice_score)

In [23]:
print(np.mean(dice_scores_t))

0.6695752


In [35]:
dice_scores_t2 = []   
with torch.no_grad():
    for i, (test_image, ground_truth) in enumerate(test_loader_t):
        test_image = test_image.to(device)
        ground_truth = ground_truth.to(device)
        
        test_image = norm_image(test_image)
        preds = model_seg(test_image)
        preds = (preds > 0.5).float()
        test_image = test_image.permute(0, 2, 3, 1)
        tm = test_image[0].cpu().numpy()
#         ground_truth = ground_truth.permute(0, 3, 1, 2)
        preds = preds.permute(0, 2, 3, 1)
        ground_truth = ground_truth.permute(0, 2, 3, 1)
        dice_score = dice_coef(ground_truth, preds)
        if(np.sum(ground_truth.cpu().numpy())>8000):
            dice_scores_t2.append(dice_score.cpu().numpy())

In [36]:
print(np.mean(dice_scores_t2))

0.76928437


In [None]:
model_seg = torch.load('saved_model/udapy_2811.pth')
model_seg = model_seg.to(device)

In [83]:
wsi_path = 'aidpth_test_wsi/VUHSK_1992_2.ome.tif'
wsi = tifi.imread(wsi_path)

In [84]:
fname = wsi_path.split('/')[1].split('.')[0]

In [85]:
wsi.shape

(26188, 4393, 3)

In [None]:
stride = 512
patch_size = 512
location = [0, 255]

# Example WSI dimensions (replace with actual WSI shape)
width, height, _ = wsi.shape  # Assuming `wsi` is your input WSI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the output masks
img = np.zeros([width, height], dtype=np.float32)
img_final_th = np.zeros([width, height], dtype=bool)
count = 0

# Loop through locations
for eh in tqdm(location, desc="Processing patches at locations"):
    h = eh
    while h < height:
        w = eh
        while w < width:
            # Extract patch
            wsi_region = wsi[w:w + patch_size, h:h + patch_size]

            # Pad if the patch is smaller than the required size
            if wsi_region.shape[0] < patch_size or wsi_region.shape[1] < patch_size:
                padded_patch = np.zeros((patch_size, patch_size, 3), dtype=np.float32)
                padded_patch[:wsi_region.shape[0], :wsi_region.shape[1]] = wsi_region
                input_patch = padded_patch
            else:
                input_patch = wsi_region

            # Prepare the patch for model inference
#             input_patch = np.array(input_patch).transpose(0,1,2)  # Convert to (C, H, W)
            input_patch = torch.tensor(input_patch, dtype=torch.float32).unsqueeze(0).to(device)
            input_patch = input_patch.permute(0, 3, 1, 2)
#             print(input_patch.shape)
            input_patch = norm_image(input_patch)
            # Perform inference with the PyTorch model
            with torch.no_grad():
                input_patch = norm_image(input_patch)
                output = model_seg(input_patch)
                output = torch.sigmoid(output)  # Assuming binary output
                output = output.squeeze().cpu().numpy()
                output = np.where(output >= 0.5, 1, 0)
                input_patch = input_patch.permute(0, 2, 3, 1)
                # Visualization for patches with positive predictions
#                 sum_score = np.sum(output)
#                 if sum_score > 0:
#                     fig, axes = plt.subplots(1, 2, figsize=(15, 5))

#                     # Test Image
#                     axes[0].imshow(input_patch[0].cpu().numpy())  # (H, W, C)
#                     axes[0].set_title('Test Image')
#                     axes[0].axis('off')

#                     # Predicted Mask
#                     axes[1].imshow(output, cmap='gray')  # Binary mask
#                     axes[1].set_title(f'Predicted Mask\nSum count: {sum_score}')
#                     axes[1].axis('off')

#                     plt.show()

                # Save the patch output into the final image
                w_end = min(w + patch_size, width)
                h_end = min(h + patch_size, height)
                img[w:w_end, h:h_end] = output[:w_end - w, :h_end - h]

            w += stride

        h += stride

    # Threshold the output image for the current iteration
    img1_th = np.where(img >= 0.5, 1, 0)
    img_final_th = np.logical_or(img_final_th, img1_th)

    # Reset `img` for the next iteration
    img = np.zeros([width, height], dtype=np.float32)

    print(f'Completed iteration {count + 1}, img shape: {img1_th.shape}, img_final_th shape: {img_final_th.shape}')
    count += 1

# Final mask
img_final_th = img_final_th.astype(np.uint8) 
plt.imsave('output/'+ fname+ '_0403.png', img_final_th, cmap='gray')

In [None]:
width,height,_ =wsi.shape

stride = 512
patch_size = 512
location = [0, 255]

img = np.zeros([width, height], dtype=np.float32)
img_final_th = np.zeros([width,height])
count = 0

# Loop through locations
for eh in tqdm(location, desc="Processing patches at locations"):
    h = eh
    while h < height:
        w = eh
        while w < width:
            wsi_region = wsi[w:w + patch_size, h:h + patch_size]  # Extract patch
            
            if wsi_region.shape[0] < patch_size or wsi_region.shape[1] < patch_size:
                padded_patch = np.zeros([patch_size, patch_size, 3], dtype=np.float32)
                padded_patch[0:wsi_region.shape[0], 0:wsi_region.shape[1]] = wsi_region
                input_patch = padded_patch
            else:
                input_patch = wsi_region

                input_patch = np.array(input_patch).transpose(0,1,2)
                input_patch = torch.tensor(input_patch, dtype=torch.float32).unsqueeze(0).to(device)

                with torch.no_grad():
                    
                    output = model_seg(input_patch)
                    output = torch.sigmoid(output) 
                    output = output.squeeze().cpu().numpy()
                    
                    output = np.where(output >= 0.5, 1, 0)
                    sum_score = np.sum(output)
                    if sum_score>0:
                        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
                        # Test Image
                        axes[0].imshow(wsi_region)  # Permute for correct shape (H, W, C)
                        axes[0].set_title('Test Image')
                        axes[0].axis('off')


                        # Predicted Mask
                        axes[1].imshow(output)  # Squeeze to remove channel dimension
                        axes[1].set_title(f'Predicted Mask\nsum count: {sum_score}')
                        axes[1].axis('off')

                        plt.show()
                # Save the output in the final image
                
                img[w:w + patch_size, h:h + patch_size] = output

#             print(f"Processed patch at (h={h}, w={w})")
            w += stride

        h += stride

    # Threshold the output image
    img1_th = np.where(img >= 0.5, 1, 0)
    img_final_th = np.logical_or(img_final_th, img1_th)

    # Reset img for the next iteration
    img = np.zeros([width, height], dtype=np.float32)
    
    print(f'Completed iteration {count + 1}, img shape: {img1_th.shape}, img_final_th shape: {img_final_th.shape}')
    count += 1

In [None]:
plt.imsave('output/14218B_24_2811.png',img_final_th)

In [None]:
np.sum(img_final_th)

In [None]:
target_loader_cycle = itertools.cycle(train_loader_s)
for i, (batch_s) in enumerate(train_loader_s):
    batch_t = next(target_loader_cycle)
    image_t, label_t = batch_t
    label_t = label_t / 255.
    label_t = (label_t > 0.5).float()
    pc = np.sum(np.array(label_t))
    if pc<5000:
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        image_t = image_t.cpu().numpy().astype('uint8')

        # Test Image
        axes[0].imshow(image_t[0])  # Permute for correct shape (H, W, C)
        axes[0].set_title('Test Image')
        axes[0].axis('off')
        # Ground Truth
        axes[1].imshow(label_t[0])  # Squeeze to remove channel dimension
        axes[1].set_title(f'Ground Truth (Sum: {pc:.2f})')
        axes[1].axis('off')
        plt.show()

In [None]:
for i, (batch_s) in enumerate(train_loader_s):
    image_t, label_t = batch_s
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    image_t = image_t.cpu().numpy().astype('uint8')

    # Test Image
    axes[0].imshow(image_t[0])  
    axes[0].set_title('Test Image')
    axes[0].axis('off')
    # Ground Truth
    axes[1].imshow(label_t[0]) 
    axes[1].set_title(f'Ground Truth (Sum: {pc:.2f})')
    axes[1].axis('off')
    plt.show()