<a href="https://colab.research.google.com/github/profsuccodifrutta/patch_core_brain_mri/blob/main/patch_core_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import zipfile
import os

drive.mount('/content/drive')


In [None]:
zip_path = '/content/drive/MyDrive/brainmri.zip'
extract_path = '/content/brain_dataset' # Cartella locale di Colab

if not os.path.exists(extract_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("✅ Estrazione completata!")

In [None]:
!pip install anomalib
!pip install faiss-cpu

In [None]:
import torch
from pathlib import Path
from anomalib.data import Folder
from anomalib.models import Patchcore
from anomalib.engine import Engine


DATA_PATH = Path('/content/brain_dataset')

datamodule = Folder(
    name="brain_mri",
    root=DATA_PATH,
    normal_dir="Training/notumor",      # Sani per il training
    normal_test_dir="Testing/notumor",  # Sani per il test
    abnormal_dir="Testing",             # Tutte le altre cartelle in Testing sono anomalie
    train_batch_size=8,
    eval_batch_size=8,
    num_workers=2
)
datamodule.setup()

print("✅ Configurazione completata. Pronto per l'estrazione delle feature.")

In [None]:
# PatchCore
model = Patchcore(
    backbone="resnet18",    #"wide_resnet50_2"
    layers=["layer2", "layer3"],
    coreset_sampling_ratio=0.01
)


engine = Engine(
    max_epochs=1,
    devices=1,
    accelerator="auto",
    enable_progress_bar=False
)

# Crea la memoria dei sani (Training)
print(" Avvio estrazione feature e creazione Memory Bank...")

engine.fit(model=model, datamodule=datamodule)

print(" Memory Bank creata con successo!")

In [None]:
#  Confronta i malati con la memoria (Inference)
print(" Valutazione sulle anomalie...")
results = engine.test(model=model, datamodule=datamodule)

print("\n--- RISULTATI FINALI ---")
print(results)

In [None]:
# salva il modello
from anomalib.engine import Engine
path_salvataggio = "/content/drive/MyDrive/patchcore_modelresnet18.ckpt"
engine.trainer.save_checkpoint(path_salvataggio)
print(f"✅ Modello salvato correttamente in: {path_salvataggio}")

In [None]:
from anomalib.models import Patchcore
from anomalib.engine import Engine
from anomalib.data import Folder

# inferenza senza dover ricreare la memory bank
# Inizializzi il modello
path_checkpoint = "/content/drive/MyDrive/patchcore_modelresnet18.ckpt"
model = Patchcore(backbone="resnet18", layers=["layer2", "layer3"])


engine = Engine(devices=1,enable_progress_bar=False )


predictions = engine.predict(model=model, datamodule=datamodule, ckpt_path=path_checkpoint)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random


all_samples = []

for batch in predictions:

    imgs = batch["image"] if isinstance(batch, dict) else batch.image
    masks = batch["pred_mask"] if isinstance(batch, dict) else batch.pred_mask
    amaps = batch["anomaly_map"] if isinstance(batch, dict) else batch.anomaly_map


    if isinstance(batch, dict):
        lbls = batch.get("gt_label", batch.get("label"))
    else:
        lbls = getattr(batch, "gt_label", getattr(batch, "label", None))

    if lbls is None: continue


    for k in range(imgs.shape[0]):
        all_samples.append({
            "image": imgs[k],
            "mask": masks[k],
            "amap": amaps[k],
            "label": lbls[k]
        })


num_to_show = 3
if len(all_samples) < num_to_show: num_to_show = len(all_samples)
selected_samples = random.sample(all_samples, num_to_show)

print(f" Mostro {num_to_show} immagini casuali...")


for sample in selected_samples:


    is_sick_truth = (sample["label"].item() == 1)
    mask_val = sample["mask"].cpu().numpy().squeeze()
    is_sick_pred = (mask_val.max() > 0)


    txt_truth = "Non-Healthy" if is_sick_truth else "Healthy"
    txt_pred = "Non-Healthy" if is_sick_pred else "Healthy"

    colore = "green" if (is_sick_truth == is_sick_pred) else "red"
    titolo = f"Truth: {txt_truth}  |  Prediction: {txt_pred}"


    fig, axes = plt.subplots(1, 3, figsize=(10, 3))
    fig.suptitle(titolo, color=colore, fontweight='bold', fontsize=11, y=0.98)


    img = sample["image"].cpu().permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())

    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title("Original", fontsize=9)


    amap = sample["amap"].cpu().numpy().squeeze()
    axes[1].imshow(img, cmap='gray')
    axes[1].imshow(amap, cmap='jet', alpha=0.5)
    axes[1].axis('off')
    axes[1].set_title("Heatmap", fontsize=9)


    axes[2].imshow(mask_val, cmap='gray')
    axes[2].axis('off')
    axes[2].set_title("Mask", fontsize=9)

    plt.tight_layout()
    plt.show()