In [1]:
import os
from PIL import Image
import matplotlib.pyplot as plt

In [6]:
PRED_DIR = "/home/ubuntu/project/Tumor_Detection/data/pred_mask_old"
VAL_DIR = "/home/ubuntu/project/Tumor_Detection/data/Validation"
OUTPUT_DIR = "/home/ubuntu/project/Tumor_Detection/visualizations"


os.makedirs(OUTPUT_DIR, exist_ok=True)

# === Identify WSI IDs ===
wsi_ids = sorted({f.split('_')[0] for f in os.listdir(PRED_DIR) if f.endswith('_pred.png')})[:6]

# Column titles
col_titles = ["WSI", "GT Mask", "Pred Mask", "Overlay"]

# === Generate 1 plot per WSI ===
for wsi_id in wsi_ids:
    
    # Paths to all required images
    wsi_path = os.path.join(VAL_DIR, f"{wsi_id}.png")
    gt_mask_path = os.path.join(VAL_DIR, f"{wsi_id}_mask.png")
    pred_mask_path = os.path.join(PRED_DIR, f"{wsi_id}_pred.png")
    overlay_path = os.path.join(PRED_DIR, f"{wsi_id}_overlap.png")

    # Load all images
    wsi_img = Image.open(wsi_path)
    gt_mask_img = Image.open(gt_mask_path)
    pred_mask_img = Image.open(pred_mask_path)
    overlay_img = Image.open(overlay_path)

    # Start figure
    fig, axes = plt.subplots(1, 4, figsize=(28, 8))
    #fig.suptitle(f"Visualization for WSI: {wsi_id}", fontsize=16)

    # Show images
    images = [wsi_img, gt_mask_img, pred_mask_img, overlay_img]
    for i in range(4):
        axes[i].imshow(images[i])
        axes[i].set_title(col_titles[i],fontsize=16)
        axes[i].axis("off")

    # Save figure
    save_path = os.path.join(OUTPUT_DIR, f"{wsi_id}_viz.png")
    plt.tight_layout()
    plt.savefig(save_path,dpi=300, bbox_inches='tight')
    plt.close()
    