In [8]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from os.path import join
from PIL import Image

from data_load import Dataset
from model.LN_UXFormer import LN_UXFormer


def dice_coefficient(pred_mask, true_mask, smooth=1e-6):
    pred_flat = pred_mask.reshape(-1)
    target_flat = true_mask.reshape(-1)
    
    intersection = torch.sum(pred_flat * target_flat)
    return (2.0 * intersection + smooth) / (torch.sum(pred_flat) + torch.sum(target_flat) + smooth)


def calculate_binary_dice_score(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    
    target = target.contiguous()
    pred_binary = pred_binary.contiguous()
    
    dice = dice_coefficient(pred_binary, target).item()
    
    return {'Foreground': dice}


def save_segmentation_results(images, targets, outputs, batch_idx, batch_size, output_dir, dataset, threshold=0.5):
    os.makedirs(os.path.join(output_dir, 'input'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'target'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'prediction'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'overlay_pred'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'overlay_target'), exist_ok=True)
    
    pred_sigmoid = torch.sigmoid(outputs)
    pred_binary = (pred_sigmoid > threshold).float().cpu().numpy()
    
    for i in range(images.size(0)):
        idx = batch_idx * batch_size + i
        
        if idx >= len(dataset):
            break
        
        input_img = images[i, 0].cpu().numpy()
        input_img = (input_img * 255).astype(np.uint8)
        
        target_mask = targets[i, 0].cpu().numpy()
        pred_mask = pred_binary[i, 0]
        
        colors = {
            'background': [0, 0, 0],
            'foreground': [255, 255, 255],
            'target_overlay': [255, 0, 0],
            'pred_overlay': [0, 255, 0]
        }
        
        target_rgb = np.zeros((target_mask.shape[0], target_mask.shape[1], 3), dtype=np.uint8)
        pred_rgb = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
        
        target_rgb[target_mask > 0] = colors['foreground']
        pred_rgb[pred_mask > 0] = colors['foreground']
        
        input_rgb = np.stack([input_img] * 3, axis=2)
        
        overlay_pred = input_rgb.copy()
        pred_overlay_mask = pred_mask > 0
        overlay_pred[pred_overlay_mask] = (overlay_pred[pred_overlay_mask] * 0.3 + 
                                          np.array(colors['pred_overlay']) * 0.7).astype(np.uint8)
        
        overlay_target = input_rgb.copy()
        target_overlay_mask = target_mask > 0
        overlay_target[target_overlay_mask] = (overlay_target[target_overlay_mask] * 0.3 + 
                                              np.array(colors['target_overlay']) * 0.7).astype(np.uint8)
        
        filename = f"image_{idx:04d}"
        
        Image.fromarray(input_img).save(os.path.join(output_dir, 'input', f"{filename}.png"))
        Image.fromarray(target_rgb).save(os.path.join(output_dir, 'target', f"{filename}.png"))
        Image.fromarray(pred_rgb).save(os.path.join(output_dir, 'prediction', f"{filename}.png"))
        Image.fromarray(overlay_pred).save(os.path.join(output_dir, 'overlay_pred', f"{filename}.png"))
        Image.fromarray(overlay_target).save(os.path.join(output_dir, 'overlay_target', f"{filename}.png"))


def evaluate_test_set(model, test_loader, device, save_segmentation=False, output_dir=None, dataset=None, threshold=0.5):
    model.eval()
    dice_scores = []
    
    if save_segmentation and output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
    
    print(f"Starting evaluation... Total {len(test_loader)} batches")
    
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(test_loader)):
            images, targets = batch_data
            
            images = images.to(device)
            targets = targets.to(device)
            
            try:
                outputs = model(images)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                
                if batch_idx == 0:
                    print(f"Model output shape: {outputs.shape}")
                    print(f"Target shape: {targets.shape}")
                
                dice_score = calculate_binary_dice_score(outputs, targets, threshold)
                dice_scores.append(dice_score['Foreground'])
                
                if save_segmentation and output_dir is not None:
                    save_segmentation_results(images, targets, outputs, batch_idx, test_loader.batch_size, output_dir, dataset, threshold)
                
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                import traceback
                traceback.print_exc()
                if batch_idx == 0:
                    raise
    
    mean_dice = np.mean(dice_scores)
    
    print("\n" + "="*50)
    print(f"Foreground Dice Score: {mean_dice:.4f}")
    print("="*50)
    
    if output_dir is not None:
        with open(os.path.join(output_dir, 'evaluation_results.txt'), 'w') as f:
            f.write(f"Foreground Dice Score: {mean_dice:.4f}\n")
            f.write("\nSaved folders:\n")
            f.write("- input: Original input images\n")
            f.write("- target: Ground truth segmentation masks\n")
            f.write("- prediction: Predicted segmentation masks\n")
            f.write("- overlay_pred: Input + Prediction overlay (green)\n")
            f.write("- overlay_target: Input + Ground truth overlay (red)\n")
    
    return mean_dice


def main():
    test_dir = '/data3/glocal_project/multiclass_seg_kid_ney/data/covid_test'
    model_path = '/data3/glocal_project/LN_UXFormer/result_bestmodel/LN_UXFormer_epoch.pt'
    gpu_id = 3
    batch_size = 16
    save_segmentation = True
    output_dir = '/data3/glocal_project/LN_UXFormer/result_image'
    threshold = 0.5
    
    if gpu_id >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{gpu_id}")
        print(f"Using GPU {gpu_id}")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    
    try:
        model = LN_UXFormer(n_channels=1, n_classes=1).to(device)
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Model loaded: {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return
    
    try:
        test_dataset = Dataset(test_dir)
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True if device.type == 'cuda' else False
        )
        print(f"Test dataset loaded: {len(test_dataset)} images")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        import traceback
        traceback.print_exc()
        return
    
    try:
        mean_dice = evaluate_test_set(
            model, test_loader, device,
            save_segmentation, output_dir, test_dataset,
            threshold
        )
        
        if save_segmentation:
            print(f"\nSegmentation results saved: {output_dir}")
            print("\nFolder structure:")
            print("├── input/           : Original input images")
            print("├── target/          : Ground truth segmentation masks")
            print("├── prediction/      : Predicted segmentation masks")
            print("├── overlay_pred/    : Input + Prediction overlay (green)")
            print("└── overlay_target/  : Input + Ground truth overlay (red)")
            
    except Exception as e:
        print(f"Error during evaluation: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

Using GPU 3
Model loaded: /data3/glocal_project/LN_UXFormer/result_bestmodel/LN_UXFormer_epoch.pt
Found 231 valid image pairs
Skipped 0 pairs with empty masks
Test dataset loaded: 231 images
Starting evaluation... Total 15 batches


  0%|          | 0/15 [00:00<?, ?it/s]

Model output shape: torch.Size([16, 1, 224, 224])
Target shape: torch.Size([16, 1, 224, 224])


100%|██████████| 15/15 [00:08<00:00,  1.68it/s]


Foreground Dice Score: 0.6227

Segmentation results saved: /data3/glocal_project/LN_UXFormer/result_image

Folder structure:
├── input/           : Original input images
├── target/          : Ground truth segmentation masks
├── prediction/      : Predicted segmentation masks
├── overlay_pred/    : Input + Prediction overlay (green)
└── overlay_target/  : Input + Ground truth overlay (red)



