# Sanity Check For Rectified Flow

This notebook is a sanity check to make sure my code for Rectified Flow training works

In [None]:
# Define dataset and collator

from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

class PokemonDataset(Dataset):
    def __init__():
        super().__init__()

        self.ds = load_dataset("lambdalabs/pokemon-blip-captions")['train']
        self.ds = [row for row in self.ds if 'text' in row and len(row['text']) > 0]

    def __getitem__(self, idx):
        return self.ds[idx]
    
    def __len__(self):
        return len(self.ds)

class PokemonCollator:
    def __init__(self, tokenizer, image_size = 32, cfg_prob = 0.1):
        self.tokenizer = tokenizer
        self.cfg_prob = 0.1
    
    def __call__(self, batch):
        images = [d['image'] for d in batch]
        captions = [d['text'] if random.random() >= self.cfg_prob else "" for d in batch]

        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.RandomHorizontalFlip()
        ])

        images = [transform(img) for img in images]
        images = torch.stack(images)

        tokenizer_out = self.tokenizer(captions, return_tensors='pt', padding='max_length', truncation=True, max_length=77)
        input_ids = tokenizer_out['input_ids']
        attention_mask = tokenizer_out['attention_mask']

        # Turn the images into mock videos
        videos = eo.repeat(images, 'b c h w -> b n c h w', n = self.target_frames).contiguous()

        return {
            "pixel_values" : videos,
            "input_ids" : input_ids,
            "attention_mask" : attention_mask
        }        

In [None]:
from common.nn.denoisers import (
    ConditionedRectFlowTransformer,
    CLIPConditioner
)
from common.configs import ViTConfig

from transformers import CLIPTextModel, CLIPTokenizer

clip_id = "openai/clip-vit-base-patch32"
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
clip_lm = CLIPTextModel.from_pretrained(clip_id)

text_encoder = CLIPConditioner(clip_lm, tokenizer, layer_skip = -2, hidden_size = 512)

config = ViTConfig(
    n_layers = 12,
    n_heads = 12,
    hidden_size = 768,
    input_shape = (3, 32, 32),
    patching = (4, 4)
)

denoiser = ConditionedRectFlowTransformer(
    config,
    text_encoder
)

In [None]:
# Test singular forward pass to compute loss
ds = PokemonDataset()
dc = PokemonCollator(tokenizer)

import torch.utils.data

loader = DataLoader(ds, collate_fn = dc, batch_size = 4)

batch = next(iter(loader))
loss = denoiser(**batch)
print(loss.item())

In [None]:
from common.trainer import Trainer
from common.configs import ProjectConfig, TrainConfig, LoggingConfig

config = ProjectConfig(
    TrainConfig(
        batch_size = 4, target_batch = 4,
        epochs = 20,
        save_every = 9999,
        sample_every = 101,
        eval_every = 1000,
        checkpoint_dir = "./pokemon_rft_out",
        train_state_checkpoint = "./trainer_state",
        resume = False
    ),
    LoggingConfig(
        run_name = "Pokemon Test",
    )
)

trainer = Trainer(
    denoiser,
    ds, dc,
    config = 
)

In [None]:
import torch

opt = torch.optim.AdamW(denoiser.parameters(), lr = 1.0e-4)

EPOCHS = 10

for epoch in range(EPOCHS):
    for batch in loader:
        loss = denoiser(**batch)
        opt.zero_grad()
        loss.backward()
        opt.step()

        print(loss.item())