In [None]:
import os
import sys

import pickle

from pathlib import Path
from dataclasses import dataclass
from tqdm.notebook import tqdm


from diffusers import DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_cosine_schedule_with_warmup

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset


from PIL import Image
import numpy as np
import math



# Get absolute path to project root
project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))

from dotenv import load_dotenv
load_dotenv()

nih_dataset_root_dir = os.getenv("NIH_CXR14_DATASET_DIR")

main_data_dir = "../data"



from src.pipelines import VaeProcessor

## Configuration

In [None]:
unet_config = {
    # Architecture parameters
    "sample_size": 28,  # for 224x224 images (224 = 28 * 8)
    "in_channels": 4,
    "out_channels": 4,
    "layers_per_block": 2,
    "block_out_channels": (320, 640, 1280, 1280),
    "down_block_types": (
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ),
    "up_block_types": (
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D"
    ),
    
    # Attention parameters
    "attention_head_dim": 8,
    "cross_attention_dim": 768,
    
    # Normalization and activation
    "norm_num_groups": 32,
    "norm_eps": 1e-05,
    "act_fn": "silu",
    
    # Additional configuration
    "center_input_sample": False,
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "mid_block_scale_factor": 1
}


# To create the model:


@dataclass
class train_config:
    device = "cuda:2"
    num_workers = 32
    batch_size = 32
    mixed_precision = "fp16"    
    output_dir = "output"
    save_model_epochs = 1
    num_epochs = 20
    num_train_timesteps = 1000
    learning_rate = 1e-5
    lr_warmup_steps:int = 500
    unet_config = unet_config
    gradient_accumulation_steps: int = 1

    

torch.cuda.set_device(int(train_config.device.split(":")[-1]))
config = train_config()

In [None]:
latents = pickle.load(open(f"{main_data_dir}/latents.pkl", "rb"))

text_embeds = pickle.load(open(f"{main_data_dir}/clip_text_embeds.pkl", "rb"))



In [None]:
class CustomDataset(Dataset):
    def __init__(self, latents, text_embeds, latent_transform=None, text_embed_transform=None):
        self.latents = latents
        self.text_embeds = text_embeds
        self._keys = list(latents.keys())
        self.text_transform = text_embed_transform
        self.latent_transform = latent_transform

    def __len__(self):
        return len(self._keys)
    
    def __getitem__(self, idx):
        key = self._keys[idx]

        # Convert text_embed to tensor if it's numpy array
        text_embed = self.text_embeds[key]
        if isinstance(text_embed, np.ndarray):
            text_embed = torch.from_numpy(text_embed).float()

        # Pad or trim text embeddings
        if text_embed.shape[0] < 77:
            padding = torch.zeros(77 - text_embed.shape[0], 768)
            text_embed = torch.cat((text_embed, padding), 0)
        elif text_embed.shape[0] > 77:
            text_embed = text_embed[:77]

        if self.text_transform:
            text_embed = self.text_transform(text_embed)
        
        # Convert latent to tensor if it's numpy array
        latent = self.latents[key]
        if isinstance(latent, np.ndarray):
            latent = torch.from_numpy(latent).float()

        if self.latent_transform:
            latent = self.latent_transform(latent)

        return latent, text_embed

def save_images_grid(images, output_path, grid_size=None, padding=10):
    
    if not images:
        raise ValueError("Image list is empty")
    
    if grid_size is None:
        n_images = len(images)
        n_cols = math.ceil(math.sqrt(n_images))
        n_rows = math.ceil(n_images / n_cols)
    else:
        n_rows, n_cols = grid_size
        
    max_width = max(img.width for img in images)
    max_height = max(img.height for img in images)
    
    grid_width = (max_width + padding) * n_cols - padding
    grid_height = (max_height + padding) * n_rows - padding
    grid_image = Image.new('RGB', (grid_width, grid_height), color='white')
    
    for idx, img in enumerate(images):
        row = idx // n_cols
        col = idx % n_cols
        
        x = col * (max_width + padding)
        y = row * (max_height + padding)
        
        x_offset = (max_width - img.width) // 2
        y_offset = (max_height - img.height) // 2
        
        grid_image.paste(img, (x + x_offset, y + y_offset))
    
    grid_image.save(output_path, quality=95)
    return grid_image
