In [1]:
import numpy as np
import torch
import torch.nn as nn
from datasets import load_from_disk
from loguru import logger
from PIL import Image
from transformers import AutoTokenizer, GPTJForCausalLM
from lmm_synthetic.mm_train.gptj_vlm import GPTJ_VLM
from lmm_synthetic.mm_train.utils import load_vision_encoder
import time

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths 
dataset_path = '/data/lmm/generated/v3_spatial_grid_multimodal'
vlm_path = '/home/allanz/data/vlm_checkpoint/final_model'
lm_path = "/data/lmm/checkpoints/lm/lm-pretrain-only-checkpoint-1953"

# Load dataset
dataset = load_from_disk(dataset_path)
print(dataset)

# Load VLM and CLIP model
def load_model_and_tokenizer(model_path, multimodal=False):
    """
    Load the model and tokenizer from the specified path.
    """
    logger.info(f"Loading model and tokenizer from {model_path}")
    model, tokenizer = None, None
    if multimodal:
        model = GPTJ_VLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model.config.pretrained_lm_path)
    else:
        model = GPTJForCausalLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.eval()
    return model, tokenizer

vlm, vlm_tokenizer = load_model_and_tokenizer(vlm_path, multimodal=True)

clip_vision_model = vlm.vision_encoder.clip_vision_model
encoder, image_transforms, _ = load_vision_encoder("clip")

#Parse grid
def parse_grid(grid_str, K):
    """
    Parse the grid string into a 2D list of grid cells.
    """
    grid_str = '\n'.join(grid_str.split('\n')[:K])
    rows = grid_str.strip().split('\n')
    return [[cell.strip() for cell in row.split('|') if cell.strip()] for row in rows]

  from .autonotebook import tqdm as notebook_tqdm
