In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import umap
import time

# 1. Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 2. Dataset class (optimized path handling)
class OASISDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []

        # Recursively search for .nii.png files in all subdirectories
        for dirpath, _, filenames in os.walk(root_dir):
            for f in filenames:
                if f.lower().endswith(".nii.png"):
                    self.image_paths.append(os.path.join(dirpath, f))

        if not self.image_paths:
            print(f"Warning: No .nii.png files found in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert("L")  # Convert to grayscale
            if self.transform:
                img = self.transform(img)
            return img
        except Exception as e:
            print(f"Error loading {img_path}: {str(e)}")
            return torch.zeros(1, 128, 128)  # Return placeholder

# 3. VAE model (combining best architectural features)
class BrainVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),   # 128x128 -> 64x64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, 4, 2, 1),  # 64x64 -> 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1), # 32x32 -> 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),# 16x16 -> 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Flatten()
        )

        # Latent space
        self.fc_mu = nn.Linear(256 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(256 * 8 * 8, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256 * 8 * 8)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, (256, 8, 8)),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 8x8 -> 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 16x16 -> 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 32x32 -> 64x64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(32, 1, 4, 2, 1),     # 64x64 -> 128x128
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = torch.clamp(self.fc_logvar(h), -20, 20)  # Prevent extreme values
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.leaky_relu(self.decoder_input(z), 0.2)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 4. Loss function (using MSE)
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    batch_size = x.size(0)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
    return recon_loss + beta * kl_loss, recon_loss, kl_loss

