# ELIXRB 32x768 BASED

In [2]:
import os
import sys
from dotenv import load_dotenv 
from pathlib import Path

project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))



load_dotenv()

NIH_CXR14_ELIXR_DIR = os.getenv("NIH_CXR14_ELIXR_DIR")
INTERMEDIATE_DATA_DIR = project_root / "data"
print(NIH_CXR14_ELIXR_DIR)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' 


/home/yasin/Lfstorage/datasets/nih-cxr14/elixr


In [3]:
import h5py
import pickle
import numpy as np
import pandas as pd


In [4]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset


In [5]:
from diffusers import DDPMScheduler, UNet2DConditionModel
from src.utils import Unet2DConditionalTrainer, TrainConfig, get_validation_samples, create_validation_dataloader


2025-03-13 00:00:08.937758: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741813208.962628 2234473 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741813208.970421 2234473 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-13 00:00:08.996827: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Prepara Dataset

In [6]:
from src.datasets import EmbeddingDataset

contrastive_dataset = EmbeddingDataset(root = NIH_CXR14_ELIXR_DIR + "/elixrb/contrastive",name="contrastive", use_cache=True)

print(contrastive_dataset.list_columns())

['Image Index', 'all_contrastive_img_emb', 'contrastive_img_emb', 'contrastive_txt_emb']


In [7]:
contrastive_dataset.set_active_columns(["Image Index", "all_contrastive_img_emb"])
print(contrastive_dataset.list_columns())
print(len(contrastive_dataset))

['Image Index', 'all_contrastive_img_emb']
112120


In [10]:
filtered_df = pd.read_csv(INTERMEDIATE_DATA_DIR / "filtered_nihcxr14.csv")
filtered_idxs = filtered_df["Image Index"].values

with open(INTERMEDIATE_DATA_DIR / "latents_64.pkl", "rb") as f:
    latents = pickle.load(f)

    

In [11]:
class CustomDataset(Dataset):
    def __init__(self, latents, contrastive_dataset, filtered_idxs):
        self.latents = latents
        self.contrastive_dataset = contrastive_dataset
        self.filtered_idxs = filtered_idxs

    def __len__(self):
        return len(self.filtered_idxs)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError
        
        img_idx = self.filtered_idxs[idx]
        item = self.contrastive_dataset[img_idx]
        cond_embed = item["all_contrastive_img_emb"]
        img_idx = item["Image Index"][0]
        cond_embed = item["all_contrastive_img_emb"]
        latent = self.latents[img_idx]

        latent_tensor = torch.tensor(latent, dtype=torch.float32)
        cond_embed_tensor = torch.tensor(cond_embed, dtype=torch.float32)
        return latent_tensor, cond_embed_tensor
    
    

dataset = CustomDataset(latents=latents,
                        contrastive_dataset=contrastive_dataset,
                        filtered_idxs=filtered_idxs)


print(len(dataset))


48311


In [12]:
sample_latent, sample_cond_embed = dataset[0]
print(sample_latent.shape)
print(sample_cond_embed.shape)

torch.Size([4, 64, 64])
torch.Size([32, 128])


In [13]:
config = TrainConfig(
    batch_size=4,
    num_workers=24,
    mixed_precision=True,
    learning_rate=1e-5,
    weight_decay=1e-6,
    num_epochs=20,
    gradient_accumulation_steps=64,  # Effective batch size of 128
    scheduler_type="cosine",
    early_stopping_patience=5,
    use_timestep_weights=True
)

In [14]:
unet = UNet2DConditionModel(act_fn="silu",
                                    attention_head_dim=8,
                                    center_input_sample=False,
                                    downsample_padding=1,
                                    flip_sin_to_cos=True,
                                    freq_shift=0,
                                    mid_block_scale_factor=1,
                                    norm_eps=1e-05,
                                    norm_num_groups=32,
                                    sample_size=64, # generated samples are 512x512
                                    in_channels=4, 
                                    out_channels=4, 
                                    layers_per_block=2, 
                                    block_out_channels=(320, 640, 1280, 1280), 
                                    down_block_types=(
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "DownBlock2D"), 
                                    up_block_types=("UpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D"),
                                    cross_attention_dim=128
                                )

In [15]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=config.learning_rate)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [16]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True, pin_memory=True)


In [19]:
validation_samples = get_validation_samples(dataset, num_samples=16, max_seq_len=32)

In [20]:
print(validation_samples[0].shape)

torch.Size([32, 768])


In [21]:
import tensorboard
trainer = Unet2DConditionalTrainer(
    unet=unet,
    train_config=config,
    noise_scheduler=noise_scheduler,
    optimizer=optimizer)

Log directory is output/logs
Output directory is output


In [None]:
trainer.train(dataloader = train_dataloader,  
              validation_samples = validation_samples
)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/12078 [00:00<?, ?it/s]