In [None]:
!pip install torch torchvision tqdm pandas

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
# =============================
# STEP 1: INSTALL & IMPORT LIBRARIES
# =============================
!pip install torch torchvision tqdm pandas

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import json
import math
import pandas as pd
import glob
from PIL import Image
from tqdm import tqdm



In [None]:
import os
from google.colab import drive

# ✅ Check if Google Drive is already mounted
if not os.path.ismount("/content/drive"):
    drive.mount('/content/drive')

# ✅ Set paths (CHANGE THESE IF NEEDED)
BASE_DIR = "/content/drive/MyDrive/DATABase/"

# Ensure BASE_DIR exists (avoid errors)
if not os.path.exists(BASE_DIR):
    raise FileNotFoundError(f"⚠️ Base directory not found: {BASE_DIR}")

IMAGE_PATH = os.path.join(BASE_DIR, "HAM10000_images")
CSV_PATH = os.path.join(BASE_DIR, "ham10000_metadata.csv")  # ✅ CSV file for metadata
SAVE_FOLDER = os.path.join(BASE_DIR, "Generated_images")

# ✅ Create save folder if it doesn't exist
os.makedirs(SAVE_FOLDER, exist_ok=True)

# ✅ Ensure the CSV file exists before proceeding
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(f"⚠️ CSV file not found: {CSV_PATH}")

print("✅ Paths set up successfully! CSV file found!")


Mounted at /content/drive
✅ Paths set up successfully! CSV file found!


In [None]:
# ✅ Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ================================================
# STEP 2: LOAD DATASET WITH RICHER METADATA (128×128)
# ================================================

class CustomImageDataset(Dataset):
    def __init__(self, image_dir, csv_path, transform=None):
        self.image_paths = glob.glob(os.path.join(image_dir, "*.jpg"))
        self.transform = transform

        # ✅ Load and clean CSV
        self.metadata = pd.read_csv(csv_path)
        self.metadata = self.metadata.drop_duplicates(subset="isic_id")
        self.metadata.set_index("isic_id", inplace=True)

        # ✅ Build lookup for categorical fields (one-hot)
        self.anatom_sites = sorted(self.metadata['anatom_site_general'].dropna().unique().tolist())
        self.site_to_idx = {site: i for i, site in enumerate(self.anatom_sites)}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        filename = os.path.basename(img_path).replace(".jpg", "")

        # ✅ Load metadata row
        metadata_entry = self.metadata.to_dict(orient="index").get(filename, {})

        # --- Metadata Encoding ---
        melanocytic = 1 if metadata_entry.get("melanocytic", False) else 0
        sex = 1 if metadata_entry.get("sex") == "female" else 0
        age = metadata_entry.get("age_approx", 50)
        age_group = [1, 0, 0] if age < 40 else [0, 1, 0] if age <= 65 else [0, 0, 1]
        site_vec = [0] * len(self.site_to_idx)
        site = metadata_entry.get("anatom_site_general", None)
        if site in self.site_to_idx:
            site_vec[self.site_to_idx[site]] = 1

        metadata_vector = torch.tensor(
            [melanocytic, sex] + age_group + site_vec,
            dtype=torch.float
        )

        if self.transform:
            img = self.transform(img)

        return img, metadata_vector


# ✅ Define updated image preprocessing (256×256)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # ⬅️ High-res update
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])

