In [7]:
from transformers import GPTJForCausalLM, AutoTokenizer
import torch

# Load the tokenizer and model
model_path = "/home/sjoshi/lmm/lm-train/checkpoints/v3_spatial_grid_gptj/checkpoint-1953"  # Replace with the path where your model is saved
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GPTJForCausalLM.from_pretrained(model_path)

# Set the model to evaluation mode
model.eval()

# Define a function to generate text
def generate_text(prompt, max_new_tokens=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    
    # Generate text using the model
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True
        )
    
    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [3]:
from datasets import load_from_disk

# Specify the path to the dataset
dataset_path = '/home/sjoshi/lmm/data/generated/v2_spatial_grid'

# Load the dataset
dataset = load_from_disk(dataset_path)

# Display the dataset
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1000
    })
})


In [8]:
import re
from tqdm import tqdm 

def parse_answer(text):
    match = re.search(r'A:\s*(.*)', text)
    if match:
        return match.group(1).split('\n')[0]
    return None

def parse_grid(grid_str):
    grid_str = '\n'.join(grid_str.split('\n')[:3])
    rows = grid_str.strip().split('\n')
    grid = [row.strip().split('|') for row in rows]
    # Remove any empty strings resulting from splitting and strip each element
    return [[cell.strip() for cell in row if cell.strip()] for row in grid]

total = 0
correct = 0
accuracy = 0.0
wrong_per_pos = {}
pbar = tqdm(dataset['validation'])
for example in pbar:
    prompt = example["text"].split(']')[0] + '].'
    grid = parse_grid(prompt)
    for i in range(3):
        for j in range(3):
            total += 1
            current_prompt = prompt + f"\nWhat object is in row {i}, column {j}?"
            parsed_answer = parse_answer(generate_text(current_prompt, max_new_tokens=5))
            if parsed_answer == grid[i][j]:
                correct += 1
            else:
                if (i, j) in wrong_per_pos:
                    wrong_per_pos[(i, j)] += 1
                else:
                    wrong_per_pos[(i, j)] = 1
            accuracy = correct / total
            pbar.set_description(f'Accuracy: {accuracy:.3f}')

Accuracy: 0.63:  10%|█         | 104/1000 [04:51<41:52,  2.80s/it]


KeyboardInterrupt: 