In [3]:
import numpy as np
import torch
import torch.nn as nn
import random
from datasets import load_from_disk
from loguru import logger
from PIL import Image
from transformers import AutoTokenizer, GPTJForCausalLM
from lmm_synthetic.mm_train.gptj_vlm import GPTJ_VLM
from lmm_synthetic.mm_train.utils import load_vision_encoder
import time

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths 
dataset_path = '/data/lmm/generated/v3_spatial_grid_multimodal'
vlm_path = '/home/allanz/data/vlm_checkpoint/final_model'
lm_path = "/data/lmm/checkpoints/lm/lm-pretrain-only-checkpoint-1953"

# Load dataset
dataset = load_from_disk(dataset_path)
print(dataset)

# Load VLM and CLIP model
def load_model_and_tokenizer(model_path, multimodal=False):
    """
    Load the model and tokenizer from the specified path.
    """
    logger.info(f"Loading model and tokenizer from {model_path}")
    model, tokenizer = None, None
    if multimodal:
        model = GPTJ_VLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model.config.pretrained_lm_path)
    else:
        model = GPTJForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.eval()
    return model, tokenizer

vlm, vlm_tokenizer = load_model_and_tokenizer(vlm_path, multimodal=True)

clip_vision_model = vlm.vision_encoder.clip_vision_model
encoder, image_transforms, _ = load_vision_encoder("clip")

#Parse grid
def parse_grid(grid_str, K):
    """
    Parse the grid string into a 2D list of grid cells.
    """
    grid_str = '\n'.join(grid_str.split('\n')[:K])
    rows = grid_str.strip().split('\n')
    return [[cell.strip() for cell in row.split('|') if cell.strip()] for row in rows]

  from .autonotebook import tqdm as notebook_tqdm