def generate_sample(model, samples, device):
    scheduler = DDPMScheduler()
    scheduler.set_timesteps(25)
    
    model.eval()
    vae_processor = VaeProcessor(device)
    images = []

    with torch.no_grad():
        for sample in samples:
            latent, text_embed = sample
            
            # Ensure correct shape for latents [batch_size, channels, height, width]
            if len(latent.shape) == 3:
                latent = latent.unsqueeze(0)  # Add batch dimension if missing
            
            # Move to device
            latent = latent.to(device)
            text_embed = text_embed.to(device)
            
            # Ensure text embeddings have correct shape [batch_size, sequence_length, hidden_size]
            if len(text_embed.shape) == 2:
                text_embed = text_embed.unsqueeze(0)
            
            
            # Initialize noise
            latent = latent * scheduler.init_noise_sigma
            
            # Denoising loop
            for t in scheduler.timesteps:
                latent_model_input = scheduler.scale_model_input(latent, t)
                
                noise_pred = model(
                    latent_model_input,
                    t,
                    encoder_hidden_states=text_embed,
                    return_dict=False
                )[0]
                
                latent = scheduler.step(noise_pred, t, latent).prev_sample
            
            # Decode and append the generated image
            images.append(vae_processor.decode_latent(latent))
    
    del vae_processor
    return images

In [None]:
dataset = CustomDataset(latents, text_embeds, latent_transform=None, text_embed_transform=None)



fixed_validate_samples = []

random_idx = torch.randint(0, len(dataset), (10,))

for idx in random_idx:
    fixed_validate_samples.append(dataset[idx])





dataloader = DataLoader(dataset, 
                        batch_size=config.batch_size, 
                        shuffle=True, 
                        num_workers=config.num_workers,
                        pin_memory=True)



In [None]:
print(fixed_validate_samples[0][0].shape, fixed_validate_samples[0][1].shape)

In [None]:
model = UNet2DConditionModel(**unet_config)

noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

lr_scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                               num_warmup_steps=config.lr_warmup_steps, 
                                               num_training_steps=len(dataset) * config.num_epochs)



In [None]:
from accelerate import Accelerator
import tensorboard as tb

# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
if accelerator.is_main_process:
    if config.output_dir is not None:
        os.makedirs(config.output_dir, exist_ok=True)
    accelerator.init_trackers("train_example")



In [None]:
model, optimizer, lr_scheduler, dataloader,  = accelerator.prepare(
    model, optimizer, lr_scheduler, dataloader
)

In [None]:
def save_model_checkpoint(model, accelerator, output_dir, epoch):
    # Unwrap the model from accelerator to get the original model
    unwrapped_model = accelerator.unwrap_model(model)
    
    # Save the model state
    if accelerator.is_main_process:
        # Create pipeline folder
        os.makedirs(os.path.join(output_dir, f"checkpoint-{epoch}"), exist_ok=True)
        
        # Save the model in diffusers format
        unwrapped_model.save_pretrained(os.path.join(output_dir, f"checkpoint-{epoch}/unet"))

In [None]:
def train(config, model, noise_scheduler, dataloader, optimizer, lr_scheduler, accelerator):
    """
    Training loop for diffusion model.
    """
    global_step = 0
    
    # Progress bar for epochs
    progress_bar = tqdm(range(config.num_epochs), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Epochs")
    
    for epoch in range(config.num_epochs):
        model.train()
        
        # Progress bar for steps
        step_progress_bar = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)
        step_progress_bar.set_description(f"Epoch {epoch}")
        
        # Track epoch metrics
        epoch_loss = 0
        
        for step, (latents, text_embeds) in enumerate(dataloader):
            # Prepare inputs
            latents = latents.to(accelerator.device)
            text_embeds = text_embeds.to(accelerator.device)
            
            # Sample noise and timesteps
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
            
            # Add noise to latents
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            # Training step
            with accelerator.accumulate(model):
                # Predict noise
                noise_pred = model(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=text_embeds,
                    return_dict=False
                )[0]
                
                # Calculate loss
                loss = F.mse_loss(noise_pred, noise)
                epoch_loss += loss.detach().item()
                
                # Backpropagation
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            # Logging
            if accelerator.is_main_process:
                logs = {
                    "loss": loss.detach().item(),
                    "avg_loss": epoch_loss / (step + 1),
                    "lr": lr_scheduler.get_last_lr()[0],
                    "epoch": epoch,
                    "step": global_step,
                }
                step_progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)
            
            global_step += 1
            step_progress_bar.update(1)
            
        
        # Save checkpoint and generate samples
        if (epoch + 1) % config.save_model_epochs == 0:
            # Create output directory
            output_dir = os.path.join(config.output_dir, ".cache")
            os.makedirs(output_dir, exist_ok=True)
            
            # Save model
            save_model_checkpoint(model, accelerator, output_dir, epoch)
            
            # Generate and save samples
            with torch.no_grad():
                images = generate_sample(model, fixed_validate_samples, accelerator.device)
                save_images_grid(images, os.path.join(output_dir, f"sample_{epoch}.png"))
        
        step_progress_bar.close()
        progress_bar.update(1)
    
    save_model_checkpoint(model, accelerator, config.output_dir, "final")
    images = generate_sample(model, fixed_validate_samples, "cuda:2")
    save_images_grid(images, os.path.join(config.output_dir, "final_sample.png"))


    progress_bar.close()

    return global_step

In [None]:
train(config, model, noise_scheduler, dataloader, optimizer, lr_scheduler, accelerator)