# ✅ Load dataset
dataset = CustomImageDataset(IMAGE_PATH, CSV_PATH, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

print(f"✅ Dataset loaded! Total images: {len(dataset)}")


Using device: cuda
✅ Dataset loaded! Total images: 10015


In [None]:
def get_timestep_embedding(timesteps, dim=64):
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb  # Shape: (B, dim)

In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class SimpleUNet(nn.Module):
    def __init__(self, img_channels=3, metadata_size=12):
        super().__init__()

        self.meta_fc = nn.Linear(metadata_size, 64)
        self.time_mlp = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

        # Downsampling
        self.down1 = UNetBlock(img_channels + 128, 64)
        self.down2 = UNetBlock(64, 128)
        self.down3 = UNetBlock(128, 256)
        self.down4 = UNetBlock(256, 512)

        # Bottleneck
        self.bottleneck = UNetBlock(512, 512)

        # Upsampling
        self.up1 = UNetBlock(512 + 512, 256)
        self.up2 = UNetBlock(256 + 256, 128)
        self.up3 = UNetBlock(128 + 128, 64)
        self.up4 = UNetBlock(64 + 64, 64)

        self.final = nn.Conv2d(64, img_channels, kernel_size=1)

    def forward(self, x, t, metadata):
        B, _, H, W = x.shape

        t_embed = get_timestep_embedding(t, dim=64)
        t_embed = self.time_mlp(t_embed).view(B, -1, 1, 1).expand(B, -1, H, W)
        m_embed = self.meta_fc(metadata).view(B, -1, 1, 1).expand(B, 64, H, W)

        x = torch.cat([x, t_embed, m_embed], dim=1)

        # Down path
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        # Bottleneck
        b = self.bottleneck(d4)

        # Up path (with skip connections)
        u1 = self.up1(torch.cat([b, d4], dim=1))
        u2 = self.up2(torch.cat([u1, d3], dim=1))
        u3 = self.up3(torch.cat([u2, d2], dim=1))
        u4 = self.up4(torch.cat([u3, d1], dim=1))

        return self.final(u4)


# ✅ Detect metadata size dynamically from dataset
sample_img, sample_metadata = dataset[0]
metadata_size = sample_metadata.shape[0]
print(f"Detected metadata size: {metadata_size}")

# ✅ Initialize model
diffusion_model = SimpleUNet(img_channels=3, metadata_size=metadata_size).to(device)


Detected metadata size: 12


In [None]:
# =============================
# STEP 4: TRAINING LOOP (256×256 + DEEP UNET + FULL CONDITIONING)
# =============================

import torch.nn.functional as F

timesteps = 1000
MODEL_SAVE_PATH = "/content/drive/MyDrive/DATABase/trained_diffusion_model"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

#Noise schedule
def cosine_noise_schedule(timesteps, s=0.005):
    t = torch.linspace(0, timesteps, timesteps + 1, device=device) / timesteps
    alpha_bar = torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
    betas = 1 - (alpha_bar[1:] / alpha_bar[:-1])
    return betas.clamp(1e-5, 0.02)

#Precompute
betas = cosine_noise_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

#Optimizer
optimizer = optim.Adam(diffusion_model.parameters(), lr=2e-4)

#Early Stopping
patience = 15
min_delta = 1e-4
best_loss = float('inf')
counter = 0
num_epochs = 300  # Increased for higher-res

for epoch in range(num_epochs):
    diffusion_model.train()
    epoch_loss = 0.0

    for real_images, metadata in tqdm(dataloader):
        real_images = real_images.to(device)
        metadata = metadata.to(device)
        metadata = torch.where(torch.isnan(metadata), torch.zeros_like(metadata), metadata)

        noise = torch.randn_like(real_images)
        t = torch.randint(1, timesteps, (real_images.shape[0],), device=device).long()

        sqrt_alpha_cumprod_t = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
        noisy_images = sqrt_alpha_cumprod_t * real_images + sqrt_one_minus_alpha_cumprod_t * noise

        predicted_noise = diffusion_model(noisy_images, t, metadata)
        loss = F.mse_loss(predicted_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(diffusion_model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(dataloader)
    print(f"✅ Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {epoch_loss:.6f}")

    # ✅ Early stopping + best model saving
    if epoch_loss < best_loss - min_delta:
        best_loss = epoch_loss
        counter = 0
        best_model_path = os.path.join(MODEL_SAVE_PATH, "best_model.pth")
        torch.save(diffusion_model.state_dict(), best_model_path)
        print(f"💾 Best model saved to: {best_model_path}")
    else:
        counter += 1
        if counter >= patience:
            print(f"⏹️ Early stopping triggered at epoch {epoch+1}")
            break

    # ✅ Save full checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(MODEL_SAVE_PATH, f"diffusion_model_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': diffusion_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        print(f"💾 Full checkpoint saved: {checkpoint_path}")


100%|██████████| 2504/2504 [1:35:22<00:00,  2.29s/it]


✅ Epoch [1/300] - Avg Loss: 0.044751
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:29:45<00:00,  2.15s/it]


✅ Epoch [2/300] - Avg Loss: 0.025341
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:22<00:00,  2.17s/it]


✅ Epoch [3/300] - Avg Loss: 0.024173
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:20<00:00,  2.16s/it]


✅ Epoch [4/300] - Avg Loss: 0.021514
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:11<00:00,  2.16s/it]


✅ Epoch [5/300] - Avg Loss: 0.019528
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:29<00:00,  2.17s/it]


✅ Epoch [6/300] - Avg Loss: 0.020283


100%|██████████| 2504/2504 [1:30:26<00:00,  2.17s/it]


✅ Epoch [7/300] - Avg Loss: 0.019541


100%|██████████| 2504/2504 [1:30:07<00:00,  2.16s/it]


✅ Epoch [8/300] - Avg Loss: 0.019333
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:11<00:00,  2.16s/it]


✅ Epoch [9/300] - Avg Loss: 0.019535


100%|██████████| 2504/2504 [1:30:08<00:00,  2.16s/it]


✅ Epoch [10/300] - Avg Loss: 0.019258
💾 Full checkpoint saved: /content/drive/MyDrive/DATABase/trained_diffusion_model/diffusion_model_epoch_10.pth


100%|██████████| 2504/2504 [1:30:08<00:00,  2.16s/it]


✅ Epoch [11/300] - Avg Loss: 0.018891
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:03<00:00,  2.16s/it]


✅ Epoch [12/300] - Avg Loss: 0.017563
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


100%|██████████| 2504/2504 [1:30:03<00:00,  2.16s/it]


✅ Epoch [13/300] - Avg Loss: 0.017227
💾 Best model saved to: /content/drive/MyDrive/DATABase/trained_diffusion_model/best_model.pth


 23%|██▎       | 574/2504 [20:41<1:09:17,  2.15s/it]