In [1]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
img_list = []
mask_list = []

for i in os.listdir('/kaggle/input/chasedb1'):
    if '1stHO' in i:
        mask_list.append('/kaggle/input/chasedb1/'+i)
        
for i in os.listdir('/kaggle/input/processed-chasedb/processed_dataset'):
    img_list.append('/kaggle/input/processed-chasedb/processed_dataset/'+i)

img_list.sort()
mask_list.sort()
print("Done")

Done


In [10]:
img_list = []
mask_list = []

for i in os.listdir('/kaggle/input/chasedb1'):
    if '.jpg' in i:
        img_list.append('/kaggle/input/chasedb1/'+i)
    else:
        if '1stHO' in i:
            mask_list.append('/kaggle/input/chasedb1/'+i)
            
img_list.sort()
mask_list.sort()
print("Done")

Done


In [11]:
class RetinalDataset(Dataset):
    def __init__(self, images_list, masks_list, transform=None):
        self.transform = transform
        self.images_list = images_list
        self.masks_list = masks_list

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.images_list[idx]).convert('RGB')).astype(np.float32)
        mask = np.array(Image.open(self.masks_list[idx])).astype(np.uint8)

        # Center Crop
        height, width = image.shape[:2]
        crop_size = 960
        top = (height - crop_size) // 2
        left = (width - crop_size) // 2
        image = image[top:top+crop_size, left:left+crop_size]
        mask = mask[top:top+crop_size, left:left+crop_size]
        mask = mask.reshape((960, 960, 1))

        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented['image'], augmented['mask']

        image = image.astype(np.float32) 
        mask = mask.astype(np.float32) 
        mask = torch.tensor(mask).permute(2, 0, 1) 
        image = torch.tensor(image).permute(2, 0, 1)  
        return image, mask

# transform = A.Compose([
#     A.HorizontalFlip(p=0.5),
#     A.VerticalFlip(p=0.5),
#     A.RandomRotate90(p=0.5),
#     A.RandomBrightnessContrast(p=0.2, brightness_limit=0.2, contrast_limit=0.2),
#     A.Normalize(mean=(0.5,), std=(0.5,)),
# ])

dataset = RetinalDataset(img_list, mask_list, transform=None)
train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=123)
train_loader = DataLoader(train_data, batch_size=2, shuffle=True)
val_loader = DataLoader(val_data, batch_size=2, shuffle=False)
print('Done')

Done


In [4]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)
        self.bottleneck = conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(nn.MaxPool2d(2)(e1))
        e3 = self.encoder3(nn.MaxPool2d(2)(e2))
        e4 = self.encoder4(nn.MaxPool2d(2)(e3))
        b = self.bottleneck(nn.MaxPool2d(2)(e4))

        # Decoder
        d4 = self.upconv4(b)
        d4 = self.decoder4(torch.cat([d4, e4], dim=1))
        d3 = self.upconv3(d4)
        d3 = self.decoder3(torch.cat([d3, e3], dim=1))
        d2 = self.upconv2(d3)
        d2 = self.decoder2(torch.cat([d2, e2], dim=1))
        d1 = self.upconv1(d2)
        d1 = self.decoder1(torch.cat([d1, e1], dim=1))

        return torch.sigmoid(self.final_conv(d1))

print(True)

True


In [5]:
def dice_loss(pred, target):
    smooth = 1.0
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = 1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))
    return loss.mean()

print(True)

True


In [6]:
model = UNet().cuda()
criterion_bce = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_train_loss = float('inf')
best_model_path = "best_model.pth"  
num_epochs = 150

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    
    for images, masks in tqdm(train_loader):
        
        images = images.cuda()
        masks = masks.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        total_loss = criterion_bce(outputs, masks) + dice_loss(outputs, masks)
        total_loss.backward()
        optimizer.step()
        train_loss += total_loss.item()
        
    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_train_loss:.4f}")
    
    if avg_train_loss < best_train_loss:
        best_train_loss = avg_train_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"Model saved with train loss: {best_train_loss:.4f}")


100%|██████████| 11/11 [00:16<00:00,  1.49s/it]


Epoch 1/150, Loss: 1.3366
Model saved with train loss: 1.3366


100%|██████████| 11/11 [00:15<00:00,  1.42s/it]


Epoch 2/150, Loss: 1.3040
Model saved with train loss: 1.3040


100%|██████████| 11/11 [00:15<00:00,  1.45s/it]


Epoch 3/150, Loss: 1.2725
Model saved with train loss: 1.2725