# 5. Training function (with early stopping and model saving)
def train_and_evaluate(data_dir, img_size=128, batch_size=16, latent_dim=256, epochs=100, lr=1e-4):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create dataset
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])

    full_dataset = OASISDataset(data_dir, transform=transform)

    # Split dataset (80% train, 10% validation, 10% test)
    train_size = int(0.8 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size]
    )

    print(f"Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize model
    model = BrainVAE(latent_dim=latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    # Training loop
    best_val_loss = float('inf')
    patience, patience_counter = 10, 0
    history = {'train_loss': [], 'val_loss': [], 'recon_loss': [], 'kl_loss': []}

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, recon_loss, kl_loss = 0, 0, 0

        for images in train_loader:
            images = images.to(device)
            optimizer.zero_grad()

            recon_images, mu, logvar = model(images)
            loss, r_loss, k_loss = vae_loss(recon_images, images, mu, logvar)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            recon_loss += r_loss.item()
            kl_loss += k_loss.item()

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images in val_loader:
                images = images.to(device)
                recon_images, mu, logvar = model(images)
                loss, _, _ = vae_loss(recon_images, images, mu, logvar)
                val_loss += loss.item()

        # Record history
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        avg_recon = recon_loss / len(train_loader)
        avg_kl = kl_loss / len(train_loader)

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['recon_loss'].append(avg_recon)
        history['kl_loss'].append(avg_kl)

        print(f"Epoch {epoch+1:03d}/{epochs} | "
              f"Train: {avg_train_loss:.4f} (Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}) | "
              f"Val: {avg_val_loss:.4f}")

        # Early stopping and model saving
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_brain_vae.pth')
            print("Saved best model")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    # Load best model for testing
    model.load_state_dict(torch.load('best_brain_vae.pth'))
    model.eval()

    # Test set evaluation
    test_loss = 0
    with torch.no_grad():
        for images in test_loader:
            images = images.to(device)
            recon_images, mu, logvar = model(images)
            loss, _, _ = vae_loss(recon_images, images, mu, logvar)
            test_loss += loss.item()

    print(f"Test loss: {test_loss/len(test_loader):.4f}")

    return model, history, test_loader

# 6. Visualization functions (combining best features)
def visualize_results(model, test_loader, history):
    device = next(model.parameters()).device

    # 1. Training curves
    plt.figure(figsize=(12, 8))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.plot(history['recon_loss'], '--', label='Reconstruction Loss')
    plt.plot(history['kl_loss'], '--', label='KL Loss')
    plt.title('Training Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Reconstruction visualization
    model.eval()
    with torch.no_grad():
        images = next(iter(test_loader)).to(device)
        recon_images, _, _ = model(images[:8])

        fig, axes = plt.subplots(2, 8, figsize=(20, 5))
        for i in range(8):
            axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
            axes[0, i].set_title(f'Original {i+1}')
            axes[0, i].axis('off')

            axes[1, i].imshow(recon_images[i].cpu().squeeze(), cmap='gray')
            axes[1, i].set_title(f'Reconstructed {i+1}')
            axes[1, i].axis('off')

        plt.suptitle('Original vs Reconstructed Images')
        plt.tight_layout()
        plt.savefig('reconstructions.png', dpi=300, bbox_inches='tight')
        plt.close()

    # 3. Generate new samples
    with torch.no_grad():
        z = torch.randn(16, model.latent_dim).to(device)
        generated = model.decode(z)

        fig, axes = plt.subplots(4, 4, figsize=(10, 10))
        for i in range(16):
            ax = axes[i//4, i%4]
            ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
            ax.axis('off')
            ax.set_title(f'Sample {i+1}', fontsize=8)

        plt.suptitle('Generated Brain MRI Samples')
        plt.tight_layout()
        plt.savefig('generated_samples.png', dpi=300, bbox_inches='tight')
        plt.close()

    # 4. Latent space UMAP visualization
    model.eval()
    latent_vectors = []
    with torch.no_grad():
        for images in test_loader:
            images = images.to(device)
            mu, _ = model.encode(images)
            latent_vectors.append(mu.cpu().numpy())

    latent_vectors = np.concatenate(latent_vectors)

    print("Applying UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
    embedding = reducer.fit_transform(latent_vectors)

    plt.figure(figsize=(12, 10))
    plt.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=10,
                c=np.arange(len(embedding)), cmap='viridis')
    plt.colorbar(label='Sample Index')
    plt.title('VAE Latent Space UMAP Visualization')
    plt.xlabel('UMAP Dimension 1')
    plt.ylabel('UMAP Dimension 2')
    plt.grid(True, alpha=0.2)
    plt.savefig('latent_space_umap.png', dpi=300, bbox_inches='tight')
    plt.close()

    # 5. Latent space interpolation
    with torch.no_grad():
        img1, img2 = next(iter(test_loader))[:2].to(device)
        mu1, _ = model.encode(img1.unsqueeze(0))
        mu2, _ = model.encode(img2.unsqueeze(0))

        fig, axes = plt.subplots(1, 10, figsize=(20, 3))
        for i, alpha in enumerate(np.linspace(0, 1, 10)):
            z = (1-alpha)*mu1 + alpha*mu2
            recon = model.decode(z)
            axes[i].imshow(recon[0].cpu().squeeze(), cmap='gray')
            axes[i].set_title(f'α={alpha:.1f}')
            axes[i].axis('off')

        plt.suptitle('Latent Space Interpolation')
        plt.tight_layout()
        plt.savefig('latent_interpolation.png', dpi=300, bbox_inches='tight')
        plt.close()

# 7. Main function
def main():
    # Use OASIS dataset on Rangpur cluster
    data_dir = "/content/drive/MyDrive/OASIS/keras_png_slices_train"

    # Train and evaluate model
    model, history, test_loader = train_and_evaluate(
        data_dir,
        img_size=128,       # Use larger image size
        batch_size=32,       # Adjust based on GPU memory
        latent_dim=256,      # Larger latent space dimension
        epochs=50,
        lr=1e-4
    )

    # Visualize results
    visualize_results(model, test_loader, history)
    print("All visualizations saved!")

if __name__ == "__main__":
    main()

Using device: cuda
Dataset sizes: Train=7731, Val=966, Test=967
Epoch 001/50 | Train: 1058.3377 (Recon: 1016.8213, KL: 41.5164) | Val: 383.3327
Saved best model
Epoch 002/50 | Train: 249.8542 (Recon: 220.3515, KL: 29.5027) | Val: 169.6264
Saved best model
Epoch 003/50 | Train: 132.5908 (Recon: 110.5072, KL: 22.0836) | Val: 125.9742
Saved best model
Epoch 004/50 | Train: 110.1345 (Recon: 88.8610, KL: 21.2735) | Val: 106.7082
Saved best model
Epoch 005/50 | Train: 101.6709 (Recon: 81.2563, KL: 20.4145) | Val: 98.6926
Saved best model
Epoch 006/50 | Train: 96.5979 (Recon: 76.5166, KL: 20.0812) | Val: 95.6821
Saved best model
Epoch 007/50 | Train: 93.7890 (Recon: 73.8883, KL: 19.9007) | Val: 93.6851
Saved best model
Epoch 008/50 | Train: 90.5566 (Recon: 71.2077, KL: 19.3489) | Val: 90.1019
Saved best model
Epoch 009/50 | Train: 88.5290 (Recon: 69.5001, KL: 19.0289) | Val: 86.7307
Saved best model
Epoch 010/50 | Train: 86.4846 (Recon: 67.7098, KL: 18.7749) | Val: 86.9081
Epoch 011/50 | Trai

  warn(


All visualizations saved!
