In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from time import time
from thop import profile
import pandas as pd
from tqdm import tqdm
import cv2
import json
import argparse

# Add parent directory to path
import sys
sys.path.append('../src')

# Import project modules
from models.unet import UNet
from models.enhanced_unet import EnhancedUNet, SpatialAttentionUNet, UltraLightUNet
from utils.metrics import dice_coefficient, iou_coefficient, pixel_accuracy
from data.isic_dataset import load_isic_data
from data.busi_dataset import load_busi_data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [26]:
DATASET = 'isic'



if DATASET == 'isic':
    data = load_isic_data(dataset_path='../data/isic_2018_task1_data')
    
elif DATASET == 'busi':
    data = load_busi_data(dataset_path='../data/busi_dataset/Dataset_BUSI_with_GT')
else:
    raise ValueError("Unsupported dataset name")

test_loader = data['test_loader']
dataset_name = DATASET
models = {
    'unet_standard': UNet(n_channels=3, n_classes=1).to(device),
    'unet_with_depthwise': EnhancedUNet(n_channels=3, n_classes=1, use_se=False, use_lightweight=True).to(device),
    'unet_with_se_depthwise': EnhancedUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True).to(device),
    'unet_with_se_depthwise_reduced': EnhancedUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True, se_reduction=32).to(device),
    'unet_with_spatial_attn': SpatialAttentionUNet(n_channels=3, n_classes=1, use_se=True, use_lightweight=True).to(device),
}

Found 2594 image-mask pairs
Train: 1815, Validation: 389, Test: 390


In [27]:
for model_name, model in models.items():
    model_path = os.path.join(f'../saved_models/{DATASET}', model_name, 'best_model.pth')
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print(f"Loaded model: {model_name}")
    else:
        print(f"⚠️  Model file not found for {model_name}: {model_path}")

Loaded model: unet_standard
Loaded model: unet_with_depthwise
Loaded model: unet_with_se_depthwise
Loaded model: unet_with_se_depthwise_reduced
Loaded model: unet_with_spatial_attn


In [28]:
def count_parameters(model):
    """Count the number of trainable parameters in the model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_time(model, input_size=(1, 3, 256, 256), num_iterations=50):
    """Measure average inference time over multiple iterations"""
    dummy_input = torch.randn(input_size).to(device)
    
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)
    
    torch.cuda.synchronize()
    start_time = time()
    
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(dummy_input)
    
    torch.cuda.synchronize()
    end_time = time()
    
    return (end_time - start_time) / num_iterations

efficiency_metrics = {}
eval_type = 'both'
if eval_type in ['efficiency', 'both']:
    print("\n=== Computing Efficiency Metrics ===")
    
    for model_name, model in models.items():
        # Parameters
        params = count_parameters(model)
        
        # FLOPs (Floating Point Operations)
        dummy_input = torch.randn(1, 3, 256, 256).to(device)
        flops, _ = profile(model, inputs=(dummy_input,), verbose=False)
        
        inference_time = measure_inference_time(model)
        
        efficiency_metrics[model_name] = {
            'params': params,
            'flops': flops,
            'inference_time': inference_time
        }
        
        print(f"{model_name}:")
        print(f"  Parameters: {params:,}")
        print(f"  FLOPs: {flops / 1e9:.2f} G")
        print(f"  Inference time: {inference_time * 1000:.2f} ms")

    os.makedirs(f'../results/{dataset_name}', exist_ok=True)
    with open(f'../results/{dataset_name}/efficiency_metrics.json', 'w') as f:
        serializable_metrics = {}
        for model_name, metrics in efficiency_metrics.items():
            serializable_metrics[model_name] = {
                'params': int(metrics['params']),
                'flops': float(metrics['flops']),
                'inference_time': float(metrics['inference_time'])
            }
        json.dump(serializable_metrics, f, indent=2)

print("\n=== Efficiency Summary ===")
for model_name, metrics in efficiency_metrics.items():
    print(f"{model_name:<30} | Params: {metrics['params']:,} | "
          f"FLOPs: {metrics['flops'] / 1e9:.2f} G | "
          f"Inference: {metrics['inference_time'] * 1000:.2f} ms")



=== Computing Efficiency Metrics ===
unet_standard:
  Parameters: 31,037,633
  FLOPs: 54.74 G
  Inference time: 7.10 ms
unet_with_depthwise:
  Parameters: 5,988,252
  FLOPs: 14.14 G
  Inference time: 4.02 ms
unet_with_se_depthwise:
  Parameters: 6,206,364
  FLOPs: 14.16 G
  Inference time: 5.78 ms
unet_with_se_depthwise_reduced:
  Parameters: 6,097,308
  FLOPs: 14.16 G
  Inference time: 5.73 ms
unet_with_spatial_attn:
  Parameters: 6,206,756
  FLOPs: 14.16 G
  Inference time: 6.33 ms

=== Efficiency Summary ===
unet_standard                  | Params: 31,037,633 | FLOPs: 54.74 G | Inference: 7.10 ms
unet_with_depthwise            | Params: 5,988,252 | FLOPs: 14.14 G | Inference: 4.02 ms
unet_with_se_depthwise         | Params: 6,206,364 | FLOPs: 14.16 G | Inference: 5.78 ms
unet_with_se_depthwise_reduced | Params: 6,097,308 | FLOPs: 14.16 G | Inference: 5.73 ms
unet_with_spatial_attn         | Params: 6,206,756 | FLOPs: 14.16 G | Inference: 6.33 ms


In [29]:
def evaluate_model(model, test_loader, device):
    """Evaluate a model on the test set and return average Dice, IoU, and accuracy"""
    model.eval()
    dice_scores = []
    iou_scores = []
    accuracy_scores = []

    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc="Evaluating", leave=False):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)

            # Calculate metrics
            dice = dice_coefficient(outputs, masks)
            iou = iou_coefficient(outputs, masks)
            acc = pixel_accuracy(outputs, masks)

            dice_scores.append(dice)
            iou_scores.append(iou)
            accuracy_scores.append(acc)

    return {
        'dice': round(np.mean(dice_scores) * 100, 2),
        'iou': round(np.mean(iou_scores) * 100, 2),
        'accuracy': round(np.mean(accuracy_scores) * 100, 2)
    }

all_eval_results = {}

for name, model in models.items():
    print(f"Evaluating {name}...")
    metrics = evaluate_model(model, test_loader, device)
    all_eval_results[name] = metrics
    print(f"{name} - Dice: {metrics['dice']}%, IoU: {metrics['iou']}%, Acc: {metrics['accuracy']}%")

Evaluating unet_standard...


                                                                                                                

unet_standard - Dice: 87.98%, IoU: 78.77%, Acc: 95.27%
Evaluating unet_with_depthwise...


                                                                                                                

unet_with_depthwise - Dice: 88.51%, IoU: 79.54%, Acc: 95.27%
Evaluating unet_with_se_depthwise...


                                                                                                                

unet_with_se_depthwise - Dice: 89.2%, IoU: 80.65%, Acc: 95.56%
Evaluating unet_with_se_depthwise_reduced...


                                                                                                                

unet_with_se_depthwise_reduced - Dice: 89.12%, IoU: 80.52%, Acc: 95.47%
Evaluating unet_with_spatial_attn...


                                                                                                                

unet_with_spatial_attn - Dice: 88.91%, IoU: 80.19%, Acc: 95.42%


