#  Prostate WSI Segmentation - Inference Demo

This notebook shows how to use the trained model for inference on new WSI images.


In [None]:
# Setup
import sys
sys.path.append('../src')

from config import InferenceConfig
from inference import predict_wsi, generate_test_predictions
from utils import visualize_predictions

import torch
import matplotlib.pyplot as plt

In [None]:
# Load configuration
config = InferenceConfig()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Using device: {device}")

In [None]:
# Example: Predict single WSI
from model import SegmentationModel

# Load model
model = SegmentationModel(num_classes=4, encoder_name="resnet34", pretrained=False)
checkpoint = torch.load(config.BEST_MODEL_PATH, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

print(f" Model loaded! Best IoU: {checkpoint['best_wsi_iou']:.4f}")

In [None]:
# Predict on test WSI
import glob

test_images = glob.glob(str(config.TEST_DIR / "*.png"))
if test_images:
    sample_wsi = test_images[0]
    print(f" Predicting: {sample_wsi}")
    
    prediction = predict_wsi(model, sample_wsi, device=device)
    
    print(f" Prediction shape: {prediction.shape}")
    print(f" Classes found: {np.unique(prediction)}")
else:
    print(" No test images found")