[32m2024-12-17 21:03:06.375[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_model_and_tokenizer[0m:[36m28[0m - [1mLoading model and tokenizer from /home/allanz/data/vlm_checkpoint/final_model[0m


DatasetDict({
    train: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 100000
    })
    validation: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'prompt', 'conversations', 'image'],
        num_rows: 1000
    })
})


[32m2024-12-17 21:03:08.083[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 21:03:08.085[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m
[32m2024-12-17 21:03:09.145[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_multimodal_projector[0m:[36m89[0m - [1mLoading multimodal projector: linear[0m
[32m2024-12-17 21:03:09.279[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m26[0m - [1mLoading vision encoder: clip[0m
[32m2024-12-17 21:03:09.280[0m | [1mINFO    [0m | [36mlmm_synthetic.mm_train.utils[0m:[36mload_vision_encoder[0m:[36m29[0m - [1mUsing CLIP model as the vision encoder[0m


In [40]:
vlm.multimodal_projector

Linear(in_features=768, out_features=768, bias=True)

In [3]:
patches = [[0, 1, 2, 7, 8, 9, 14, 15, 16], [2, 3, 4, 9, 10, 11, 16, 17, 18], [4, 5, 6, 11, 12, 13, 18, 19, 20],
           [14, 15, 16, 21, 22, 23, 28, 29, 30], [16, 17, 18, 23, 24, 25, 30, 31, 32], [18, 19, 20, 25, 26, 27, 32, 33, 34],
           [28, 29, 30, 35, 36, 37, 42, 43, 44], [30, 31, 32, 37, 38, 39, 44, 45, 46], [32, 33, 34, 39, 40, 41, 46, 47, 48]]

def position_concat(image_tensor, patch = patches):
    """
    Concats tensors, should return shape 9x1x6912, 
    1x6912 tensor fore each position in patch, 
    9 patches
    """
    concat = []
    for position in patch:
        concat.append(torch.cat([image_tensor[0][i] for i in position]).unsqueeze(0))
    
    return torch.stack(concat)


In [13]:
#refactor 
def prepare_data_no_batch(set_type, num_samples, prints = False):
    """
    Prepares data to feed into model by turning images 
    into 1 x 9 x 6912 tensors, grid cell information 
    into 1 x 9 x 10 tensors 
    """ 
    ANIMALS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder.to(device)
    multimodal_projector = vlm.multimodal_projector
    multimodal_projector.to(device) 

    reformatted_data = []
    t_0 = time.perf_counter()
    
    # Extract image tensors and grid cells in each batch
    for i in range(num_samples):
        image = Image.open(f"/data/lmm/generated/v3_spatial_grid_multimodal/images/{set_type}_{i}.png")
        image_tensor = image_transforms(image).unsqueeze(0).to(device)
        with torch.no_grad():
            embeddings = encoder(image_tensor)
            # Projector
            image_tensor_tokens = vlm.multimodal_projector(embeddings)

        grid = parse_grid(dataset[set_type][i]['text'], 3)
        
        temp = []
        for row in grid:
            for animal in row:
                #temp_tensor = torch.zeros(1,10)
                #temp_tensor[0][ANIMALS.index(animal)] = 1
                #temp.append(temp_tensor)
                temp.append(ANIMALS.index(animal)) 

        # Tuple of (image tokens (9 x 1 x 6912), correct labels is list. 1 x 9)
        reformatted_data.append((position_concat(image_tensor_tokens.to(device)), temp)) 

        if i % 200 == 0 and prints == True:
            print(f"Processed {i} images") 
        torch.cuda.empty_cache()
    
    t_3 = time.perf_counter()
    print(f"Finished preparing data in {t_3 - t_0} seconds")
    return reformatted_data

In [4]:
def train_lbfgs_linear_classifier(position, dataset, num_iterations = 100, lr = 1.0):
    """
    Trains a linear cliassifier using L-BFGS

    Args:
        position (int): position of the grid cell to train on (0-8)
        dataset (list): each tuple should contain 9 x 1 x 6912 tensor and a list with 9 values
            1 x 6912 tensor for each position in patch, value corresponding with animal index in position
        num_iterations (int): number of iterations for L-BFGS optimization
        lr (float): learning rate for L-BFGS optimization

    Returns:  
        model (nn.Module): trained linear classifier
        log_losses (list): list of losses at each iteration
    """
    # Initialize data and labels 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = torch.concat([dataset[i][0][position] for i in range(len(dataset))], dim = 0).to(device)
    temp_label = []
    for i in range(len(dataset)):
        temp_label.append(dataset[i][1][position])
    labels = torch.tensor(temp_label).to(device)

    # Check inputs
    if len(data) == len(labels) and data.dim() == 2 and labels.dim() == 1:
        print("Data is formatted correctly")
    else:
        print("Look over format of dataset")

    print(f"Data shape: {data.shape}")
    print(f"Labels shape: {labels.shape}")
    


    class LinearClassifier(nn.Module):
        def __init__(self, input_dim, num_classes):
            super(LinearClassifier, self).__init__()
            self.linear = nn.Linear(input_dim, num_classes)

        def forward(self, x):
            return self.linear(x)
    
    input_dim = data.shape[1]
    print(input_dim)
    #num_classes = len(torch.unique(labels))
    num_classes = 10
    print(num_classes)

    model = LinearClassifier(input_dim, num_classes).to(device)

    # Use cross entropy loss for classification
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.LBFGS(model.parameters(), lr = lr, max_iter = num_iterations)

    log_losses = []

    def closure():
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        log_losses.append(loss.item())
        return loss

    optimizer.step(closure)


    print(f"Training complete for position {position}. Final loss: {log_losses[-1]:.4f}")
    return model, log_losses
    

In [8]:
def accuracy(position, dataset, num_samples = 1000, premodel = None):
    """
    Calculate the accuracy of the model on the either test or validation set 

    Args:
        position (int): position of the grid cell to train on (0-8)
        dataset (str): either "test" or "validation"
        num_samples (int): number of samples to calculate accuracy on
    
    Returns:
        accuracy (float): accuracy of the model on the dataset
        lowest (str): worst performing class
    """
    if premodel is None:
        model = models[position]
    else:
        model = premodel
    model.to(DEVICE)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = prepare_data_no_batch(dataset, num_samples)

    cleaned_data = torch.concat([data[i][0][position] for i in range(len(data))], dim = 0).to(device)
    temp_label = []
    for i in range(len(data)):
        temp_label.append(data[i][1][position])

    labels = torch.tensor(temp_label).to(device)

                                         

    incorrect = 0 
    total = 0

    incorrect_guesses = []

    for i in range(len(data)):
        with torch.no_grad():
            output = model(cleaned_data[i])
            prediction = torch.argmax(output)
            if prediction != labels[i]:
                incorrect += 1
                incorrect_guesses.append((prediction, labels[i]))
            total += 1
    
    print(f"Accuracy for position {position} on {dataset}: {1 - incorrect/total:.4f}")
    


In [14]:
train_data = prepare_data_no_batch("train", 10000)

Finished preparing data in 122.77460218966007 seconds


In [15]:
linear0, linear0_log = train_lbfgs_linear_classifier(0, train_data, num_iterations = 100, lr = 1.0)
linear1, linear1_log = train_lbfgs_linear_classifier(1, train_data, num_iterations = 100, lr = 1.0)
linear2, linear2_log = train_lbfgs_linear_classifier(2, train_data, num_iterations = 100, lr = 1.0)
linear3, linear3_log = train_lbfgs_linear_classifier(3, train_data, num_iterations = 100, lr = 1.0)
linear4, linear4_log = train_lbfgs_linear_classifier(4, train_data, num_iterations = 100, lr = 1.0)
linear5, linear5_log = train_lbfgs_linear_classifier(5, train_data, num_iterations = 100, lr = 1.0)
linear6, linear6_log = train_lbfgs_linear_classifier(6, train_data, num_iterations = 100, lr = 1.0)
linear7, linear7_log = train_lbfgs_linear_classifier(7, train_data, num_iterations = 100, lr = 1.0)
linear8, linear8_log = train_lbfgs_linear_classifier(8, train_data, num_iterations = 100, lr = 1.0)

models = {0: linear0, 1: linear1, 2: linear2, 3: linear3, 4: linear4, 5: linear5, 6: linear6, 7: linear7, 8: linear8}

Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 0. Final loss: 0.0000
Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 1. Final loss: 0.0000
Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 2. Final loss: 0.0000
Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 3. Final loss: 0.0000
Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 4. Final loss: 0.0000
Data is formatted correctly
Data shape: torch.Size([10000, 6912])
Labels shape: torch.Size([10000])
6912
10
Training complete for position 5. Final loss: 0.0000
Data is formatted correctly
Data s

In [17]:
for i in range(9):
    accuracy(i, "test", 1000, premodel = None)

Finished preparing data in 12.15949971973896 seconds
Accuracy for position 0 on test: 1.0000
Finished preparing data in 12.347612246870995 seconds
Accuracy for position 1 on test: 1.0000
Finished preparing data in 11.875313021242619 seconds
Accuracy for position 2 on test: 1.0000
Finished preparing data in 12.323789902031422 seconds
Accuracy for position 3 on test: 1.0000
Finished preparing data in 11.902767654508352 seconds
Accuracy for position 4 on test: 1.0000
Finished preparing data in 12.121486186981201 seconds
Accuracy for position 5 on test: 1.0000
Finished preparing data in 12.358559433370829 seconds
Accuracy for position 6 on test: 1.0000
Finished preparing data in 12.432540699839592 seconds
Accuracy for position 7 on test: 1.0000
Finished preparing data in 11.95513591915369 seconds
Accuracy for position 8 on test: 1.0000
