# ELIXRB 32x768 BASED

In [None]:
import os
import sys
import h5py
from dotenv import load_dotenv 
from pathlib import Path

project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))



load_dotenv()

NIH_CXR14_DATASET_DIR = os.getenv("NIH_CXR14_DATASET_DIR")
print(NIH_CXR14_DATASET_DIR)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' 


In [None]:
from src.utils import Unet2DConditionalTrainerV2, TrainConfigV2, get_validation_samples_v2, create_validation_dataloader
from diffusers import DDPMScheduler, UNet2DConditionModel
import torch
from torch.utils.data import Dataset
import torch.optim as optim


import numpy as np
import os
import torch
from torch.utils.data import Dataset
import h5py
import pandas as pd

class EmbeddingLatentDataset(Dataset):
    def __init__(self, latent_dir, embedding_dir, csv_path):
        """
        Memory-efficient dataset that loads latent vectors and embeddings on demand,
        filtered by a CSV file with image IDs.
        
        Args:
            latent_dir (str): Directory containing latent vector H5 files
            embedding_dir (str): Directory containing embedding H5 files
            csv_path (str): Path to the CSV file containing image IDs to filter by
        """
        # Store directories
        self.latent_dir = latent_dir
        self.embedding_dir = embedding_dir
        
        # Read CSV file with selected image IDs
        self.selected_image_df = pd.read_csv(csv_path)
        self.selected_image_ids = self.selected_image_df['Image Index'].tolist()
        print(f"Loaded {len(self.selected_image_ids)} image IDs from CSV")
        
        # Get all file paths
        self.latent_files = [os.path.join(latent_dir, f) for f in os.listdir(latent_dir) if f.endswith(".h5")]
        self.embedding_files = [os.path.join(embedding_dir, f) for f in os.listdir(embedding_dir) if f.endswith(".h5")]
        
        # Build index of image IDs to file locations
        self.latent_id_map = self._build_id_map(self.latent_files, "latents")
        self.embedding_id_map = self._build_id_map(self.embedding_files, "img_emb")
        
        # Find common IDs between latents, embeddings, and the CSV
        csv_set = set(self.selected_image_ids)
        latent_set = set(self.latent_id_map.keys())
        embedding_set = set(self.embedding_id_map.keys())
        
        # Filter to only include images in our CSV
        self.common_ids = list(csv_set & latent_set & embedding_set)
        
        # Keep the order from the CSV
        self.common_ids = [img_id for img_id in self.selected_image_ids if img_id in self.common_ids]
        
        print(f"Total matching IDs after filtering: {len(self.common_ids)}")
        
        # Statistics
        print(f"IDs in CSV but missing latents: {len(csv_set - latent_set)}")
        print(f"IDs in CSV but missing embeddings: {len(csv_set - embedding_set)}")
    
    def _build_id_map(self, files, dataset_name):
        """
        Build a mapping from image IDs to file paths and indices within those files.
        
        Args:
            files (list): List of H5 file paths
            dataset_name (str): Name of dataset in H5 files ("Latents" or "Embeddings")
            
        Returns:
            dict: Mapping from image ID to (file_path, index) tuple
        """
        id_map = {}
        
        for file_path in files:
            with h5py.File(file_path, "r") as hf:
                try:
                    image_indices = hf["Image index"][:]
                except:
                    image_indices = hf["Image Index"][:]
                # Convert byte strings to regular strings if needed
                if isinstance(image_indices[0], bytes):
                    image_indices = [idx.decode("utf-8") for idx in image_indices]
                
                # Map each ID to its file and position
                for i, img_id in enumerate(image_indices):
                    id_map[img_id] = (file_path, i)
        
        return id_map
    
    def __len__(self):
        """Return the number of matched image pairs"""
        return len(self.common_ids)
    
    def __getitem__(self, idx):
        """
        Get a latent and embedding pair by index.
        
        Args:
            idx (int): Index in the dataset
            
        Returns:
            tuple: (latent, cond_embed) pair as torch tensors
        """
        # Get the image ID for this index
        img_id = self.common_ids[idx]
        
        # Get file path and index for latent
        latent_file, latent_idx = self.latent_id_map[img_id]
        
        # Get file path and index for embedding
        embedding_file, embedding_idx = self.embedding_id_map[img_id]
        
        # Load latent vector
        with h5py.File(latent_file, "r") as hf:
            latent = hf["latents"][latent_idx]
        
        # Load embedding
        with h5py.File(embedding_file, "r") as hf:
            cond_embed = hf["img_emb"][embedding_idx]
        
        # Convert to torch tensors
        latent = torch.tensor(latent, dtype=torch.float32)
        cond_embed = torch.tensor(cond_embed, dtype=torch.float32)
        
        # Ensure conditioning embeddings have the right shape
        # This is crucial for the UNet cross-attention to work properly
        cond_embed = cond_embed.squeeze(0)
           
        return latent, cond_embed

In [None]:
elixrb_dir = os.path.join(NIH_CXR14_DATASET_DIR, "elixr", "elixrb", "img_emb")
latent_dir = os.path.join(NIH_CXR14_DATASET_DIR, "vae_latents2")
csv_path = "/home/yasin/Lfstorage/Projects/cxr-diffusion/intermediate_data/filtered_findings_label_data.csv"


In [None]:
dataset = EmbeddingLatentDataset(latent_dir, elixrb_dir, csv_path)
len(dataset)

In [None]:
config = TrainConfigV2(
    batch_size=4,
    mixed_precision=True,
    learning_rate=1e-5,
    num_epochs=20,
    gradient_accumulation_steps=64,  # Effective batch size of 128
    scheduler_type="cosine",
    early_stopping_patience=5,
    use_timestep_weights=True
)

In [None]:
unet = UNet2DConditionModel(act_fn="silu",
                                    attention_head_dim=8,
                                    center_input_sample=False,
                                    downsample_padding=1,
                                    flip_sin_to_cos=True,
                                    freq_shift=0,
                                    mid_block_scale_factor=1,
                                    norm_eps=1e-05,
                                    norm_num_groups=32,
                                    sample_size=64, # generated samples are 512x512
                                    in_channels=4, 
                                    out_channels=4, 
                                    layers_per_block=2, 
                                    block_out_channels=(320, 640, 1280, 1280), 
                                    down_block_types=(
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "DownBlock2D"), 
                                    up_block_types=("UpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D"),
                                    cross_attention_dim=768
                                )

In [None]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=config.learning_rate)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True, pin_memory=True)


In [None]:
validation_samples = get_validation_samples_v2(dataset, num_samples=16)

In [None]:
print(validation_samples[0].shape)

In [None]:
import tensorboard
trainer = Unet2DConditionalTrainerV2(
    unet=unet,
    train_config=config,
    noise_scheduler=noise_scheduler,
    optimizer=optimizer)

In [None]:
trainer.train(dataloader = train_dataloader,  
              validation_samples = validation_samples
)