<a href="https://colab.research.google.com/github/profsuccodifrutta/Variational_Autoencoder_for_Anomaly_Detection/blob/main/fifth_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

In [None]:
# Crea una cartella locale su Colab per i dati (veloce)
!mkdir -p /content/dataset_local

# Scompatta il file.
path_zip = "/content/drive/MyDrive/brainmri.zip"

!unzip -o -q "{path_zip}" -d /content/dataset_local

print("Scompattamento completato!")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import glob
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

file_sani = glob.glob("/content/dataset_local/**/Training/notumor/*.jpg", recursive=True)
print(f"--- ANALISI DATASET ---")
print(f"Totale immagini sane trovate: {len(file_sani)}")

class BrainDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        image = Image.open(img_path).convert('L') # Scala di grigi
        if self.transform:
            image = self.transform(image)
        return image

# 70% Train, 20% Val, 10% Test
if len(file_sani) > 0:
    full_healthy_ds = BrainDataset(file_sani)

    train_size = int(0.7 * len(full_healthy_ds))
    val_size = int(0.2 * len(full_healthy_ds))
    test_size = len(full_healthy_ds) - train_size - val_size

    train_subset, val_subset, test_healthy_subset = random_split(
        full_healthy_ds, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    # Recupero anomalie per il Test finale
    file_anomalie = glob.glob("/content/dataset_local/**/Testing/*/*.jpg", recursive=True)
    file_anomalie = [f for f in file_anomalie if "notumor" not in f]

    print(f"Training: {len(train_subset)} | Val: {len(val_subset)} | Test Sani: {len(test_healthy_subset)}")
    print(f"Anomalie trovate per test: {len(file_anomalie)}")
else:
    print("ERRORE: Dataset non trovato.")

In [None]:
# 4. CONFIGURAZIONE TRASFORMAZIONI
# Rimuoviamo la normalizzazione mean/std per usare il range [0, 1] compatibile con Sigmoid
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),       # Resize diretto a 224
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),               # Porta i pixel da [0, 255] a [0.0, 1.0]
])

base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


batch_size = 32

train_loader = DataLoader(
    BrainDataset([train_subset.dataset.file_list[i] for i in train_subset.indices], transform=train_transform),
    batch_size=batch_size, shuffle=True)

val_loader = DataLoader(
    BrainDataset([val_subset.dataset.file_list[i] for i in val_subset.indices], transform=base_transform),
    batch_size=batch_size, shuffle=False)

# Test loaders
test_loader_sani = DataLoader(
    BrainDataset([test_healthy_subset.dataset.file_list[i] for i in test_healthy_subset.indices], transform=base_transform),
    batch_size=batch_size, shuffle=False)

anno_loader = DataLoader(
    BrainDataset(file_anomalie, transform=base_transform),
    batch_size=batch_size, shuffle=False)

print(f" Configurazione completata. Batch size: {batch_size}")

In [None]:
import torch
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels)
        )
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.relu(x + self.conv(x))

class VAE_V5(nn.Module):
    def __init__(self, latent_dim=1024):
        super(VAE_V5, self).__init__()

        # ENCODER: 224x224 -> 112x112 -> 56x56 -> 28x28 -> 14x14 -> 7x7
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1), # 112
            nn.BatchNorm2d(32), nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, 3, stride=2, padding=1), # 56
            nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            ResBlock(64),

            nn.Conv2d(64, 128, 3, stride=2, padding=1), # 28
            nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            ResBlock(128),

            nn.Conv2d(128, 256, 3, stride=2, padding=1), # 14
            nn.BatchNorm2d(256), nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, stride=2, padding=1), # 7x7
            nn.BatchNorm2d(256), nn.LeakyReLU(0.2)
        )

        self.flatten_dim = 256 * 7 * 7 # 12.544 (molto più gestibile di 100k!)
        self.fc_mu = nn.Linear(self.flatten_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flatten_dim, latent_dim)

        # DECODER
        self.decoder_input = nn.Linear(latent_dim, self.flatten_dim)
        self.unflatten = nn.Unflatten(1, (256, 7, 7))

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1), # 14
            nn.BatchNorm2d(256), nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), # 28
            nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            ResBlock(128),

            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # 56
            nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            ResBlock(64),

            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 112
            nn.BatchNorm2d(32), nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1), # 224
            nn.Sigmoid() # Output in [0, 1] per matchare i nuovi dati
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        h = torch.flatten(h, start_dim=1)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(self.unflatten(self.decoder_input(z)))
        return x_recon, mu, logvar

In [None]:
import torch
import torch.nn.functional as F

