In [None]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


# class MRIModalityDataset(Dataset):
#     def __init__(self, root_dir, mmap_mode='r'):
#         self.root_dir = root_dir
#         self.mmap_mode = mmap_mode
#         self.files = sorted([
#             f for f in os.listdir(root_dir) if f.endswith('.npy')
#         ])

#     def __len__(self):
#         return len(self.files)

#     def __getitem__(self, idx):
#         file_name = self.files[idx]
#         path = os.path.join(self.root_dir, file_name)
#         arr = np.load(path, mmap_mode=self.mmap_mode) 
#         arr = arr.squeeze(axis=0)                     
#         tensor = torch.from_numpy(arr.astype(np.float32))
#         return tensor


class MRIModalityDataset(Dataset):
    def __init__(self, npy_path, mmap_mode='r'):
        self.data = np.load(npy_path, mmap_mode=mmap_mode)  # shape: (N, D, H, W)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        img = self.data[idx]  # (121, 145, 121)
        tensor = torch.from_numpy(img.astype(np.float32)).unsqueeze(0)  # -> (1, 121, 145, 121)
        return tensor

    
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class VBMEncoder3D(nn.Module):
    def __init__(self, in_channels=1, hidden_channels=16, latent_dim=128):
        super().__init__()

        self.conv1 = nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=2) # (121,145,121) -> (60,72,60)

        self.conv2 = nn.Conv3d(hidden_channels, hidden_channels * 2, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=2) # (60,72,60) -> (30,36,30)

        # Final shape: (30, 36, 30)
        self.feature_shape = (hidden_channels * 2, 30, 36, 30)
        self.flat_dim = torch.prod(torch.tensor(self.feature_shape)).item()

        self.fc_mu = nn.Linear(self.flat_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flat_dim, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class VBMDecoder3D(nn.Module):
    def __init__(self, out_channels=1, hidden_channels=16, latent_dim=128, dropout_rate=0.0):
        super().__init__()
        self.feature_shape = (hidden_channels * 2, 30, 36, 30)
        self.flat_dim = torch.prod(torch.tensor(self.feature_shape)).item()

        self.fc = nn.Linear(latent_dim, self.flat_dim)
        self.dropout = nn.Dropout(dropout_rate)

        self.deconv1 = nn.ConvTranspose3d(hidden_channels * 2, hidden_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv3d(hidden_channels, hidden_channels, kernel_size=3, padding=1)

        self.deconv2 = nn.ConvTranspose3d(
            hidden_channels, out_channels, kernel_size=2, stride=2, output_padding=(1, 1, 1)
        )
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, z):
        x = self.fc(z)
        x = self.dropout(x)
        x = x.view(z.size(0), *self.feature_shape)

        x = self.deconv1(x)
        x = F.relu(self.conv1(x))
        x = self.dropout(x)

        x = self.deconv2(x)
        x = self.conv2(x) 

        return x

class VAEVBM3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, hidden_channels=16, latent_dim=128, dropout_rate=0.0):
        super().__init__()
        self.encoder = VBMEncoder3D(in_channels, hidden_channels, latent_dim)
        self.decoder = VBMDecoder3D(out_channels, hidden_channels, latent_dim, dropout_rate)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar


def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kld, recon_loss.item(), kld.item()


def train_vae(
    model,
    train_loader,
    val_loader,
    lr=1e-4,
    device='cuda',
    beta=1.0,
    patience=5,
    min_delta=0.0,
    max_epochs=100
):
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_history = []
    val_loss_history = []

    n_train = len(train_loader.dataset)
    n_val = len(val_loader.dataset)

    print(f"Number of training samples: {n_train}")
    print(f"Number of validation samples: {n_val}\n")

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_train_loss = 0.0

        for batch_data in train_loader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            recon, mu, logvar = model(batch_data)
            loss, _, _ = vae_loss(recon, batch_data, mu, logvar, beta)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_data in val_loader:
                batch_data = batch_data.to(device)
                recon, mu, logvar = model(batch_data)
                loss, _, _ = vae_loss(recon, batch_data, mu, logvar, beta)
                total_val_loss += loss.item()

        avg_train_loss = total_train_loss / n_train
        avg_val_loss = total_val_loss / n_val

        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)

        print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss + min_delta < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"EarlyStopping counter: {epochs_no_improve} out of {patience}")

        if epochs_no_improve >= patience:
            print("\nEarly stopping triggered.")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return train_loss_history, val_loss_history

In [None]:
def plot_training_loss(train_loss_history, val_loss_history):
    epochs = range(1, len(train_loss_history) + 1)
    
    plt.figure(figsize=(8, 6))

    plt.style.use("seaborn-v0_8-whitegrid")
    
    plt.plot(epochs, train_loss_history, label="Training Loss", color="tab:blue", linewidth=2)
    plt.plot(epochs, val_loss_history, label="Validation Loss", color="tab:orange", linewidth=2, linestyle='--')

    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.title("VAE Training and Validation Loss Over Epochs", fontsize=16)

    plt.legend(fontsize=12)
    
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    plt.tight_layout()
    plt.show()

##### Use all VBM training data

In [None]:
train_vbm_ds = MRIModalityDataset("features_3d_train_vbm.npy")
train_vbm_loader = DataLoader(train_vbm_ds, batch_size=2, shuffle=True, num_workers=0)

val_vbm_ds = MRIModalityDataset("features_3d_val_vbm.npy")
val_vbm_loader = DataLoader(val_vbm_ds, batch_size=2, shuffle=True, num_workers=0)


In [None]:
model_vbm = VAEVBM3D(
    in_channels=1,
    out_channels=1,
    hidden_channels=16,
    latent_dim=128, 
    dropout_rate=0.0
).to(device)

print("=== Training VBM VAE ===")
train_loss, val_loss = train_vae(model_vbm,
                                 train_loader=train_vbm_loader,
                                 val_loader=val_vbm_loader,
                                 lr=1e-4,
                                 device=device,
                                 beta=1.0,
                                 patience=5,      
                                 min_delta=0.01,    
                                 max_epochs=100   
                                )

In [None]:
train_loss = np.array(train_loss)
val_loss = np.array(val_loss)

np.save("VAE_train_loss.npy", train_loss)
np.save("VAE_val_loss.npy", val_loss)

# train_loss = np.load("VAE_train_loss.npy")
# val_loss = np.load("VAE_val_loss.npy")

In [None]:
plot_training_loss(train_loss, val_loss)

### Extracting Latent Features

In [None]:
def extract_latents(model, dataloader, device='cuda'):
    model.eval()
    all_latents = []

    with torch.no_grad():
        for batch_data in dataloader:
            batch_data = batch_data.to(device)
            mu, _ = model.encoder(batch_data)
            all_latents.append(mu.cpu())

    return torch.cat(all_latents, dim=0) 

In [None]:
## training dataset
train_feature_loader = DataLoader(train_vbm_ds, batch_size=2, shuffle=False)
VBM_train_features = extract_latents(model_vbm, train_feature_loader, device=device)

## validation dataset
val_feature_loader = DataLoader(val_vbm_ds, batch_size=2, shuffle=False)
VBM_val_features = extract_latents(model_vbm, val_feature_loader, device=device)

In [None]:
print(VBM_train_features.shape)
print(VBM_val_features.shape)

In [None]:
torch.save({
    'train': VBM_train_features,
    'val': VBM_val_features
}, 'vbm_features.pt')
