In [1]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size:int =  224
    train_batch_size:int =  128
    eval_batch_size:int =  128
    num_epochs:int =  20
    learning_rate:float =  1e-4
    lr_warmup_steps:int = 500
    save_image_epochs:int = 10
    save_model_epochs:int = 10
    mixed_precision:str = "fp16"
    output_dir:str = "output"
    save_model_epochs: int = 2
    gradient_accumulation_steps: int = 1




config = TrainingConfig()


In [2]:

import torch
from torch.utils.data import Dataset, DataLoader

class TextLatentDataset(Dataset):
    def __init__(self, text_embeds, latents):
        self.text_embeds = text_embeds
        self.latents = latents

    def __len__(self):
        return len(self.text_embeds)

    def __getitem__(self, idx):
        # Text embedding al
        text_embed = torch.tensor(self.text_embeds[idx], dtype=torch.float32)
        
        # Eğer 77'den küçükse padding yap
        if text_embed.shape[0] < 77:
            padding = torch.zeros((77 - text_embed.shape[0], 768), dtype=torch.float32)
            text_embed = torch.cat([text_embed, padding], dim=0)
        # Eğer 77'den büyükse kes
        elif text_embed.shape[0] > 77:
            text_embed = text_embed[:77, :]

        # Latent al
        latent = torch.tensor(self.latents[idx], dtype=torch.float32)
        
        return text_embed, latent

In [3]:
import pickle 

file_path = "./data/embeddings.pkl"

with open(file_path, "rb") as f:
    embeddings = pickle.load(f)

text_embeds = embeddings["texts"]
latents = embeddings["latents"]

# Convert dictionary to list based on sorted keys
text_embeds_list = [text_embeds[i] for i in sorted(text_embeds.keys())]
latents_list = [latents[i] for i in sorted(latents.keys())]

# Now we can slice
val_text_embeds = text_embeds_list[0:1000]
val_latents = latents_list[0:1000]

train_text_embeds = text_embeds_list[1000:]
train_latents = latents_list[1000:]

train_dataset_len = len(train_text_embeds)

train_dataset = TextLatentDataset(train_text_embeds, train_latents)
val_dataset = TextLatentDataset(val_text_embeds, val_latents)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config.train_batch_size, 
    shuffle=True, 
    num_workers=24, 
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config.eval_batch_size, 
    shuffle=False, 
    num_workers=24, 
    pin_memory=True
)

In [4]:
from diffusers import UNet2DConditionModel  

model = UNet2DConditionModel(
    sample_size=config.image_size,
    in_channels=4,  # latent channels
    out_channels=4, # latent channels
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ),
    cross_attention_dim=768,  # CLIP hidden size
)



In [5]:

from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [6]:
import torch
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(train_dataset_len * config.num_epochs),
)

In [7]:
import os
from accelerate import Accelerator
import torch
import tensorboard  # Add this import

torch.cuda.set_device(2) 

# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
if accelerator.is_main_process:
    if config.output_dir is not None:
        os.makedirs(config.output_dir, exist_ok=True)
    accelerator.init_trackers("train_example")


model, optimizer, lr_scheduler, train_loader, val_loader = accelerator.prepare(
    model, optimizer, lr_scheduler, train_loader, val_loader
)

In [8]:
from tqdm.notebook import tqdm
import torch.nn.functional as F
import os
from diffusers import UNet2DModel
from accelerate.utils import ProjectConfiguration
from huggingface_hub import HfFolder, Repository

def save_model_checkpoint(model, accelerator, output_dir, epoch):
    # Unwrap the model from accelerator to get the original model
    unwrapped_model = accelerator.unwrap_model(model)
    
    # Save the model state
    if accelerator.is_main_process:
        # Create pipeline folder
        os.makedirs(os.path.join(output_dir, f"checkpoint-{epoch}"), exist_ok=True)
        
        # Save the model in diffusers format
        unwrapped_model.save_pretrained(os.path.join(output_dir, f"checkpoint-{epoch}/unet"))
        
        # Save scheduler and other components if needed
        # noise_scheduler.save_pretrained(os.path.join(output_dir, f"checkpoint-{epoch}/scheduler"))

def train_loop(
    config,
    model,
    noise_scheduler,
    optimizer,
    train_loader,
    val_loader,
    lr_scheduler,
    accelerator
):
    global_step = 0
    
    # Progress bar for epochs
    progress_bar = tqdm(range(config.num_epochs), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Epochs")

    for epoch in range(config.num_epochs):
        model.train()
        
        # Progress bar for steps
        step_progress_bar = tqdm(total=len(train_loader), disable=not accelerator.is_local_main_process)
        step_progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_loader):
            encoder_hidden_states, latents = batch

            # Generate noise
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device).long()

            # Add noise to the latents according to the noise magnitude at each timestep
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                
                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Logging
            if accelerator.is_main_process:
                logs = {
                    "loss": loss.detach().item(),
                    "lr": lr_scheduler.get_last_lr()[0],
                    "step": global_step,
                    "epoch": epoch,
                }
                step_progress_bar.set_postfix(**logs)
                accelerator.log(logs, step=global_step)

            global_step += 1
            step_progress_bar.update(1)

        # Save checkpoint at specified intervals
        if (epoch + 1) % config.save_model_epochs == 0:
            save_model_checkpoint(model, accelerator, config.output_dir, epoch + 1)
            
        step_progress_bar.close()
        progress_bar.update(1)
    
    # Save final model
    save_model_checkpoint(model, accelerator, config.output_dir, "final")
    
    return global_step

In [9]:
global_step = train_loop(
        config=config,
        model=model,
        noise_scheduler=noise_scheduler,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        lr_scheduler=lr_scheduler,
        accelerator=accelerator
    )

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]

  0%|          | 0/462 [00:00<?, ?it/s]