def ssim_loss(img1, img2, window_size=11):
    mu1 = F.avg_pool2d(img1, window_size, stride=1, padding=window_size//2)
    mu2 = F.avg_pool2d(img2, window_size, stride=1, padding=window_size//2)

    sigma1_sq = F.avg_pool2d(img1 * img1, window_size, stride=1, padding=window_size//2) - mu1.pow(2)
    sigma2_sq = F.avg_pool2d(img2 * img2, window_size, stride=1, padding=window_size//2) - mu2.pow(2)
    sigma12 = F.avg_pool2d(img1 * img2, window_size, stride=1, padding=window_size//2) - (mu1 * mu2)

    c1, c2 = 0.01**2, 0.03**2 # Costanti di stabilità

    num = (2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)
    den = (mu1.pow(2) + mu2.pow(2) + c1) * (sigma1_sq + sigma2_sq + c2)

    return 1 - (num / den).mean()

def loss_function(recon_x, x, mu, logvar, beta=0.01):
    #  BCE (Pixel intensity)
    # reduction='sum' diviso per batch_size mantiene i valori stabili
    bce = F.binary_cross_entropy(recon_x, x, reduction='sum') / x.shape[0]

    #  SSIM (Structural integrity)
    # Moltiplichiamo per 500-1000 per bilanciarla con la scala della BCE
    ssim = ssim_loss(recon_x, x) * 1000

    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.shape[0]

    total_loss = bce + ssim + (beta * kl)

    return total_loss, bce, kl

In [None]:
import torch.optim as optim

latent_dim = 1024
model = VAE_V5(latent_dim=latent_dim).to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) # lr cauto

# Scheduler: Riduce il LR del 50% (factor=0.5) se la Val Loss non migliora per 5 epoche
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)


In [None]:
def plot_reconstruction(model, val_loader, device, epoch):
    model.eval()
    with torch.no_grad():
        # batch dal loader di validazione
        inputs = next(iter(val_loader))
        inputs = inputs.to(device)

        # Ricostruzione
        recons, mu, logvar = model(inputs)

        # su cpu per plottare
        img = inputs[0].cpu().squeeze().numpy()
        recon = recons[0].cpu().squeeze().numpy()

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(img, cmap='gray')
        plt.title("Originale")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(recon, cmap='gray')
        plt.title(f"Ricostruzione Epoca {epoch+1}")
        plt.axis('off')

        plt.show()

In [None]:
import numpy as np

# training loop

def train_vae(model, train_loader, val_loader, optimizer, scheduler, num_epochs):
    train_history = {'total': [], 'recon': [], 'kl': []}
    val_history = {'total': [], 'recon': [], 'kl': []}

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()

        current_beta = min(target_beta, target_beta * (epoch / warmup_epochs))

        train_total, train_recon, train_kl = 0, 0, 0

        for batch in train_loader:
            batch = batch.to(device)

            optimizer.zero_grad()
            recon_batch, mu, logvar = model(batch)

            loss, r_loss, k_loss = loss_function(recon_batch, batch, mu, logvar, beta=current_beta)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_total += loss.item()
            train_recon += r_loss.item()
            train_kl += k_loss.item()

        # VALIDAZIONE
        model.eval()
        val_total, val_recon, val_kl = 0, 0, 0

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                recon_batch, mu, logvar = model(batch)
                loss, r_loss, k_loss = loss_function(recon_batch, batch, mu, logvar, beta=current_beta)

                val_total += loss.item()
                val_recon += r_loss.item()
                val_kl += k_loss.item()

        # Calcolo medie per epoca
        avg_train_loss = train_total / len(train_loader)
        avg_val_loss = val_total / len(val_loader)

        scheduler.step(avg_val_loss)

        # Salvataggio storia
        train_history['total'].append(avg_train_loss)
        val_history['total'].append(avg_val_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}] | Beta: {current_beta:.5f} | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), '/content/drive/MyDrive/VAE_Brain_Project/best_model_v4.pth')
            print(" Miglior modello salvato!")


        if (epoch + 1) % 5 == 0:
            print(f"\n--- Visualizzazione Ricostruzione (Epoca {epoch+1}) ---")
            plot_reconstruction(model, val_loader, device, epoch)

            plt.figure(figsize=(8, 4))
            plt.plot(train_history['total'], label='Train Loss')
            plt.plot(val_history['total'], label='Val Loss')
            plt.title(f"Andamento Loss fino a Epoca {epoch+1}")
            plt.legend()
            plt.show()

    return train_history, val_history

