In [1]:
import numpy as np
import torch
import torch.nn as nn
import random
from datasets import load_from_disk
from loguru import logger
from PIL import Image
from transformers import AutoTokenizer, GPTJForCausalLM
from lmm_synthetic.mm_train.gptj_vlm import GPTJ_VLM
from lmm_synthetic.mm_train.utils import load_vision_encoder
import time

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths 
dataset_path = '/data/lmm/generated/v3_spatial_grid_multimodal'
vlm_path = '/home/allanz/data/vlm_checkpoint/final_model'
lm_path = "/data/lmm/checkpoints/lm/lm-pretrain-only-checkpoint-1953"

# Load dataset
dataset = load_from_disk(dataset_path)
print(dataset)

# Load VLM and CLIP model
def load_model_and_tokenizer(model_path, multimodal=False):
    """
    Load the model and tokenizer from the specified path.
    """
    logger.info(f"Loading model and tokenizer from {model_path}")
    model, tokenizer = None, None
    if multimodal:
        model = GPTJ_VLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model.config.pretrained_lm_path)
    else:
        model = GPTJForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.eval()
    return model, tokenizer

vlm, vlm_tokenizer = load_model_and_tokenizer(vlm_path, multimodal=True)

clip_vision_model = vlm.vision_encoder.clip_vision_model
encoder, image_transforms, _ = load_vision_encoder("clip")

#Parse grid
def parse_grid(grid_str, K):
    """
    Parse the grid string into a 2D list of grid cells.
    """
    grid_str = '\n'.join(grid_str.split('\n')[:K])
    rows = grid_str.strip().split('\n')
    return [[cell.strip() for cell in row.split('|') if cell.strip()] for row in rows]

  from .autonotebook import tqdm as notebook_tqdm
