In [1]:
import torch
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated())

0


In [2]:
from datasets import load_dataset, load_from_disk

# dataset = load_dataset('lambdalabs/pokemon-blip-captions')
dataset = load_dataset("atasoglu/flickr8k-dataset", data_dir="data")

In [3]:
dataset.get('train')

Dataset({
    features: ['image_id', 'image_path', 'captions'],
    num_rows: 6000
})

In [4]:
dataset.get('train')[110]

{'image_id': '2730819220_b58af1119a',
 'image_path': '/home/veezbo/.cache/huggingface/datasets/downloads/extracted/8c0281a0d6433d492cf11514ee37574297a203fdd2f114c2b7edb98bf297a371/Flicker8k_Dataset/2730819220_b58af1119a.jpg',
 'captions': ['a little girl ends up at the bottom of the slide .',
  'a little girl is just reaching the bottom of a playground slide',
  'A little girl lands at the bottom of a slide .',
  'A young girl reaches the bottom of a slide .',
  'a young girl sliding down a tan plastic slide']}

In [5]:
# display(dataset.get('train')[110]['image'])

In [6]:
from transformers import GPT2TokenizerFast
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from typing import Literal
import os

In [7]:
BATCH_SIZE = 4
CONTEXT_SIZE = 32
# PAD_TOKEN = '[PAD]'
IMAGE_SIZE = 256

MODEL = 'gpt2-large'
DATASET_TO_USE: Literal['pokemon', 'flickr'] = 'flickr'

In [8]:
# TODO(1): Is this causing problems? Should we actually use the pad token
# TODO(1): Should we do left padding instead to keep the important data together
# TODO(1): Actually validate that the attention mask coming from the tokenizer is correctly masking just the pad tokens
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL)
# tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'


# def resize_images(batch):
#     # Specify your target size
#     target_size = (224, 224)
#     # Load the image
#     image = Image.open(batch['image_path'])
#     # Resize the image
#     image = image.resize(target_size)
#     # Convert the image to a numpy array and normalize pixel values to [0, 1]
#     batch['image'] = np.array(image) / 255.0
#     return batch

# def tokenize_and_pad_texts(batch):
#     # Tokenize the texts
#     tokenized_batch = tokenizer(batch['text'], padding='longest', truncation=True, max_length=CONTEXT_SIZE)
#     return tokenized_batch

train_transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),  # NOTE: This shuffles from (H, W, C) to (C, H, W)
        transforms.Normalize([0.5], [0.5]),  # TODO(2): Adjust normalization if needed. Can do linear transformation using transforms.lambda
    ]
)

def tokenize_and_resize_pokemon(examples):
    tokenizer_output = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=CONTEXT_SIZE)
    return {
        'input_ids': tokenizer_output['input_ids'],
        'attention_mask': tokenizer_output['attention_mask'],
        'image': train_transforms(examples['image']),
    }

def tokenize_and_resize_flickr(examples):
    tokenizer_output = tokenizer(examples['captions'][0], padding='max_length', truncation=True, max_length=CONTEXT_SIZE)
    # print(examples['captions'][0])
    return {
        'input_ids': tokenizer_output['input_ids'],
        'attention_mask': tokenizer_output['attention_mask'],
        'image': train_transforms(Image.open(examples['image_path'])),
        'text': examples['captions'][0]
    }


match DATASET_TO_USE:
    case 'pokemon':
        f = f"./data/pokemon_train_dataset_{MODEL}_{IMAGE_SIZE}_{CONTEXT_SIZE}.hf"
        if os.path.exists(f):
            train_dataset = load_from_disk(f)
        else:
            train_dataset = dataset['train'].map(tokenize_and_resize_pokemon).with_format('torch')
            train_dataset.save_to_disk(f)
    case 'flickr':
        f = f"./data/flickr_train_dataset_{MODEL}_{IMAGE_SIZE}_{CONTEXT_SIZE}.hf"
        if os.path.exists(f):
            train_dataset = load_from_disk(f)
        else:
            train_dataset = dataset['train'].map(tokenize_and_resize_flickr).with_format('torch')
            train_dataset.save_to_disk(f)
            

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

In [9]:
# max(len(x.split(' ')) for x in train_dataset['text'])

In [10]:
print(tokenizer.pad_token_id)

50256


In [11]:
next(iter(train_loader))