# Esecuzione
# history = train_vae(model, train_loader, val_loader, optimizer, scheduler, num_epochs)

In [None]:
# TRAINING
num_epochs = 80
target_beta = 0.005
warmup_epochs = 20

history_train, history_val = train_vae(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs
)

In [None]:
# CARICAMENTO MODELLO
from google.colab import drive
import torch
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 1024
model = VAE_V5(latent_dim=latent_dim).to(device)

model_path_v5 = "/content/drive/MyDrive/VAE_Brain_Project/best_model_v5.pth"

try:
    checkpoint = torch.load(model_path_v5, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    print(f" Modello V5 caricato correttamente!")
    total_params = sum(p.numel() for p in model.parameters())
    print(f" Capacità: {total_params:,} parametri.")


except FileNotFoundError:
    print(f" Errore: Il file '{model_path_v5}' non esiste.")
    print("Controlla se il nome del file o il percorso nel Drive sono corretti.")
except RuntimeError as e:
    print(f" Errore di architettura: Incompatibilità rilevata.")
    print("Assicurati di non aver modificato la classe VAE_V5 rispetto a quando hai salvato i pesi.")

In [None]:
# logica per la classificazione
def get_anomaly_scores(model, loader, device):
    model.eval()
    scores = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            recon, _, _ = model(batch)
            # bce loss as anomaly score
            bce_per_pixel = F.binary_cross_entropy(recon, batch, reduction='none')
            img_error = torch.mean(bce_per_pixel, dim=(1, 2, 3))

            scores.extend(img_error.cpu().numpy())
    return np.array(scores)

In [None]:
import numpy as np
from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, ConfusionMatrixDisplay, accuracy_score, f1_score, recall_score)
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

test_healthy_scores = get_anomaly_scores(model, test_loader_sani, device)
test_anomaly_scores = get_anomaly_scores(model, anno_loader, device)

y_true = np.array([0] * len(test_healthy_scores) + [1] * len(test_anomaly_scores))
y_scores = np.concatenate([test_healthy_scores, test_anomaly_scores])

fpr, tpr, thresholds = roc_curve(y_true, y_scores)
auc_value = roc_auc_score(y_true, y_scores)
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]

y_pred = [1 if s > optimal_threshold else 0 for s in y_scores]

acc_modello = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
conteggio_classi = np.bincount(y_true)
nir = np.max(conteggio_classi) / len(y_true)


fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5))

# ROC Curve
ax1.plot(fpr, tpr, color='darkred', lw=2, label=f'AUC = {auc_value:.3f}')
ax1.plot([0, 1], [0, 1], color='gray', linestyle='--')
ax1.scatter(fpr[optimal_idx], tpr[optimal_idx], color='black', label='Soglia Ottimale')
ax1.set_title('ROC Curve (Capacità Discriminativa)')
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.legend(loc="lower right")

# Istogramma Distribuzione
ax2.hist(test_healthy_scores, bins=30, alpha=0.6, label='Sani (Normali)', color='green', density=True)
ax2.hist(test_anomaly_scores, bins=30, alpha=0.6, label='Anomalie (Tumori)', color='red', density=True)
ax2.axvline(optimal_threshold, color='black', linestyle='--', label=f'Threshold: {optimal_threshold:.4f}')
ax2.set_title('Distribuzione Errori di Ricostruzione (BCE)')
ax2.set_xlabel('Anomaly Score')
ax2.legend()

# Matrice di Confusione
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Sano', 'Tumore'])
disp.plot(ax=ax3, cmap='Reds', values_format='d')
ax3.set_title('Matrice di Confusione')

plt.tight_layout()
plt.show()

# report finale
print(f"\n" + "="*30)
print(f"   PERFORMANCE REPORT V5")
print(f"="*30)
print(f"Accuracy:           {acc_modello:.4f}")
print(f"F1-Score:           {f1:.4f} (Bilanciamento Prec/Rec)")
print(f"Recall (Sensibilità): {recall:.4f} <-- IMPORTANTE")
print(f"AUC Score:          {auc_value:.4f}")
print(f"No Information Rate: {nir:.4f}")
print(f"Soglia calcolata:    {optimal_threshold:.6f}")
print("-" * 30)
print(classification_report(y_true, y_pred, target_names=['Sano', 'Tumore']))

In [None]:
import torch.nn.functional as F

def visualize_mri_anomaly(model, loader_sani, loader_anno, device):
    """
    Visualizza il confronto tra un cervello sano e uno con tumore,
    mostrando l'errore di ricostruzione tramite heatmap BCE.
    """
    model.eval()
    with torch.no_grad():

        img_sano = next(iter(loader_sani))[0:1].to(device)
        img_anno = next(iter(loader_anno))[0:1].to(device)

        # forward pass
        rec_sano, _, _ = model(img_sano)
        rec_anno, _, _ = model(img_anno)

        # mappe errore bce puntuale
        err_sano = F.binary_cross_entropy(rec_sano, img_sano, reduction='none').cpu().squeeze()
        err_anno = F.binary_cross_entropy(rec_anno, img_anno, reduction='none').cpu().squeeze()

        fig, axes = plt.subplots(2, 3, figsize=(18, 10))

        score_sano = err_sano.mean().item()
        score_anno = err_anno.mean().item()

        # caso sano
        axes[0,0].imshow(img_sano.cpu().squeeze(), cmap='gray')
        axes[0,0].set_title("SANO: Originale", fontsize=12)

        axes[0,1].imshow(rec_sano.cpu().squeeze(), cmap='gray')
        axes[0,1].set_title("SANO: Ricostruzione VAE", fontsize=12)

        im1 = axes[0,2].imshow(err_sano, cmap='hot')
        axes[0,2].set_title(f"SANO: Mappa Errore (Score: {score_sano:.5f})", fontsize=12, fontweight='bold')
        fig.colorbar(im1, ax=axes[0,2], fraction=0.046, pad=0.04)

        # caso tumore
        axes[1,0].imshow(img_anno.cpu().squeeze(), cmap='gray')
        axes[1,0].set_title("TUMORE: Originale", fontsize=12)

        axes[1,1].imshow(rec_anno.cpu().squeeze(), cmap='gray')
        axes[1,1].set_title("TUMORE: Ricostruzione VAE", fontsize=12)

        vmax_val = max(err_anno.max(), err_sano.max()) * 0.8
        im2 = axes[1,2].imshow(err_anno, cmap='hot', vmax=vmax_val)
        axes[1,2].set_title(f"TUMORE: Mappa Errore (Score: {score_anno:.5f})", fontsize=12, fontweight='bold', color='red')
        fig.colorbar(im2, ax=axes[1,2], fraction=0.046, pad=0.04)

        for ax in axes.flatten():
            ax.axis('off')

        plt.suptitle(f"Analisi Qualitativa Anomaly Detection", fontsize=16, y=0.95)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()


visualize_mri_anomaly(model, test_loader_sani, anno_loader, device)

In [None]:
# identifica immagini sane che il modello ha scambiato per tumori
def analyze_false_positives(model, loader_sani, threshold, device, n_max=5):
    model.eval()
    fp_list = []
    with torch.no_grad():
        for batch in loader_sani:
            batch = batch.to(device)
            recon, _, _ = model(batch)

            bce_per_img = F.binary_cross_entropy(recon, batch, reduction='none').mean(dim=(1,2,3))

            # Falsi Positivi (Sani con Score > Soglia)
            mask = bce_per_img > threshold
            if mask.any():
                indices = torch.where(mask)[0]
                for idx in indices:
                    fp_list.append({
                        'img': batch[idx].cpu(),
                        'rec': recon[idx].cpu(),
                        'score': bce_per_img[idx].item()
                    })

    # ordine decrescente per errore
    fp_list = sorted(fp_list, key=lambda x: x['score'], reverse=True)

    if not fp_list:
        print(" Non sono stati trovati Falsi Positivi nel set analizzato.")
        return

    print(f" Trovati {len(fp_list)} Falsi Positivi. Visualizzo i primi {min(n_max, len(fp_list))}:")

    # Plotting
    n_plot = min(n_max, len(fp_list))
    fig, axes = plt.subplots(n_plot, 3, figsize=(15, 5 * n_plot))

    if n_plot == 1: axes = [axes] # Gestione caso singola immagine

    for i in range(n_plot):
        img = fp_list[i]['img'].squeeze()
        rec = fp_list[i]['rec'].squeeze()
        score = fp_list[i]['score']
        err_map = torch.abs(img - rec)

        axes[i][0].imshow(img, cmap='gray'); axes[i][0].set_title(f"FP #{i+1} Originale")
        axes[i][1].imshow(rec, cmap='gray'); axes[i][1].set_title(f"Ricostruzione (Fallita)")
        im = axes[i][2].imshow(err_map, cmap='magma'); axes[i][2].set_title(f"Errore (Score: {score:.5f})")
        fig.colorbar(im, ax=axes[i][2])

        for ax in axes[i]: ax.axis('off')

    plt.tight_layout()
    plt.show()


analyze_false_positives(model, test_loader_sani, optimal_threshold, device)