In [1]:
import torch
import torch.distributed as dist
import sys
import torch
from sklearn.metrics import roc_auc_score
from torch import nn
from tqdm import tqdm

class ConfusionMatrix:
    def __init__(self, num_classes,device):
        self.num_classes = num_classes
        self.mat = torch.zeros((num_classes, num_classes), dtype=torch.int64,device=device)

    def update(self, a, b):
        n = self.num_classes
        with torch.no_grad():
            mask = (a >= 0) & (a < n)
            a = a[mask].to(torch.int64)
            b = b[mask].to(torch.int64)
            indices = n * a + b
            counts = torch.bincount(indices, minlength=n ** 2)
            self.mat += counts.reshape(n, n)

    def reset(self):
        self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        diag_sum = torch.diag(h).sum()
        total_sum = h.sum()
        acc_global = (diag_sum / total_sum).item() if total_sum > 0 else 0.0
        se = (h[1, 1] / h[1].sum()).item() if h[1].sum() > 0 else 0.0
        sp = (h[0, 0] / h[0].sum()).item() if h[0].sum() > 0 else 0.0
        pr = (h[1, 1] / h[:, 1].sum()).item() if h[:, 1].sum() > 0 else 0.0
        F1 = 2 * (pr * se) / (pr + se) if (pr + se) > 0 else 0.0
        return acc_global, se, sp, F1, pr

def evaluate(model, data_loader, num_classes):
    model.eval()
    confmat = ConfusionMatrix(num_classes,device="cuda")
    data_loader = tqdm(data_loader)
    mask = None
    predict = None
    dice_c = 0

    with torch.no_grad():
        for data in data_loader:
            for image, target, eye_masks in data:
                image, target = image.cuda(), target.cuda()
                output = model(image)
                if type(output) is list:
                    output = output[0]
                output = torch.sigmoid(output)
                truth = output.clone()
                output[output >= 0.5] = 1
                output[output < 0.5] = 0
                confmat.update(target.flatten(), output.long().flatten())
                dice_c += dice_coeff(output, target)
                mask = target.flatten() if mask is None else torch.cat((mask, target.flatten()))
                predict = truth.flatten() if predict is None else torch.cat((predict, truth.flatten()))

    mask = mask.cpu().numpy()
    predict = predict.cpu().numpy()
    AUC_ROC = roc_auc_score(mask, predict)
    iou = calculate_iou(model,data_loader)
    return confmat.compute()[0], confmat.compute()[1], confmat.compute()[2], confmat.compute()[3], confmat.compute()[
        4], AUC_ROC, dice_c / len(data_loader), iou

    
def dice_coeff(x: torch.Tensor, target: torch.Tensor, epsilon=1e-6):
    d = 0.
    batch_size = x.shape[0]
    for i in range(batch_size):
        x_i = x[i].reshape(-1)
        t_i = target[i].reshape(-1).float()
        inter = torch.dot(x_i, t_i)
        sets_sum = torch.sum(x_i) + torch.sum(t_i)
        if sets_sum == 0:
            sets_sum = 2 * inter
        d += (2 * inter + epsilon) / (sets_sum + epsilon)
    return d / batch_size
    
def calculate_iou1(y_true, y_pred, labels=[0, 1]):
    IoU_scores = []
    for label in labels:
        jaccard = jaccard_score(y_true.flatten().cpu().numpy(), y_pred.flatten().cpu().numpy(), pos_label=label, average='weighted')
        IoU_scores.append(jaccard)     
    return torch.tensor(IoU_scores).mean().item()
    
def calculate_iou(model, dataloader):
    model.eval()
    total_iou = 0.0
    num_batches = len(dataloader)
    with torch.no_grad():
        for data in dataloader:
            z = len(data)
            for images, masks, eye_masks in data:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                
                # Convert outputs to binary predictions
                binary_preds = (outputs > 0.5).float()
                
                # Calculate intersection and union
                intersection = torch.logical_and(binary_preds, masks).sum((1, 2))  # Sum over height and width
                union = torch.logical_or(binary_preds, masks).sum((1, 2))         # Sum over height and width
                
                # Calculate IOU for each sample in the batch
                iou_per_sample = torch.where(union == 0, torch.ones_like(union), intersection.float() / union.float())
                
                # Average IOU across the batch
                batch_iou = iou_per_sample.mean().item()
                total_iou += batch_iou
    
    # Calculate average IOU across all batches
    avg_iou = total_iou / (num_batches*z)
    return avg_iou

