Import all

In [None]:
pip install nibabel numpy scipy matplotlib scikit-image torch torchvision pandas

GPU Available?

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

ResNet50


In [None]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import scipy.ndimage

CSV_PATH = r"D:/ADNI Data/BETandGroups.csv" 
IMG_SIZE = (128, 128, 128)
BATCH_SIZE = 1
EPOCHS = 30
LR = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Nifti3DDataset(Dataset):
    def __init__(self, df, label_encoder):
        self.df = df.reset_index(drop=True)
        self.label_encoder = label_encoder
        self.target_shape = IMG_SIZE

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.loc[idx, "Ruta"]
        label = self.df.loc[idx, "Grupo"]

        volume = nib.load(path).get_fdata().astype(np.float32)

        zoom_factors = [t / s for t, s in zip(self.target_shape, volume.shape)]
        volume = scipy.ndimage.zoom(volume, zoom=zoom_factors, order=3)

        volume = (volume - np.mean(volume)) / (np.std(volume) + 1e-8)
        volume = np.clip(volume, -3, 3)

        volume = torch.tensor(volume).unsqueeze(0) 
        label = torch.tensor(self.label_encoder.transform([label])[0])

        return volume, label

# Architectura Bottleneck
class Bottleneck3D(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)

        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)

        self.conv3 = nn.Conv3d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(out_channels * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        return self.relu(out)

# Arquitectura ResNet3D
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels * block.expansion),
            )

        layers = [block(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

def resnet50_3d(num_classes):
    return ResNet3D(Bottleneck3D, [3, 4, 6, 3], num_classes)

#Carga de datos
df = pd.read_csv(CSV_PATH)
le = LabelEncoder()
df['labels'] = le.fit_transform(df['Grupo'])

df_train, df_val = train_test_split(df, test_size=0.25, stratify=df['labels'], random_state=42)
train_dataset = Nifti3DDataset(df_train, le)
val_dataset = Nifti3DDataset(df_val, le)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

#Entrenar el modelo
model = resnet50_3d(num_classes=len(le.classes_)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_correct = 0, 0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        train_correct += (outputs.argmax(1) == labels).sum().item()

    acc_train = 100 * train_correct / len(train_dataset)
    print(f"Epoch {epoch+1} - Loss: {train_loss:.4f}, Train Acc: {acc_train:.2f}%")
    torch.save(model.state_dict(), "resnet50_3d_alzheimer.pth")
    print("Modelo guardado como resnet50_3d_alzheimer.pth")
    model.eval()
    val_correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            val_correct += (outputs.argmax(1) == labels).sum().item()
    acc_val = 100 * val_correct / len(val_dataset)
    print(f"Test Acc: {acc_val:.2f}%")

# ========= GUARDAR =========
torch.save(model.state_dict(), "resnet50_3d_alzheimer.pth")
print("Modelo guardado como resnet50_3d_alzheimer.pth")
