In [1]:
import os
import sys
sys.path.append(os.path.abspath("..")) 


from tqdm.notebook import tqdm
from dataclasses import dataclass
from dotenv import load_dotenv
load_dotenv()



import pickle

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset

from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

from src.models import FeatureConditionedUNet, FeatureConditionedUNetConfig



## Configurations

In [2]:
unet_config = {
    # Architecture parameters
    "sample_size": 28,  # for 224x224 images (224 = 28 * 8)
    "in_channels": 4,
    "out_channels": 4,
    "layers_per_block": 2,
    "block_out_channels": (320, 640, 1280, 1280),
    "down_block_types": (
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ),
    "up_block_types": (
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D"
    ),
    
    # Attention parameters
    "attention_head_dim": 8,
    "cross_attention_dim": 768,
    
    # Normalization and activation
    "norm_num_groups": 32,
    "norm_eps": 1e-05,
    "act_fn": "silu",
    
    # Additional configuration
    "center_input_sample": False,
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "mid_block_scale_factor": 1
}


# To create the model:


@dataclass
class train_config:
    device = "cuda:2"
    num_workers = 24
    batch_size = 16
    mixed_precision = "fp16"    
    output_dir = "output"
    save_model_epochs = 3
    num_epochs = 20
    num_train_timesteps = 1000
    learning_rate = 1e-5
    lr_warmup_steps:int = 500
    unet_config = unet_config
    feature_dim = 4096
    projection_config = {
        "hidden_dim": 1024,
        "num_hidden_layers": 2
    }


config = train_config()

# Datasets

In [3]:
class LatentCondEmbedingDataset(Dataset):
    def __init__(self, latent, condembeds):
        self.latent = latent
        self.condembeds = condembeds

    def __len__(self):
        return len(self.latent)

    def __getitem__(self, idx):
        return self.latent[idx], self.condembeds[idx]
    



file_path = './data/nih-cxr14-latent-embed.pkl'

with open(file_path, 'rb') as f:
    data = pickle.load(f)

latent = data['latents']
condembeds = data['embeddings']


dataset = LatentCondEmbedingDataset(latent, condembeds)


dataloader = DataLoader(dataset, 
                        batch_size=config.batch_size, 
                        shuffle=True,
                        num_workers=config.num_workers,
                        pin_memory=True,
                        )

dataset_size = len(dataset)

# Creating Model, Noise Scheduler And Optimizer

In [4]:

model_config = FeatureConditionedUNetConfig(
    feature_dim=config.feature_dim,
    unet_config=config.unet_config,
    projection_config=config.projection_config
)

model = FeatureConditionedUNet(model_config)

noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)


lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(dataset_size * config.num_epochs))



In [5]:
from accelerate import Accelerator
import tensorboard

torch.cuda.set_device(2)

accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=1,
    log_with='tensorboard',
    project_dir=os.path.join(config.output_dir, "logs")
    )


if accelerator.is_main_process:
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)
    accelerator.init_trackers("train_example")




def save_model_checkpoint(model, accelerator, output_dir, epoch):
    unwrapped_model = accelerator.unwrap_model(model)

    if accelerator.is_main_process:
        os.makedirs(os.path.join(output_dir, f"checkpoint_{epoch:03d}"), exist_ok=True)

        unwrapped_model.save_pretrained(
            os.path.join(output_dir, f"checkpoint_{epoch:03d}")
        )





In [9]:
def train(config, model, optimizer, noise_scheduler, data_loader, lr_scheduler, accelerator):
    
    global_step = 0

    # Progress bar for epochs
    progress_bar = tqdm(range(config.num_epochs), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Epochs")


    for epoch in range(config.num_epochs):
        model.train()

        step_progress_bar = tqdm(total=len(data_loader), disable=not accelerator.is_local_main_process)
        step_progress_bar.set_description(f"Epoch {epoch}")


        for step, batch in enumerate(data_loader):
            latents, condembeds = batch

            # Reshape condembeds from (1, 32, 128) to (batch_size, 4096)
            condembeds = condembeds.reshape(condembeds.shape[0], -1)  # -1 otomatik olarak 32*128=4096 yapacak


            noise = torch.rand_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device).long()

            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_latents, timesteps, features=condembeds, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)

            
                #Bacpropagation
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            # Logging
            if accelerator.is_main_process:
                logs = {
                    "loss": loss.detach().item(),
                    "lr": lr_scheduler.get_last_lr()[0],
                    "step": global_step,
                    "epoch": epoch,
                }
                step_progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)

            global_step += 1
            step_progress_bar.update(1)

        # Save checkpoint at specified intervals
        if (epoch + 1) % config.save_model_epochs == 0:
            save_model_checkpoint(model, accelerator, config.output_dir, epoch + 1)
            
        step_progress_bar.close()
        progress_bar.update(1)
    
    # Save final model
    save_model_checkpoint(model, accelerator, config.output_dir, "final")
    
    return global_step

In [10]:
#Move model to accelerator
model, optimizer, lr_scheduler, dataloader, accelerator = accelerator.prepare(
    model, optimizer, lr_scheduler, dataloader, accelerator
)




In [11]:
train(config, model, optimizer, noise_scheduler, dataloader, lr_scheduler, accelerator)


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1834 [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x2048 and 4096x1024)