In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import cv2
from torch.utils.data import DataLoader
from dataset import LiverTumorDataset
from unet import UNet

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
model.load_state_dict(torch.load("unet_liver_final.pth"))
model.eval()
print("Loaded final model for inference.")

Loaded final model for inference.


In [5]:
val_dataset = LiverTumorDataset("val_balanced.txt", image_size=(256, 256))
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

In [7]:
def dice_score(pred, target, smooth=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def iou_score(pred, target, smooth=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + smooth) / (union + smooth)

In [9]:
dice_scores = []
iou_scores = []

with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        preds = (outputs > 0.5).float()

        dice_scores.append(dice_score(preds, masks).item())
        iou_scores.append(iou_score(preds, masks).item())

print(f"\nValidation Dice Score: {np.mean(dice_scores):.4f}")
print(f"Validation IoU Score : {np.mean(iou_scores):.4f}")


Validation Dice Score: 0.9022
Validation IoU Score : 0.8706


In [11]:
os.makedirs("val_predictions", exist_ok=True)

for i in range(5):
    image, mask = val_dataset[i]
    with torch.no_grad():
        pred = model(image.unsqueeze(0).to(device))
        pred = (pred > 0.5).float().cpu().numpy().squeeze()

    image_np = image.squeeze().numpy() * 255
    mask_np = mask.squeeze().numpy() * 255
    pred_np = pred * 255

    stacked = np.hstack([
        cv2.cvtColor(image_np.astype(np.uint8), cv2.COLOR_GRAY2BGR),
        cv2.cvtColor(mask_np.astype(np.uint8), cv2.COLOR_GRAY2BGR),
        cv2.cvtColor(pred_np.astype(np.uint8), cv2.COLOR_GRAY2BGR)
    ])

    cv2.imwrite(f"val_predictions/sample_{i}.png", stacked)

print("Saved visual predictions to val_predictions/")

Saved visual predictions to val_predictions/