100%|██████████| 11/11 [00:16<00:00,  1.49s/it]


Epoch 4/150, Loss: 1.1858
Model saved with train loss: 1.1858


100%|██████████| 11/11 [00:16<00:00,  1.53s/it]


Epoch 5/150, Loss: 1.1183
Model saved with train loss: 1.1183


100%|██████████| 11/11 [00:17<00:00,  1.57s/it]


Epoch 6/150, Loss: 0.9234
Model saved with train loss: 0.9234


100%|██████████| 11/11 [00:17<00:00,  1.60s/it]


Epoch 7/150, Loss: 0.8914
Model saved with train loss: 0.8914


100%|██████████| 11/11 [00:18<00:00,  1.65s/it]


Epoch 8/150, Loss: 0.7707
Model saved with train loss: 0.7707


100%|██████████| 11/11 [00:18<00:00,  1.72s/it]


Epoch 9/150, Loss: 0.7286
Model saved with train loss: 0.7286


100%|██████████| 11/11 [00:19<00:00,  1.79s/it]


Epoch 10/150, Loss: 0.6127
Model saved with train loss: 0.6127


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 11/150, Loss: 0.6017
Model saved with train loss: 0.6017


100%|██████████| 11/11 [00:18<00:00,  1.72s/it]


Epoch 12/150, Loss: 0.5690
Model saved with train loss: 0.5690


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 13/150, Loss: 0.5270
Model saved with train loss: 0.5270


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 14/150, Loss: 0.5160
Model saved with train loss: 0.5160


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 15/150, Loss: 0.5190


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 16/150, Loss: 0.5188


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 17/150, Loss: 0.5404


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 18/150, Loss: 0.5357


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 19/150, Loss: 0.5348


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 20/150, Loss: 0.5323


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 21/150, Loss: 0.5085
Model saved with train loss: 0.5085


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 22/150, Loss: 0.5181


100%|██████████| 11/11 [00:19<00:00,  1.73s/it]


Epoch 23/150, Loss: 0.5165


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 24/150, Loss: 0.5054
Model saved with train loss: 0.5054


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 25/150, Loss: 0.5054


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 26/150, Loss: 0.4983
Model saved with train loss: 0.4983


100%|██████████| 11/11 [00:19<00:00,  1.74s/it]


Epoch 27/150, Loss: 0.4772
Model saved with train loss: 0.4772


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 28/150, Loss: 0.4992


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 29/150, Loss: 0.4932


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 30/150, Loss: 0.4730
Model saved with train loss: 0.4730


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 31/150, Loss: 0.4768


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 32/150, Loss: 0.4937


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 33/150, Loss: 0.4661
Model saved with train loss: 0.4661


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 34/150, Loss: 0.4647
Model saved with train loss: 0.4647


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 35/150, Loss: 0.4596
Model saved with train loss: 0.4596


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 36/150, Loss: 0.4740


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 37/150, Loss: 0.4730


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 38/150, Loss: 0.4886


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 39/150, Loss: 0.4717


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 40/150, Loss: 0.4545
Model saved with train loss: 0.4545


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 41/150, Loss: 0.4451
Model saved with train loss: 0.4451


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 42/150, Loss: 0.4606


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 43/150, Loss: 0.4640


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 44/150, Loss: 0.4492


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 45/150, Loss: 0.4491


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 46/150, Loss: 0.4368
Model saved with train loss: 0.4368


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 47/150, Loss: 0.4255
Model saved with train loss: 0.4255


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 48/150, Loss: 0.4314


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 49/150, Loss: 0.4235
Model saved with train loss: 0.4235


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 50/150, Loss: 0.4201
Model saved with train loss: 0.4201


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 51/150, Loss: 0.4332


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 52/150, Loss: 0.4479


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 53/150, Loss: 0.4318


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 54/150, Loss: 0.4364


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 55/150, Loss: 0.4254


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 56/150, Loss: 0.4230


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 57/150, Loss: 0.4097
Model saved with train loss: 0.4097


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 58/150, Loss: 0.4083
Model saved with train loss: 0.4083


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 59/150, Loss: 0.4045
Model saved with train loss: 0.4045


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 60/150, Loss: 0.4293


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 61/150, Loss: 0.4088


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 62/150, Loss: 0.4038
Model saved with train loss: 0.4038


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 63/150, Loss: 0.3972
Model saved with train loss: 0.3972


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 64/150, Loss: 0.4014


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 65/150, Loss: 0.3999


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 66/150, Loss: 0.3999


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 67/150, Loss: 0.3924
Model saved with train loss: 0.3924


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 68/150, Loss: 0.3970


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 69/150, Loss: 0.3883
Model saved with train loss: 0.3883