[32m2024-12-17 16:05:54.445[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model_and_tokenizer[0m:[36m29[0m - [1mLoading model and tokenizer from /home/allanz/data/vlm_checkpoint/final_model[0m


DatasetDict({
    train: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 100000
    })
    validation: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
})


[32m2024-12-17 16:05:56.491[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 16:05:56.496[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m
[32m2024-12-17 16:05:57.021[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_multimodal_projector[0m:[36m89[0m - [1mLoading multimodal projector: linear[0m
[32m2024-12-17 16:05:57.177[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 16:05:57.179[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m


In [4]:
# Linear layer
class LinearLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearLayer, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

In [94]:
patches = [[0, 1, 2, 7, 8, 9, 14, 15, 16], [2, 3, 4, 9, 10, 11, 16, 17, 18], [4, 5, 6, 11, 12, 13, 18, 19, 20],
           [14, 15, 16, 21, 22, 23, 28, 29, 30], [16, 17, 18, 23, 24, 25, 30, 31, 32], [18, 19, 20, 25, 26, 27, 32, 33, 34],
           [28, 29, 30, 35, 36, 37, 42, 43, 44], [30, 31, 32, 37, 38, 39, 44, 45, 46], [32, 33, 34, 39, 40, 41, 46, 47, 48]]

def position_concat(image_tensor, patch = patches):
    """
    Concats tensors, should return shape 9x1x6912, 
    1x6912 tensor fore each position in patch, 
    9 patches
    """
    concat = []
    for position in patch:
        concat.append(torch.cat([image_tensor[0][i] for i in position]).unsqueeze(0))
    
    return torch.stack(concat)


In [113]:
def prepare_data_no_batch(set_type, num_samples):
    """
    Prepares data to feed into model by turning images 
    into 1 x 9 x 6912 tensors, grid cell information 
    into 1 x 9 x 10 tensors 
    """ 
    ANIMALS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder.to(device)

    reformatted_data = []
    t_0 = time.perf_counter()
    
    # Extract image tensors and grid cells in each batch
    for i in range(num_samples):
        image = Image.open(f"/data/lmm/generated/v3_spatial_grid_multimodal/images/{set_type}_{i}.png")
        image_tensor = image_transforms(image).unsqueeze(0).to(device) 
        with torch.no_grad():
            image_tensor_tokens = encoder(image_tensor)

        grid = parse_grid(dataset[set_type][i]['text'], 3)

        temp = []
        for row in grid:
            for animal in row:
                #temp_tensor = torch.zeros(1,10)
                #temp_tensor[0][ANIMALS.index(animal)] = 1
                #temp.append(temp_tensor)
                temp.append(ANIMALS.index(animal)) 

        # Tuple of (image tokens (9 x 1 x 6912), correct labels is list. 1 x 9)
        reformatted_data.append((position_concat(image_tensor_tokens), temp)) 

        if i % 50 == 0:
            print(f"Processed {i} images") 
        torch.cuda.empty_cache()
    
    t_3 = time.perf_counter()
    print(f"Finished preparing data in {t_3 - t_0} seconds")
    return reformatted_data

In [114]:
train_data = prepare_data_no_batch("train", 10000)

Processed 0 images
Processed 50 images
Processed 100 images
Processed 150 images
Processed 200 images
Processed 250 images
Processed 300 images
Processed 350 images
Processed 400 images
Processed 450 images
Processed 500 images
Processed 550 images
Processed 600 images
Processed 650 images
Processed 700 images
Processed 750 images
Processed 800 images
Processed 850 images
Processed 900 images
Processed 950 images
Processed 1000 images
Processed 1050 images
Processed 1100 images
Processed 1150 images
Processed 1200 images
Processed 1250 images
Processed 1300 images
Processed 1350 images
Processed 1400 images
Processed 1450 images
Processed 1500 images
Processed 1550 images
Processed 1600 images
Processed 1650 images
Processed 1700 images
Processed 1750 images
Processed 1800 images
Processed 1850 images
Processed 1900 images
Processed 1950 images
Processed 2000 images
Processed 2050 images
Processed 2100 images
Processed 2150 images
Processed 2200 images
Processed 2250 images
Processed 2

In [123]:
train_data[9999][0][1].shape

torch.Size([1, 6912])

In [155]:
train_data[1][1]

[2, 4, 2, 3, 3, 9, 3, 4, 4]

In [139]:
test = torch.concat([train_data[i][0][0] for i in range(10000)], dim = 0)

In [161]:
test.shape[1]

6912

In [154]:
labels = []
pos = 0
for i in range(10000):
    labels.append(train_data[i][1][pos])

print(len(labels))
print(labels[1])

10000
2


In [156]:
labels2 = torch.tensor(labels)

In [157]:
labels2.shape

torch.Size([10000])

In [160]:
len(torch.unique(labels2))

10

In [175]:
def train_lbfgs_linear_classifier(position, dataset, num_iterations = 100, lr = 1.0):
    """
    Trains a linear cliassifier using L-BFGS

    Args:
        position (int): position of the grid cell to train on (0-8)
        dataset (list): each tuple should contain 9 x 1 x 6912 tensor and a list with 9 values
            1 x 6912 tensor for each position in patch, value corresponding with animal index in position
        num_iterations (int): number of iterations for L-BFGS optimization
        lr (float): learning rate for L-BFGS optimization

    Returns:  
        model (nn.Module): trained linear classifier
        log_losses (list): list of losses at each iteration
    """
    # Initialize data and labels 
    data = torch.concat([dataset[i][0][position] for i in range(len(dataset))], dim = 0)
    temp_label = []
    for i in range(len(dataset)):
        temp_label.append(dataset[i][1][position])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    labels = torch.tensor(temp_label).to(device)

    # Check inputs
    if len(data) == len(labels) and data.dim() == 2 and labels.dim() == 1:
        print("Data is formatted correctly")
    else:
        print("Look over format of dataset")

    print(f"Data shape: {data.shape}")
    print(f"Labels shape: {labels.shape}")
    


    class LinearClassifier(nn.Module):
        def __init__(self, input_dim, num_classes):
            super(LinearClassifier, self).__init__()
            self.linear = nn.Linear(input_dim, num_classes)

        def forward(self, x):
            return self.linear(x)
    
    input_dim = data.shape[1]
    print(input_dim)
    num_classes = len(torch.unique(labels))
    print(num_classes)

    model = LinearClassifier(input_dim, num_classes).to(device)

    # Use cross entropy loss for classification
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.LBFGS(model.parameters(), lr = lr, max_iter = num_iterations)

    log_losses = []

    def closure():
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        log_losses.append(loss.item())
        return loss

    optimizer.step(closure)


    print("Training complete. Final loss: {:.4f}".format(log_losses[-1]))
    return model, log_losses
    

In [181]:
linear0, linear0_log = train_lbfgs_linear_classifier(0, train_data, num_iterations = 1000, lr = 0.01)

Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete. Final loss: 0.0000


In [186]:
input = train_data[0][0][0]
input.shape
input.to(DEVICE)
output = linear0(input)


In [192]:
train_data[0][1]

[4, 3, 2, 4, 3, 3, 2, 5, 2]

In [187]:
output

tensor([[-1.3596, -0.9070, -1.2077, -1.2198, 11.8554, -1.0516, -1.8016, -1.0037,
         -1.8169, -1.5476]], device='cuda:0', grad_fn=<AddmmBackward0>)