# ELIXRB 32x768 BASED

In [1]:
import os
import sys
import h5py
from dotenv import load_dotenv 
from pathlib import Path

project_root = Path(os.path.abspath('')).parent.parent
sys.path.append(str(project_root))



load_dotenv()

NIH_CXR14_DATASET_DIR = os.getenv("NIH_CXR14_DATASET_DIR")
print(NIH_CXR14_DATASET_DIR)
os.environ['CUDA_VISIBLE_DEVICES'] = '2' 


/home/yasin/Lfstorage/datasets/nih-cxr14


In [2]:
from src.utils import Unet2DConditionalTrainerV2, TrainConfigV2, get_validation_samples_v2, create_validation_dataloader
from diffusers import DDPMScheduler, UNet2DConditionModel
import torch
from torch.utils.data import Dataset
import torch.optim as optim


import numpy as np
import os
import torch
from torch.utils.data import Dataset
import h5py
import pandas as pd

class EmbeddingLatentDataset(Dataset):
    def __init__(self, latent_dir, embedding_dir, csv_path):
        """
        Memory-efficient dataset that loads latent vectors and embeddings on demand,
        filtered by a CSV file with image IDs.
        
        Args:
            latent_dir (str): Directory containing latent vector H5 files
            embedding_dir (str): Directory containing embedding H5 files
            csv_path (str): Path to the CSV file containing image IDs to filter by
        """
        # Store directories
        self.latent_dir = latent_dir
        self.embedding_dir = embedding_dir
        
        # Read CSV file with selected image IDs
        self.selected_image_df = pd.read_csv(csv_path)
        self.selected_image_ids = self.selected_image_df['Image Index'].tolist()
        print(f"Loaded {len(self.selected_image_ids)} image IDs from CSV")
        
        # Get all file paths
        self.latent_files = [os.path.join(latent_dir, f) for f in os.listdir(latent_dir) if f.endswith(".h5")]
        self.embedding_files = [os.path.join(embedding_dir, f) for f in os.listdir(embedding_dir) if f.endswith(".h5")]
        
        # Build index of image IDs to file locations
        self.latent_id_map = self._build_id_map(self.latent_files, "latents")
        self.embedding_id_map = self._build_id_map(self.embedding_files, "CLIP Embeddings")
        
        # Find common IDs between latents, embeddings, and the CSV
        csv_set = set(self.selected_image_ids)
        latent_set = set(self.latent_id_map.keys())
        embedding_set = set(self.embedding_id_map.keys())
        
        # Filter to only include images in our CSV
        self.common_ids = list(csv_set & latent_set & embedding_set)
        
        # Keep the order from the CSV
        self.common_ids = [img_id for img_id in self.selected_image_ids if img_id in self.common_ids]
        
        print(f"Total matching IDs after filtering: {len(self.common_ids)}")
        
        # Statistics
        print(f"IDs in CSV but missing latents: {len(csv_set - latent_set)}")
        print(f"IDs in CSV but missing embeddings: {len(csv_set - embedding_set)}")
    
    def _build_id_map(self, files, dataset_name):
        """
        Build a mapping from image IDs to file paths and indices within those files.
        
        Args:
            files (list): List of H5 file paths
            dataset_name (str): Name of dataset in H5 files ("Latents" or "Embeddings")
            
        Returns:
            dict: Mapping from image ID to (file_path, index) tuple
        """
        id_map = {}
        
        for file_path in files:
            with h5py.File(file_path, "r") as hf:
                try:
                    image_indices = hf["Image index"][:]
                except:
                    image_indices = hf["Image Index"][:]
                # Convert byte strings to regular strings if needed
                if isinstance(image_indices[0], bytes):
                    image_indices = [idx.decode("utf-8") for idx in image_indices]
                
                # Map each ID to its file and position
                for i, img_id in enumerate(image_indices):
                    id_map[img_id] = (file_path, i)
        
        return id_map
    
    def __len__(self):
        """Return the number of matched image pairs"""
        return len(self.common_ids)
    
    def __getitem__(self, idx):
        """
        Get a latent and embedding pair by index.
        
        Args:
            idx (int): Index in the dataset
            
        Returns:
            tuple: (latent, cond_embed) pair as torch tensors
        """
        # Get the image ID for this index
        img_id = self.common_ids[idx]
        
        # Get file path and index for latent
        latent_file, latent_idx = self.latent_id_map[img_id]
        
        # Get file path and index for embedding
        embedding_file, embedding_idx = self.embedding_id_map[img_id]
        
        # Load latent vector
        with h5py.File(latent_file, "r") as hf:
            latent = hf["latents"][latent_idx]
        
        # Load embedding
        with h5py.File(embedding_file, "r") as hf:
            cond_embed = hf["CLIP Embeddings"][embedding_idx]
        
        # Convert to torch tensors
        latent = torch.tensor(latent, dtype=torch.float32)
        cond_embed = torch.tensor(cond_embed, dtype=torch.float32)
        
        # Ensure conditioning embeddings have the right shape
        # This is crucial for the UNet cross-attention to work properly
        cond_embed = cond_embed.unsqueeze(0)
           
        return latent, cond_embed