100%|██████████| 11/11 [00:19<00:00,  1.75s/it]


Epoch 70/150, Loss: 0.3932


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 71/150, Loss: 0.3889


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 72/150, Loss: 0.3841
Model saved with train loss: 0.3841


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 73/150, Loss: 0.3835
Model saved with train loss: 0.3835


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 74/150, Loss: 0.3876


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 75/150, Loss: 0.3854


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 76/150, Loss: 0.3701
Model saved with train loss: 0.3701


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 77/150, Loss: 0.3691
Model saved with train loss: 0.3691


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 78/150, Loss: 0.3743


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 79/150, Loss: 0.3800


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 80/150, Loss: 0.3793


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 81/150, Loss: 0.3658
Model saved with train loss: 0.3658


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 82/150, Loss: 0.3632
Model saved with train loss: 0.3632


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 83/150, Loss: 0.3696


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 84/150, Loss: 0.3630
Model saved with train loss: 0.3630


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 85/150, Loss: 0.3478
Model saved with train loss: 0.3478


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 86/150, Loss: 0.3461
Model saved with train loss: 0.3461


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 87/150, Loss: 0.3486


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 88/150, Loss: 0.3440
Model saved with train loss: 0.3440


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 89/150, Loss: 0.3426
Model saved with train loss: 0.3426


100%|██████████| 11/11 [00:19<00:00,  1.79s/it]


Epoch 90/150, Loss: 0.3514


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 91/150, Loss: 0.3395
Model saved with train loss: 0.3395


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 92/150, Loss: 0.3269
Model saved with train loss: 0.3269


100%|██████████| 11/11 [00:19<00:00,  1.79s/it]


Epoch 93/150, Loss: 0.3253
Model saved with train loss: 0.3253


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 94/150, Loss: 0.3257


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 95/150, Loss: 0.3242
Model saved with train loss: 0.3242


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 96/150, Loss: 0.3157
Model saved with train loss: 0.3157


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 97/150, Loss: 0.3152
Model saved with train loss: 0.3152


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 98/150, Loss: 0.3096
Model saved with train loss: 0.3096


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 99/150, Loss: 0.3078
Model saved with train loss: 0.3078


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 100/150, Loss: 0.3212


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 101/150, Loss: 0.3138


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 102/150, Loss: 0.3116


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 103/150, Loss: 0.3052
Model saved with train loss: 0.3052


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 104/150, Loss: 0.3075


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 105/150, Loss: 0.3037
Model saved with train loss: 0.3037


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 106/150, Loss: 0.2951
Model saved with train loss: 0.2951


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 107/150, Loss: 0.2962


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 108/150, Loss: 0.2918
Model saved with train loss: 0.2918


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 109/150, Loss: 0.2932


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 110/150, Loss: 0.2907
Model saved with train loss: 0.2907


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 111/150, Loss: 0.2888
Model saved with train loss: 0.2888


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 112/150, Loss: 0.2841
Model saved with train loss: 0.2841


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 113/150, Loss: 0.2794
Model saved with train loss: 0.2794


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 114/150, Loss: 0.2784
Model saved with train loss: 0.2784


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 115/150, Loss: 0.2775
Model saved with train loss: 0.2775


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 116/150, Loss: 0.2729
Model saved with train loss: 0.2729


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 117/150, Loss: 0.2715
Model saved with train loss: 0.2715


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 118/150, Loss: 0.2733


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 119/150, Loss: 0.2688
Model saved with train loss: 0.2688


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 120/150, Loss: 0.2720


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 121/150, Loss: 0.2730


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 122/150, Loss: 0.2687
Model saved with train loss: 0.2687


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 123/150, Loss: 0.2575
Model saved with train loss: 0.2575


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 124/150, Loss: 0.2558
Model saved with train loss: 0.2558


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 125/150, Loss: 0.2515
Model saved with train loss: 0.2515


100%|██████████| 11/11 [00:19<00:00,  1.76s/it]


