In [None]:
import os
import glob
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def load_and_flatten(path, target_size=64):
    img = nib.load(path)
    data = img.get_fdata()

    # If 4D, take first volume
    if data.ndim == 4:
        data = data[..., 0]

    # Middle axial slice
    zmid = data.shape[2] // 2
    slice2d = data[:, :, zmid].astype(np.float32)

    # Normalize
    mean = slice2d.mean()
    std = slice2d.std() + 1e-6
    slice2d = (slice2d - mean) / std

    # To tensor shape (1, H, W)
    t = torch.tensor(slice2d).unsqueeze(0).unsqueeze(0)

    # Resize to 64x64
    t = F.interpolate(t, size=(target_size, target_size), mode="bilinear", align_corners=False)
    t = t.squeeze(0)

    # Flatten to vector 4096
    flat = t.reshape(-1)
    return flat

In [None]:
def build_tensor_dataset(paths, labels):
    X = []
    y = []
    for p, lab in zip(paths, labels):
        X.append(load_and_flatten(p))
        y.append(lab)
    X = torch.stack(X)
    y = torch.tensor(y, dtype=torch.long)
    return TensorDataset(X, y)

train_ds = build_tensor_dataset(train_files, train_labels)
test_ds  = build_tensor_dataset(test_files, test_labels)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=4, shuffle=False)

model = nn.Sequential(
    nn.Linear(4096, 256),
    nn.ReLU(),
    nn.Linear(256, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



In [None]:
def train_one_epoch():
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for X, y in train_loader:
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X.size(0)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / total, correct / total


def evaluate():
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for X, y in test_loader:
            out = model(X)
            loss = criterion(out, y)
            total_loss += loss.item() * X.size(0)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return total_loss / total, correct / total


# -----------------------------------------------------------
# 6. Run training
# -----------------------------------------------------------

epochs = 30
for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch()
    test_loss, test_acc = evaluate()
    print(f"Epoch {epoch+1}/{epochs}  Train Loss {train_loss:.4f}  Train Acc {train_acc:.3f}  Test Loss {test_loss:.4f}  Test Acc {test_acc:.3f}")

torch.save(model.state_dict(), "mlp_adni_no_class.pth")
print("Model saved.")