[32m2024-12-17 14:20:51.533[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model_and_tokenizer[0m:[36m29[0m - [1mLoading model and tokenizer from /home/allanz/data/vlm_checkpoint/final_model[0m


DatasetDict({
    train: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 100000
    })
    validation: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
})


[32m2024-12-17 14:20:53.402[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 14:20:53.403[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m
[32m2024-12-17 14:20:53.925[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_multimodal_projector[0m:[36m89[0m - [1mLoading multimodal projector: linear[0m
[32m2024-12-17 14:20:54.066[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 14:20:54.067[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m


In [2]:
# Linear layer
class LinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearLayer, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

In [3]:
patches = [[1, 2, 3, 8, 9, 10, 15, 16, 17], [3, 4, 5, 10 , 11, 12, 17, 18, 19], [5, 6, 7, 12, 13, 14, 19, 20, 21],
           [15, 16, 17, 22, 23, 24, 29, 30, 31], [17, 18, 19, 24, 25, 26, 31, 32, 33], [19, 20, 21, 26, 27, 28, 33, 34, 35],
           [29, 30, 31, 36, 37, 38, 43, 44, 45], [31, 32, 33, 38, 39, 40, 45, 46, 47], [33, 34, 35, 40, 41, 42, 47, 48, 49]]

def position_concat(image_tensor, patch = patches):
    """
    Concats tensors, should return shape 1x9x6912 
    """
    concat = []
    for position in patch:
        concat.append(torch.cat([image_tensor[0][i] for i in position], dim = 0))
    
    return torch.stack(concat).unsqueeze(0)

In [4]:
def prepare_data_no_batch(set_type, num_samples):
    """
    Prepares data to feed into model by turning images 
    into 1 x 9 x 6912 tensors, grid cell information 
    into 1 x 9 x 10 tensors 
    """ 
    ANIMALS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder.to(device)

    reformatted_data = []
    t_0 = time.perf_counter()
    
    # Extract image tensors and grid cells in each batch
    for i in range(num_samples):
        image = Image.open(f"/data/lmm/generated/v3_spatial_grid_multimodal/images/{set_type}_{i}.png")
        image_tensor = image_transforms(image).unsqueeze(0).to(device) 
        with torch.no_grad():
            image_tensor_tokens = encoder(image_tensor)

        grid = parse_grid(dataset[set_type][i]['text'], 3)

        temp = []
        for row in grid:
            for animal in row:
                temp_tensor = torch.zeros(1,10)
                temp_tensor[0][ANIMALS.index(animal)] = 1
                temp.append(temp_tensor)
        
        reformatted_data.append((position_concat(image_tensor_tokens), torch.stack(temp, dim = 1))) 

        if i % 50 == 0:
            print(f"Processed {i} samples") 
        torch.cuda.empty_cache()
    
    t_3 = time.perf_counter()
    print(f"Finished preparing data in {t_3 - t_0} seconds")
    return reformatted_data

In [5]:
data = prepare_data_no_batch("train", 5000)

Processed 0 samples
Processed 50 samples
Processed 100 samples
Processed 150 samples
Processed 200 samples
Processed 250 samples
Processed 300 samples
Processed 350 samples
Processed 400 samples
Processed 450 samples
Processed 500 samples
Processed 550 samples
Processed 600 samples
Processed 650 samples
Processed 700 samples
Processed 750 samples
Processed 800 samples
Processed 850 samples
Processed 900 samples
Processed 950 samples
Processed 1000 samples
Processed 1050 samples
Processed 1100 samples
Processed 1150 samples
Processed 1200 samples
Processed 1250 samples
Processed 1300 samples
Processed 1350 samples
Processed 1400 samples
Processed 1450 samples
Processed 1500 samples
Processed 1550 samples
Processed 1600 samples
Processed 1650 samples
Processed 1700 samples
Processed 1750 samples
Processed 1800 samples
Processed 1850 samples
Processed 1900 samples
Processed 1950 samples
Processed 2000 samples
Processed 2050 samples
Processed 2100 samples
Processed 2150 samples
Processed 2

In [6]:
print(data[4999][0].shape)
print(data[4999][1].shape)

torch.Size([1, 9, 6912])
torch.Size([1, 9, 10])


In [7]:
import torch.nn as nn
import torch.nn.functional as F

def train(model, dataset = data, num_epochs = 10, learning_rate = 1e-4):
    """
    Train linear layer on data. After running image 
    tensors through linear layer, apply softmax and
    use cross entropy loss to calculate loss.
    """
    linear = model
    softmax = nn.Softmax(dim=2)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    linear.to(device)
    linear.train()
    
    optimizer = torch.optim.Adam(linear.parameters(), lr=learning_rate)
    cross_entropy_loss = F.cross_entropy
    
    total_loss = 0
    
    t_0 = time.perf_counter()
    for epoch in range(num_epochs):
        t_1 = time.perf_counter()
        epoch_loss = 0
        for i in range(len(dataset)):
            image_tensor, grid_tensor = dataset[i]
            output = linear(image_tensor)
            prediction = softmax(output)

            loss = cross_entropy_loss(prediction, grid_tensor.to(device))
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
       
        t_2 = time.perf_counter()
        print(f"Epoch {epoch} loss: {epoch_loss} | completed in {t_2-t_1} seconds")
        print(f"Average loss: {epoch_loss/len(dataset)}")     

        total_loss += epoch_loss
    t_3 = time.perf_counter()
    print(f"Training completed in {t_3 - t_0} seconds")

               

In [8]:
linear = LinearLayer(6912, 10)
train(linear, data, 10, 1e-4)

Epoch 0 loss: 7223.131538510323 | completed in 4.361877731978893 seconds
Average loss: 1.4446263077020645
Epoch 1 loss: 7153.993084669113 | completed in 4.31763705983758 seconds
Average loss: 1.4307986169338227
Epoch 2 loss: 7153.667843103409 | completed in 4.304949581623077 seconds
Average loss: 1.4307335686206817
Epoch 3 loss: 7153.640385866165 | completed in 4.300205413252115 seconds
Average loss: 1.4307280771732331
Epoch 4 loss: 7153.637858390808 | completed in 4.299834690988064 seconds
Average loss: 1.4307275716781616
Epoch 5 loss: 7153.6376321315765 | completed in 4.283599615097046 seconds
Average loss: 1.4307275264263153
Epoch 6 loss: 7153.637609958649 | completed in 4.280631445348263 seconds
Average loss: 1.4307275219917297
Epoch 7 loss: 7153.63760304451 | completed in 4.277513798326254 seconds
Average loss: 1.430727520608902
Epoch 8 loss: 7153.637602686882 | completed in 4.277612302452326 seconds
Average loss: 1.4307275205373764
Epoch 9 loss: 7153.63760137558 | completed in 4.