2025-03-03 14:03:35.426379: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740999815.448629  771707 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740999815.455403  771707 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-03 14:03:35.478302: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
clip_vision_embeds = os.path.join(NIH_CXR14_DATASET_DIR, "clip_vision_embeddings")
latent_dir = os.path.join(NIH_CXR14_DATASET_DIR, "vae_latents2")
csv_path = "/home/yasin/Lfstorage/Projects/cxr-diffusion/intermediate_data/filtered_findings_label_data.csv"


In [4]:
dataset = EmbeddingLatentDataset(latent_dir, clip_vision_embeds, csv_path)
len(dataset)

Loaded 48311 image IDs from CSV
Total matching IDs after filtering: 48311
IDs in CSV but missing latents: 0
IDs in CSV but missing embeddings: 0


48311

In [5]:
sample = dataset[0]
print(sample[0].shape, sample[1].shape)

torch.Size([4, 64, 64]) torch.Size([1, 768])


In [6]:
config = TrainConfigV2(
    batch_size=4,
    mixed_precision=True,
    learning_rate=1e-4,
    num_epochs=20,
    gradient_accumulation_steps=64,  # Effective batch size of 128
    scheduler_type="cosine",
    early_stopping_patience=5,
    use_timestep_weights=True
)

In [7]:
unet = UNet2DConditionModel(act_fn="silu",
                                    attention_head_dim=8,
                                    center_input_sample=False,
                                    downsample_padding=1,
                                    flip_sin_to_cos=True,
                                    freq_shift=0,
                                    mid_block_scale_factor=1,
                                    norm_eps=1e-05,
                                    norm_num_groups=32,
                                    sample_size=64, # generated samples are 512x512
                                    in_channels=4, 
                                    out_channels=4, 
                                    layers_per_block=2, 
                                    block_out_channels=(320, 640, 1280, 1280), 
                                    down_block_types=(
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "CrossAttnDownBlock2D",
                                    "DownBlock2D"), 
                                    up_block_types=("UpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D",
                                    "CrossAttnUpBlock2D"),
                                    cross_attention_dim=768
                                )

In [8]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=config.learning_rate)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [9]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True, pin_memory=True)


In [10]:
validation_samples = get_validation_samples_v2(dataset, num_samples=16)

In [11]:
print(validation_samples[0].shape)

torch.Size([77, 768])


In [12]:
import tensorboard
trainer = Unet2DConditionalTrainerV2(
    unet=unet,
    train_config=config,
    noise_scheduler=noise_scheduler,
    optimizer=optimizer,
)

Log directory is output/logs
Output directory is output