def get_metrics(predict, target, threshold=0.5, predict_b=None):
    predict = torch.sigmoid(predict).cpu().detach().numpy().flatten()
    if predict_b is not None:
        predict_b = predict_b.flatten()
    else:
        predict_b = np.where(predict >= threshold, 1, 0)
    if torch.is_tensor(target):
        target = target.cpu().detach().numpy().flatten()
    else:
        target = target.flatten()
    tp = (predict_b * target).sum()
    tn = ((1 - predict_b) * (1 - target)).sum()
    fp = ((1 - target) * predict_b).sum()
    fn = ((1 - predict_b) * target).sum()
    auc = roc_auc_score(target, predict)
    acc = (tp + tn) / (tp + fp + fn + tn)
    pre = tp / (tp + fp)
    sen = tp / (tp + fn)
    spe = tn / (tn + fp)
    iou = tp / (tp + fp + fn)
    f1 = 2 * pre * sen / (pre + sen)
    return iou

In [2]:
import torch.nn.functional as F
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import v2
from PIL import Image
import torchsummary
import torch.optim as optim
import numpy as np
import torchvision.utils
# from train_utils.train_and_eval import train_one_epoch, evaluate, create_lr_scheduler

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)

    return inp
    
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None,windows_size=(128,128),stride=(32,32)):
        self.root_dir = root_dir
        self.transform = transform
        self.window_size = windows_size
        self.stride = stride
        
        self.image_dir = os.path.join(root_dir,'images')
        self.mask_dir = os.path.join(root_dir,'1st_manual')
        self.eye_mask_dir = os.path.join(root_dir,'mask')
        
        self.image_filenames = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir,image_name)
        image = Image.open(image_path).convert('L')
        patches_image = self.extract_patches(image)
        
        mask_name = image_name.replace("training.tif","manual1.gif")
        mask_path = os.path.join(self.mask_dir,mask_name)
        mask = Image.open(mask_path).convert('L')
        patches_mask = self.extract_patches(mask)
        

        eye_mask_name = image_name.replace(".tif","_mask.gif")
        eye_mask_path = os.path.join(self.eye_mask_dir,eye_mask_name)
        eye_mask = Image.open(eye_mask_path).convert('L')
        patches_eye_mask = self.extract_patches(eye_mask)

        data = []
        if self.transform:
            for patch1,patch2,patch3 in zip(patches_image,patches_mask,patches_eye_mask):
                data.append((self.transform(patch1),self.transform(patch2),self.transform(patch3)))

        return data

    def extract_patches(self, image):
        patches = []
        width, height = image.size
        window_height, window_width = self.window_size
        stride_vertical, stride_horizontal = self.stride

        for y in range(0, height - window_height + 1, stride_vertical):
            for x in range(0, width - window_width + 1, stride_horizontal):
                patch = image.crop((x, y, x + window_width, y + window_height))
                patches.append(patch)

        return patches

transform = v2.Compose([
    # v2.Resize((224,224)),
    # v2.RandomRotation(degrees=(0,180)),
    v2.Compose([v2.ToImage(),v2.ToDtype(torch.float32,scale=True)]),
    # v2.Normalize([0.485], [0.229]),
])

In [3]:
train_dataset = CustomDataset(root_dir="./training",transform=transform)
val_dataset = CustomDataset(root_dir="./validation",transform=transform)

train_loader = DataLoader(train_dataset,batch_size=32,shuffle=False)
val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False)

In [4]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList([
            DoubleConv(in_channels, 64),
            DoubleConv(64, 128),
            DoubleConv(128, 256),
            DoubleConv(256, 512),
            DoubleConv(512, 1024)
        ])
        
        self.decoder = nn.ModuleList([
            DoubleConv(1024 + 512, 512),  # Adjusted for concatenation
            DoubleConv(512 + 256, 256),   # Adjusted for concatenation
            DoubleConv(256 + 128, 128),   # Adjusted for concatenation
            DoubleConv(128 + 64, 64),     # Adjusted for concatenation
        ])
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for encoder in self.encoder:
            x = encoder(x)
            encoder_outputs.append(x)
            x = nn.MaxPool2d(kernel_size=2)(x)
        x = self.upsample(x)
        # Decoder
        for i, (decoder, encoder_output) in enumerate(zip(self.decoder, reversed(encoder_outputs[:-1]))):
            # print(x.shape)
            x = self.upsample(x)
            # print(f"Shape of x after upsampling in decoder {i+1}: {x.shape}")
            x = torch.cat([x, encoder_output], dim=1)
            x = decoder(x)
        
        # Final Convolution
        x = self.final_conv(x)
        # x = torch.sigmoid(x)  # Assuming binary segmentation
        return x

    # def upsample(self, x):
    #     return nn.ConvTranspose2d(x.shape[1], x.shape[1], kernel_size=2, stride=2).forward(x).to('cuda')


In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import wandb

train_iou_history = []
val_iou_history = []

device = torch.device('cuda')
model = UNet(in_channels=1, out_channels=1).to('cuda')
model = model.cuda()

criterion = nn.BCELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)
logs = wandb.init(project="Unet model")

best_acc = 0

num_epochs = 100 
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in train_loader:
        for images, masks, eye_masks in data:
            images = images.cuda()
            masks = masks.cuda()
            
            optimizer.zero_grad()
            
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            loss = criterion(outputs, masks)

            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    acc, se, sp, F1, pr, AUC_ROC, dice, iou = evaluate(model, val_loader, num_classes=2)
    # IOU = []
    # for data in val_loader:
    #     for images, masks, eye_masks in data:
    #         images = images.cuda()
    #         masks = masks.cuda()

    #         model.eval()
    #         outputs = model(images)
    #         outputs = torch.sigmoid(outputs)
            
    #         iou = get_metrics(outputs,masks)
    #         IOU.append(iou)
    # print(np.mean(IOU))
    logs.log({
        "acc": acc,
        "sensitivity": se,
        "specificity": sp,
        "F1-score": F1,
        "AUC_ROC": AUC_ROC,
        "Dice": dice,
        "epoch": epoch,
        "IOU": iou
    })
    if best_acc < acc:
        best_acc = acc
        print(f"Best accuracy attained: {best_acc*100:.4f}%")
        torch.save(model.state_dict(),"unet_model.pth")
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {acc*100:.4f}%, IOU: {iou}')

[34m[1mwandb[0m: Currently logged in as: [33mbattulasaikiran2002[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.07s/it]


Best accuracy attained: 90.4111%
Epoch [1/100], Loss: 51.6498, Accuracy: 90.4111%, IOU: 0.4858096198666663


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.01s/it]


Best accuracy attained: 93.1440%
Epoch [2/100], Loss: 36.9695, Accuracy: 93.1440%, IOU: 0.44009455499194916


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  6.00s/it]


Best accuracy attained: 93.8739%
Epoch [3/100], Loss: 32.5368, Accuracy: 93.8739%, IOU: 0.4714545215879168


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.29s/it]


Best accuracy attained: 93.9483%
Epoch [4/100], Loss: 30.4811, Accuracy: 93.9483%, IOU: 0.5591428765228816


  0%|                                                                         | 0/1 [00:00<?, ?it/s]

In [None]:
# torch.save(model.state_dict(),"unet_model.pth")

In [None]:
import torch
from torchvision import transforms
from PIL import Image

model.eval()


transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])


input_image = Image.open('15_test.tif').convert('L')  
input_tensor = transform(input_image).unsqueeze(0) 


with torch.no_grad():
    output = model(input_tensor.cuda())

output = torch.sigmoid(output)
predicted_mask = (output > 0.5).float()


import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(input_image,cmap='gray')
plt.title('Input Image')

plt.subplot(1, 2, 2)
plt.imshow(predicted_mask.squeeze().cpu().numpy(), cmap='gray')
plt.title('Predicted Mask')

plt.show()