In [None]:
import os
import sys

import matplotlib.pyplot as plt
import torch

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../src")))

from src.data_modules.segmentation_data_module import (
    CbisDdsmDataModuleSegmentation as SegmentationDataModule,
)
from src.models.segmentation_model import SegmentationModel

In [None]:
datamodule = SegmentationDataModule(
    root_dir="../data/cbis-ddsm-segme", tumor_type="mass", batch_size=1, num_workers=4
)

In [None]:
for batch in datamodule.train_dataloader():
    x, y = batch[0][0], batch[1][0]

    image = x.squeeze().numpy()
    mask = y.squeeze().numpy()

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(image, cmap="gray")
    axes[1].imshow(mask[1], cmap="gray")

    plt.show()
    break

In [None]:
model = SegmentationModel(weight_path="../models/mass-segmentation.ckpt")

In [None]:
# Get a validation batch
for batch in datamodule.val_dataloader():
    x, y = batch
    x, y = x[0], y[0]

    # Prepare image and mask for visualization
    image = x.squeeze().numpy()
    mask = y.squeeze().numpy()

    # Get model predictions
    model.eval()
    with torch.no_grad():
        prediction = model(x.unsqueeze(0).to(model.device))
        prediction = torch.argmax(prediction, dim=1)
        prediction = prediction.squeeze().numpy()

    # Plot results
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(image, cmap="gray")
    axes[0].set_title("Image")
    axes[1].imshow(mask[1], cmap="gray")
    axes[1].set_title("Ground Truth Mask")
    axes[2].imshow(prediction, cmap="gray")
    axes[2].set_title("Model Prediction")
    plt.show()

    break