Epoch 126/150, Loss: 0.2482
Model saved with train loss: 0.2482


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 127/150, Loss: 0.2451
Model saved with train loss: 0.2451


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 128/150, Loss: 0.2441
Model saved with train loss: 0.2441


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 129/150, Loss: 0.2399
Model saved with train loss: 0.2399


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 130/150, Loss: 0.2516


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 131/150, Loss: 0.2583


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 132/150, Loss: 0.2570


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 133/150, Loss: 0.2578


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 134/150, Loss: 0.2447


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 135/150, Loss: 0.2345
Model saved with train loss: 0.2345


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 136/150, Loss: 0.2252
Model saved with train loss: 0.2252


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 137/150, Loss: 0.2224
Model saved with train loss: 0.2224


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 138/150, Loss: 0.2235


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 139/150, Loss: 0.2174
Model saved with train loss: 0.2174


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 140/150, Loss: 0.2258


100%|██████████| 11/11 [00:19<00:00,  1.77s/it]


Epoch 141/150, Loss: 0.2186


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 142/150, Loss: 0.2121
Model saved with train loss: 0.2121


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 143/150, Loss: 0.2059
Model saved with train loss: 0.2059


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 144/150, Loss: 0.2027
Model saved with train loss: 0.2027


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 145/150, Loss: 0.1976
Model saved with train loss: 0.1976


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 146/150, Loss: 0.1983


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 147/150, Loss: 0.1954
Model saved with train loss: 0.1954


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 148/150, Loss: 0.1860
Model saved with train loss: 0.1860


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 149/150, Loss: 0.1797
Model saved with train loss: 0.1797


100%|██████████| 11/11 [00:19<00:00,  1.78s/it]


Epoch 150/150, Loss: 0.1749
Model saved with train loss: 0.1749


In [7]:
from IPython.display import FileLink
FileLink(r'best_model.pth')

In [12]:
import torch

model = UNet().cuda()
best_model_path = "/kaggle/input/unet_original/pytorch/default/1/original_data.pth"
model.load_state_dict(torch.load(best_model_path))
model.eval()
print("Model loaded successfully!")


  model.load_state_dict(torch.load(best_model_path))


Model loaded successfully!


In [13]:
from skimage.metrics import structural_similarity as ssim
import numpy as np
from sklearn.metrics import confusion_matrix

def evaluate_model(model, val_loader):
    model.eval()
    accuracies = []
    f1_scores = []
    ssims = []
    psnrs = []
    sensitivities = []
    specificities = []
    
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.cuda()
            masks = masks.cuda()
            outputs = model(images)
            preds = (outputs > 0.5).float()

            # Convert to numpy for accuracy_func
            preds_np = preds.cpu().numpy().squeeze() * 255
            masks_np = masks.cpu().numpy().squeeze() * 255

            # Compute metrics
            for pred, mask in zip(preds_np, masks_np):
                pred = pred.astype(np.uint8)
                mask = mask.astype(np.uint8)

                ssim_score = ssim(mask, pred, data_range=255)
                psnr_score = float('inf')
                mse = np.mean((pred - mask) ** 2)
                if mse != 0:
                    max_pixel = 255.0
                    psnr_score = 20 * np.log10(max_pixel / np.sqrt(mse))

                generated_binary = (pred > 0).astype(np.uint8)
                target_binary = (mask > 0).astype(np.uint8)

                tn, fp, fn, tp = confusion_matrix(target_binary.ravel(), generated_binary.ravel()).ravel()
                accuracy = (tp + tn) / (tp + tn + fp + fn)
                sensitivity = tp / (tp + fn + 1e-7)
                specificity = tn / (tn + fp + 1e-7)
                f1_score = 2 / ((1 / specificity) + (1 / sensitivity))

                ssims.append(ssim_score)
                psnrs.append(psnr_score)
                accuracies.append(accuracy)
                sensitivities.append(sensitivity)
                specificities.append(specificity)
                f1_scores.append(f1_score)

    # Print mean metrics
    print(f"Mean SSIM: {np.mean(ssims):.4f}")
    print(f"Mean PSNR: {np.mean(psnrs):.4f}")
    print(f"Mean Accuracy: {np.mean(accuracies):.4f}")
    print(f"Mean Sensitivity: {np.mean(sensitivities):.4f}")
    print(f"Mean Specificity: {np.mean(specificities):.4f}")
    print(f"Mean F1 Score: {np.mean(f1_scores):.4f}")

# Call the evaluate function
evaluate_model(model, val_loader)


Mean SSIM: 0.8721
Mean PSNR: 62.5135
Mean Accuracy: 0.9634
Mean Sensitivity: 0.6191
Mean Specificity: 0.9903
Mean F1 Score: 0.7543