In [13]:
trainer.train(dataloader = train_dataloader,  validation_samples = validation_samples)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 1/20 - Avg loss: 0.165842, Avg PSNR: 22.45
Saving model to output/best/best
New best model with loss: 0.165842
Saving model to output/checkpoints/epoch_1
Generating validation samples for epoch 1...
Initialized VaeProcessor for sample generation


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_1.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 2/20 - Avg loss: 0.102723, Avg PSNR: 24.56
Saving model to output/best/best
New best model with loss: 0.102723
Saving model to output/checkpoints/epoch_2
Generating validation samples for epoch 2...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_2.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 3/20 - Avg loss: 0.096640, Avg PSNR: 24.90
Saving model to output/best/best
New best model with loss: 0.096640
Saving model to output/checkpoints/epoch_3
Generating validation samples for epoch 3...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_3.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 4/20 - Avg loss: 0.091587, Avg PSNR: 25.17
Saving model to output/best/best
New best model with loss: 0.091587
Saving model to output/checkpoints/epoch_4
Generating validation samples for epoch 4...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_4.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 5/20 - Avg loss: 0.089256, Avg PSNR: 25.26
Saving model to output/best/best
New best model with loss: 0.089256
Saving model to output/checkpoints/epoch_5
Generating validation samples for epoch 5...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_5.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 6/20 - Avg loss: 0.086625, Avg PSNR: 25.40
Saving model to output/best/best
New best model with loss: 0.086625
Saving model to output/checkpoints/epoch_6
Generating validation samples for epoch 6...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_6.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 7/20 - Avg loss: 0.086041, Avg PSNR: 25.46
Saving model to output/best/best
New best model with loss: 0.086041
Saving model to output/checkpoints/epoch_7
Generating validation samples for epoch 7...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_7.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 8/20 - Avg loss: 0.085452, Avg PSNR: 25.50
Saving model to output/best/best
New best model with loss: 0.085452
Saving model to output/checkpoints/epoch_8
Generating validation samples for epoch 8...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_8.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 9/20 - Avg loss: 0.084354, Avg PSNR: 25.60
Saving model to output/best/best
New best model with loss: 0.084354
Saving model to output/checkpoints/epoch_9
Generating validation samples for epoch 9...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_9.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 10/20 - Avg loss: 0.082448, Avg PSNR: 25.67
Saving model to output/best/best
New best model with loss: 0.082448
Saving model to output/checkpoints/epoch_10
Generating validation samples for epoch 10...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_10.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 11/20 - Avg loss: 0.083510, Avg PSNR: 25.60
No improvement. Early stopping counter: 1/5
Saving model to output/checkpoints/epoch_11
Generating validation samples for epoch 11...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_11.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 12/20 - Avg loss: 0.081691, Avg PSNR: 25.69
Saving model to output/best/best
New best model with loss: 0.081691
Saving model to output/checkpoints/epoch_12
Generating validation samples for epoch 12...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_12.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 13/20 - Avg loss: 0.081701, Avg PSNR: 25.73
No improvement. Early stopping counter: 1/5
Saving model to output/checkpoints/epoch_13
Generating validation samples for epoch 13...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_13.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 14/20 - Avg loss: 0.081063, Avg PSNR: 25.77
Saving model to output/best/best
New best model with loss: 0.081063
Saving model to output/checkpoints/epoch_14
Generating validation samples for epoch 14...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_14.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 15/20 - Avg loss: 0.081508, Avg PSNR: 25.76
No improvement. Early stopping counter: 1/5
Saving model to output/checkpoints/epoch_15
Generating validation samples for epoch 15...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_15.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 16/20 - Avg loss: 0.080880, Avg PSNR: 25.80
Saving model to output/best/best
New best model with loss: 0.080880
Saving model to output/checkpoints/epoch_16
Generating validation samples for epoch 16...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_16.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 17/20 - Avg loss: 0.079760, Avg PSNR: 25.83
Saving model to output/best/best
New best model with loss: 0.079760
Saving model to output/checkpoints/epoch_17
Generating validation samples for epoch 17...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_17.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 18/20 - Avg loss: 0.079759, Avg PSNR: 25.85
Saving model to output/best/best
New best model with loss: 0.079759
Saving model to output/checkpoints/epoch_18
Generating validation samples for epoch 18...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_18.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 19/20 - Avg loss: 0.079039, Avg PSNR: 25.93
Saving model to output/best/best
New best model with loss: 0.079039
Saving model to output/checkpoints/epoch_19
Generating validation samples for epoch 19...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_19.png


  0%|          | 0/12078 [00:00<?, ?it/s]

Epoch 20/20 - Avg loss: 0.079677, Avg PSNR: 25.87
No improvement. Early stopping counter: 1/5
Saving model to output/checkpoints/epoch_20
Generating validation samples for epoch 20...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/samples_epoch_20.png
Saving model to output/final
Generating samples with final model...


Generating sample 1/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 2/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 3/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 4/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 5/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 6/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 7/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 8/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 9/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 10/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 11/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 12/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 13/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 14/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 15/16:   0%|          | 0/50 [00:00<?, ?it/s]

Generating sample 16/16:   0%|          | 0/50 [00:00<?, ?it/s]

Saved sample grid to output/samples/final_samples.png
Loading best model
Error loading model: Error no file named config.json found in directory output/best.


Traceback (most recent call last):
  File "/home/yasin/Lfstorage/Projects/cxr-diffusion/src/utils/train_v2.py", line 166, in load
    self.unet.from_pretrained(path)
  File "/home/yasin/Lfstorage/Projects/cxr-diffusion/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/yasin/Lfstorage/Projects/cxr-diffusion/.venv/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 687, in from_pretrained
    config, unused_kwargs, commit_hash = cls.load_config(
  File "/home/yasin/Lfstorage/Projects/cxr-diffusion/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/yasin/Lfstorage/Projects/cxr-diffusion/.venv/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 373, in load_config
    raise EnvironmentError(
OSError: Error no file named config.json found in directory output/best.


UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe