Import all

In [None]:
pip install nibabel numpy scipy matplotlib scikit-image torch torchvision pandas

GPU Available?

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

AlzeNet

In [None]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch.nn.functional as F


BATCH_SIZE = 4
EPOCHS = 50
LR = 1e-4
TARGET_SHAPE = (128, 128, 128)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class NiftiDataset(Dataset):
    def __init__(self, df, label_encoder):
        self.df = df.reset_index(drop=True)
        self.label_encoder = label_encoder

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.loc[idx, "Ruta"]
        label = self.df.loc[idx, "Grupo"]

        # Cargar
        img = nib.load(path)
        volume = img.get_fdata().astype(np.float32)

        # Normalización
        volume = (volume - np.mean(volume)) / (np.std(volume) + 1e-8)

        # Resize 
        volume = self.resize_volume(volume, TARGET_SHAPE)
        volume = np.expand_dims(volume, axis=0)

        return torch.tensor(volume), torch.tensor(self.label_encoder.transform([label])[0])

    def resize_volume(self, volume, target_shape):
        volume = torch.tensor(volume).unsqueeze(0).unsqueeze(0)  # [1,1,D,H,W]
        volume_resized = F.interpolate(volume, size=target_shape, mode="trilinear", align_corners=False)
        return volume_resized.squeeze().numpy()

#Arquitectura AlzeNet3D
class AlzeNet3D(nn.Module):
    def __init__(self, num_classes=3):
        super(AlzeNet3D, self).__init__()
        self.features = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool3d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 16 * 16 * 16, 128), nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

#Leer Datos
df = pd.read_csv(r"D:/ADNI Data/BETandGroups.csv")
le = LabelEncoder()
df['labels'] = le.fit_transform(df['Grupo'])  # Codificar grupos
df_train, df_val = train_test_split(df, test_size=0.25, stratify=df['labels'], random_state=42)

train_dataset = NiftiDataset(df_train, le)
val_dataset = NiftiDataset(df_val, le)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Entrenamiento
model = AlzeNet3D(num_classes=len(le.classes_)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_correct = 0, 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        train_correct += (preds == labels).sum().item()

    acc_train = 100 * train_correct / len(train_dataset)
    print(f"Epoch {epoch+1} - Loss: {train_loss:.4f}, Train Acc: {acc_train:.2f}%")

    # Validación
    model.eval()
    val_correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
    acc_val = 100 * val_correct / len(val_dataset)
    print(f"Test Acc: {acc_val:.2f}%")

# ========== GUARDAR ==========
torch.save(model.state_dict(), "alzenet3d_nifti.pth")
print("Modelo guardado como alzenet3d_nifti.pth")