{'image_id': ['3391209042_d2de8a8978',
  '3684518763_f3490b647a',
  '2272426567_9e9fb79db0',
  '362316425_bda238b4de'],
 'image_path': ['/home/veezbo/.cache/huggingface/datasets/downloads/extracted/8c0281a0d6433d492cf11514ee37574297a203fdd2f114c2b7edb98bf297a371/Flicker8k_Dataset/3391209042_d2de8a8978.jpg',
  '/home/veezbo/.cache/huggingface/datasets/downloads/extracted/8c0281a0d6433d492cf11514ee37574297a203fdd2f114c2b7edb98bf297a371/Flicker8k_Dataset/3684518763_f3490b647a.jpg',
  '/home/veezbo/.cache/huggingface/datasets/downloads/extracted/8c0281a0d6433d492cf11514ee37574297a203fdd2f114c2b7edb98bf297a371/Flicker8k_Dataset/2272426567_9e9fb79db0.jpg',
  '/home/veezbo/.cache/huggingface/datasets/downloads/extracted/8c0281a0d6433d492cf11514ee37574297a203fdd2f114c2b7edb98bf297a371/Flicker8k_Dataset/362316425_bda238b4de.jpg'],
 'captions': [('Two figures stand in a snowy setting wearing white and hot pink outfits , gazing towards a mountain .',
   'A guy in gold chains , a black top , and g

In [12]:
next(iter(train_loader)).keys()

dict_keys(['image_id', 'image_path', 'captions', 'input_ids', 'attention_mask', 'image', 'text'])

In [13]:
next(iter(train_loader))['image'].dtype

torch.float32

In [14]:
import torch
import os
import glob
import re
from typing import Literal
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

NUM_EPOCHS = 4096
LEARNING_RATE = 3e-5
NUM_DIFFUSION_TIMESTEPS = 100
LATENT_DIM = IMAGE_SIZE // 8
EVAL_ITERATIONS = 100
PREDICTION_TYPE: Literal['epsilon', 'sample'] = 'epsilon'

In [15]:
import torch.nn as nn
import torch.nn.functional as F

In [16]:
print(torch.cuda.memory_allocated())

0


In [17]:
from diffusers import DDPMScheduler
from diffusers.optimization import get_constant_schedule
from model import LLourney

model = LLourney()
model.to(DEVICE)
scheduler = DDPMScheduler(num_train_timesteps=NUM_DIFFUSION_TIMESTEPS, prediction_type=PREDICTION_TYPE)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_schedule = get_constant_schedule(optimizer)

In [18]:
print(torch.cuda.memory_allocated())

3548738560


In [19]:
thing = next(iter(train_loader))
# elimage = thing['image'].to(DEVICE)
# eltext_id = thing["input_ids"].to(DEVICE)
# elpad_mask = thing["attention_mask"].to(DEVICE)
# print(elimage.shape)
# print(eltext_id.shape)
# print(elpad_mask.shape)

In [20]:
# thing = next(iter(train_loader))
# elimage = torch.stack([thing['image'][0]] * BATCH_SIZE, dim=0).to(DEVICE)
# eltext_id = torch.stack([thing['input_ids'][0]] * BATCH_SIZE, dim=0).to(DEVICE)
# elpad_mask = torch.stack([thing['attention_mask'][0]] * BATCH_SIZE, dim=0).to(DEVICE)

# elimage = torch.ones([BATCH_SIZE, 3, 128, 128]).to(DEVICE)
# eltext_id = torch.stack([thing['input_ids'][0]] * BATCH_SIZE, dim=0).to(DEVICE)
# elpad_mask = torch.stack([thing['attention_mask'][0]] * BATCH_SIZE, dim=0).to(DEVICE)

# print(elimage.shape)
# print(eltext_id.shape)
# print(elpad_mask.shape)

In [21]:
# Load checkpoint if available
step_i = 0
CHECKPOINT_FILE_BASE_PATH = f"model_checkpoint_flickr_llm_100difftimestep_epsilon_32context_patchdim2_imgdim256_gptlarge"

checkpoints = glob.glob(f"{CHECKPOINT_FILE_BASE_PATH}*")
if len(checkpoints):
    steps = [int(re.findall(f"{CHECKPOINT_FILE_BASE_PATH}_step_(\d+).pt", checkpoint)[0]) for checkpoint in checkpoints]
    max_step = max(steps)
    CHECKPOINT_FILE_PATH = f"{CHECKPOINT_FILE_BASE_PATH}_step_{max_step}.pt"
    print(f"LOADING CHECKPOINT FROM: {CHECKPOINT_FILE_PATH}")
    checkpoint = torch.load(CHECKPOINT_FILE_PATH, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    step_i = checkpoint['step_i']

# model.to(DEVICE)

In [22]:
print(torch.cuda.memory_allocated())

3548738560


In [23]:
# TRAINING LOOP

model.train()
loss_evals = []
for epoch in range(NUM_EPOCHS):
    
    for step, batch in enumerate(train_loader):
        step_i += 1
        
        # Set image, text, and pad mask to randomly sampled batch
        image = batch["image"].to(DEVICE)
        text_id = batch["input_ids"].to(DEVICE)
        pad_mask = batch["attention_mask"].to(DEVICE)
        # Set image, text, and pad mask to single batch below
        # image = elimage
        # text_id = eltext_id
        # pad_mask = elpad_mask
        
        batch_size = image.shape[0]

        # Use VAE to encode image in latent space
        latent_image = model.encode_image(image)
        assert latent_image.shape[-1] == LATENT_DIM

        # Sample gaussian with same shape as latent image
        Z = torch.randn_like(latent_image).to(DEVICE)

        # Sample timesteps, one for each item in the batch
        T = torch.randint(1, NUM_DIFFUSION_TIMESTEPS, (batch_size,)).to(DEVICE)
        T = T.long()

        # Transform input images to noisy images now
        noisy_latent_image = scheduler.add_noise(latent_image, Z, T)

        match PREDICTION_TYPE:
            case 'epsilon':
                # Use the model to predict the noise 
                latent_noise = model(noisy_latent_image, text_id, T, text_pad_mask=pad_mask)
                # Calculate the loss
                loss_node = F.mse_loss(latent_noise.float(), Z, reduction='mean')
            case 'sample':
                # Use the model to predict the latent image
                denoised_latent_image = model(noisy_latent_image, text_id, T, text_pad_mask=pad_mask)
                # Calculate the loss
                loss_node = F.mse_loss(denoised_latent_image.float(), latent_image, reduction='mean')        

        # Backpropagate
        loss_node.backward()
        optimizer.step()
        lr_schedule.step()
        optimizer.zero_grad()

        loss_evals.append(loss_node.detach().item())
        if step_i % EVAL_ITERATIONS == 0:
            print(f'step_{step_i}: {sum(loss_evals)/len(loss_evals)}')
            # print('memory allocated:', torch.cuda.memory_allocated(), 'max memory allocated:', torch.cuda.memory_allocated())
            # print('memory reserved:', torch.cuda.memory_reserved(), 'max memory reserved:', torch.cuda.max_memory_reserved())
            loss_evals = []


step_100: 0.9997423416376114
step_200: 0.7362519389390946
step_300: 0.6154901888966561
step_400: 0.6148268097639084
step_500: 0.5997130024433136
step_600: 0.6181740722060204
step_700: 0.5955809274315834
step_800: 0.5821639862656594
step_900: 0.5743371146917343
step_1000: 0.5724548202753067
step_1100: 0.5857533740997315
step_1200: 0.5876998236775398
step_1300: 0.5632940590381622
step_1400: 0.5459697914123535
step_1500: 0.5556989958882332
step_1600: 0.5453901273012162
step_1700: 0.5867823848128318
step_1800: 0.5449166482686997
step_1900: 0.555638926923275
step_2000: 0.5670486050844192
step_2100: 0.5322461694478988
step_2200: 0.5595515170693397
step_2300: 0.5347107174992561
step_2400: 0.545967561006546
step_2500: 0.5386095994710922
step_2600: 0.5517191290855408
step_2700: 0.5441773712635041
step_2800: 0.569552618265152
step_2900: 0.5321801468729973
step_3000: 0.570259618461132
step_3100: 0.547101169526577
step_3200: 0.5556990075111389
step_3300: 0.5461532717943192
step_3400: 0.54290482074

KeyboardInterrupt: 

In [24]:
CHECKPOINT_FILE_PATH = f"{CHECKPOINT_FILE_BASE_PATH}_step_{step_i}.pt"
print(CHECKPOINT_FILE_PATH)

In [25]:
step_i

22899

In [26]:
# Save checkpoint
CHECKPOINT_FILE_PATH = f"{CHECKPOINT_FILE_BASE_PATH}_step_{step_i}.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'step_i': step_i,
}, CHECKPOINT_FILE_PATH)
print(f"Saved checkpoint to {CHECKPOINT_FILE_PATH}")

Saved checkpoint to model_checkpoint_flickr_llm_100difftimestep_epsilon_32context_patchdim2_imgdim256_gptlarge_step_22899.pt


In [None]:
NUM_INFERENCE_STEPS = 10
DEL_OPTIMIZER: bool = True

In [None]:
thing['captions']

In [None]:
if DEL_OPTIMIZER:
    del optimizer
    torch.cuda.empty_cache()

model.eval()

In [None]:
print('memory allocated:', torch.cuda.memory_allocated(), 'max memory allocated:', torch.cuda.memory_allocated())
print('memory reserved:', torch.cuda.memory_reserved(), 'max memory reserved:', torch.cuda.max_memory_reserved())

In [None]:
from diffusers.pipelines.pipeline_utils import numpy_to_pil
from diffusers import DDPMScheduler

from model import LLourney

from PIL import Image

# model = LLourney()

# inprogress = []

# if DEL_OPTIMIZER:
#     del optimizer
#     torch.cuda.empty_cache()

# NOTE: THERE'S SOMETHING REALLY BAD HAPPENING AT GENERATION TIME WITH MEMORY

def encode_prompt(prompt):
    tokenizer_output = tokenizer(prompt, padding='max_length', truncation=True, max_length=CONTEXT_SIZE)
    input_ids = torch.LongTensor(tokenizer_output["input_ids"]).to(DEVICE)
    attention_mask = torch.Tensor(tokenizer_output["attention_mask"]).to(DEVICE)
    return input_ids, attention_mask

@torch.no_grad()
def generate(prompt: str) -> Image:
    # model.eval()
    prompt = [prompt]

    print('memory allocated:', torch.cuda.memory_allocated(), 'max memory allocated:', torch.cuda.memory_allocated())
    print('memory reserved:', torch.cuda.memory_reserved(), 'max memory reserved:', torch.cuda.max_memory_reserved())

    # NOTE: Replace Z with below in order to conditon generation on single image from the batch 'thing'
    # latent_conditional_image = model.encode_image(elimage[:1, :, :, :])
    
    Z = torch.randn(len(prompt), 4, LATENT_DIM, LATENT_DIM, device=DEVICE)
    scheduler.set_timesteps(NUM_INFERENCE_STEPS)
    timesteps = scheduler.timesteps

    input_ids, attention_mask = encode_prompt(prompt)
    latent_images = Z
    # latent_images = latent_conditional_image
    print(f"Latent noise shape: {latent_images.shape}")
    
    for i, t in enumerate(timesteps):

        torch.cuda.empty_cache()
        
        if i % (len(timesteps) // 10) == 0:
            print(f"Diffused to timestep {i}/{len(timesteps)}")
            # print('memory allocated:', torch.cuda.memory_allocated(), 'max memory allocated:', torch.cuda.memory_allocated())
            # print('memory reserved:', torch.cuda.memory_reserved(), 'max memory reserved:', torch.cuda.max_memory_reserved())
            
        latent_images = scheduler.scale_model_input(latent_images, t)
        batched_t = torch.cat([torch.tensor([t])] * len(prompt), dim=0).to(DEVICE)
        
        model_output = model(latent_images, input_ids, batched_t, text_pad_mask=attention_mask)
        latent_images = scheduler.step(model_output, t, latent_images, return_dict=False)[0]

        # decoded_imagess = model.decode_image_latents(latent_images)
        # image = [numpy_to_pil(img) for img in decoded_imagess]
        # inprogress.append(image[0][0])

    print(latent_images.shape)
    
    decoded_images = model.decode_image_latents(latent_images)
    image = [numpy_to_pil(img) for img in decoded_images]

    return image[0][0]

# Generate some images
# for _ in range(10):
#     image = generate("A girl in a cowboy hat with a sheep on a leash")
#     display(image)
# image = generate("a drawing of a red and yellow pokemon character")
# display(image)
# image = generate("a very cute looking pokemon with a big beak")
# display(image)
for i in range(len(thing['text'])):
    image = generate(thing['text'][i])
    print(thing['text'][i])
    display(image)
    display(Image.open(thing['image_path'][i]))

# print(latent_images.shape)
# decoded_images = model.decode_image_latents(latent_images)
# image = [numpy_to_pil(img) for img in decoded_images]

# # Generate some images
# prompt = ["pokeymans", "big dog with cricket bat", "florgs is GAYYYYYY"]
# Z = torch.randn(len(prompt), 4, LATENT_DIM, LATENT_DIM, device=DEVICE)
# scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)
# scheduler.set_timesteps(NUM_INFERENCE_STEPS)
# timesteps = scheduler.timesteps

# def encode_prompt(prompt):
#     tokenizer_output = tokenizer(prompt, padding='max_length', truncation=True, max_length=CONTEXT_SIZE)
#     input_ids = torch.LongTensor(tokenizer_output["input_ids"]).to(DEVICE)
#     attention_mask = torch.Tensor(tokenizer_output["attention_mask"]).to(DEVICE)
#     return input_ids, attention_mask

# input_ids, attention_mask = encode_prompt(prompt)
# latent_images = Z
# print(f"Latent noise shape: {latent_images.shape}")
# for i, t in enumerate(timesteps):
#     batched_t = torch.cat([torch.tensor([t])] * len(prompt), dim=0)
#     print(f"Batched t shape: {batched_t.shape}")
#     print(f"Input ids 666: {input_ids.shape}")
#     print(f"Attention mask 666: {attention_mask.shape}")
#     model_output = model(latent_images, input_ids, batched_t, text_pad_mask=attention_mask)

#     latent_images = scheduler.step(model_output, t, latent_images).prev_sample

# print(latent_images.shape)
# decoded_images = model.decode_image_latents(latent_images)
# image = [numpy_to_pil(img) for img in decoded_images]

# display(image[0])


In [None]:
thing['text'][0]

In [None]:
display(image)

In [None]:
latent_images = model.encode_image(thing["image"].to(DEVICE))
decoded_images = model.decode_image_latents(latent_images)
image = [numpy_to_pil(img) for img in decoded_images]

In [None]:
display(image[0][0])

In [None]:
for i in inprogress:
    display(i)

In [None]:
import datetime

image.save(f"pred_img_{step_i}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")

In [None]:
display(image)

In [None]:
display(image[0][0])

In [None]:
display(image[1][0])

In [None]:
display(image[2][0])

In [None]:
from diffusers import DDPMScheduler

scheduler = DDPMScheduler()

scheduler.set_timesteps(100)

timesteps = scheduler.timesteps

print(timesteps)
print(timesteps[0])

In [None]:
model.llama.config.n_positions

In [None]:
model.llama.config.n_embd

In [None]:
from diffusers.pipelines.pipeline_utils import numpy_to_pil
from diffusers import DDPMScheduler

# Test what adding latent noise makes VAE look like

# # First, display image without any latent noise
# display(Image.open(thing['image_path'][0]))
# latent_images = model.encode_image(thing["image"].to(DEVICE))
# decoded_images = model.decode_image_latents(latent_images)
# image = [numpy_to_pil(img) for img in decoded_images]
# display(image[0][0])

# # Then, display image after having added latent noise
# Z = torch.randn_like(latent_images)
# latent_images_noise_rand = latent_images + Z
# decoded_images_noise_rand = model.decode_image_latents(latent_images_noise_rand)
# image_noisy_rand = [numpy_to_pil(img) for img in decoded_images_noise_rand]
# display(image_noisy_rand[0][0])

# # Then, display image after having added latent noise with scheduler as in the model
# t = torch.randint(0, NUM_DIFFUSION_TIMESTEPS, (1,), device=DEVICE, dtype=torch.long)
# # t = torch.tensor([300], device=DEVICE, dtype=torch.long)
# print("timestep t:", t)
# latent_images_noise_sched = scheduler.add_noise(latent_images, Z, t)
# decoded_images_noise_sched = model.decode_image_latents(latent_images_noise_sched)
# image_noisy_sched = [numpy_to_pil(img) for img in decoded_images_noise_sched]
# display(image_noisy_sched[0][0])

# norm_z = torch.norm(Z)
# norm_latent_images = torch.norm(latent_images)
# norm_latent_images_noise_rand = torch.norm(latent_images_noise_rand)
# print(f"Norm of z: {norm_z}")
# print(f"Norm of latent_images: {norm_latent_images}")
# print(f"Norm of latent_images_noise_rand: {norm_latent_images_noise_rand}")
# mse_dist = F.mse_loss(norm_z, norm_latent_images)
# mse_dist2 = F.mse_loss(norm_z, norm_latent_images_noise_rand)
# print(f"MSE between z and latent_images: {mse_dist}")
# print(f"MSE between z and latent_images_noise_rand: {mse_dist2}")