In [None]:
import os
import sys
from plot_funcs import plot_task
from   matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

%matplotlib widget

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
from typing import List

import argparse
import torch
import torch.nn.functional as F

from src.model import Transformer
from src.token import SpecialToken, VOCAB_SIZE
from src.load_data import load_from_json, GridDataset
from src.utils.helper import set_deterministic
from src.utils.display_diff import compare_sequences, colorize, split_into_chunks
from checkpoint_plot import format_batch, plot_task, plot_answer
from src.utils.transformer_helper import create_mask

from analyze_tensor import analyze_tensor

debug_on_cpu = True

def softmax(logits):
    exp_logits = torch.exp(logits - torch.max(logits))  # Subtract max for numerical stability
    return exp_logits / torch.sum(exp_logits)

def cross_entropy_loss(logits, true_class):
    num_classes = logits.size(0)
    probabilities = softmax(logits)
    
    # Create one-hot encoded true distribution using torch.nn.functional.one_hot
    true_distribution = F.one_hot(torch.tensor(true_class), num_classes=num_classes).float()
    
    # Calculate cross-entropy loss
    loss = -torch.sum(true_distribution * torch.log(probabilities))
    return loss


def report_importance(model):
    # Compute gradient magnitudes
    param_importance = {}
    row_importance = {}
    col_importance = {}

    named_parameters_dict = dict(model.named_parameters())
        
    for name, param in model.named_parameters():
        if param.grad is not None:
            # Assert that the tensor is at most 2D
            assert param.dim() <= 2, f"Parameter {name} has more than 2 dimensions"
            
            # Calculate overall importance
            param_importance[name] = param.grad.abs().mean().item()

            # Calculate row-wise importance
            if param.dim() == 2:
                row_importance[name] = param.grad.abs().mean(dim=1).tolist()
                col_importance[name] = param.grad.abs().mean(dim=0).tolist()        
    
    for name, importance in param_importance.items():
        parameter = named_parameters_dict[name]

        # if not 'bias' in name and not 'norm' in name:
            # analyze_tensor(torch.abs(parameter.grad), f"{name}.grad ({list(parameter.shape)})")
            

        # try:
        #     for i, imp in enumerate(row_importance[name]):
        #         print(f"  Row {i}: {imp}")
        
        #     for i, imp in enumerate(col_importance[name]):
        #         print(f"  Column {i}: {imp}")
        # except KeyError:
        #     pass
        
    # Sort parameters by importance
    sorted_importance = sorted(param_importance.items(), key=lambda x: x[1], reverse=True)
    
    for name, importance in sorted_importance:
        print(f"{name}: {importance}")

    
def generate_sample(model, input_sequence, max_length, device, expected_sequence, index_to_visualize: List[int], *, mask_hack):
    y = 0
    x = 0
    coord = (-1, -1)
    generated_token_index = 0
    input_ids = torch.tensor(input_sequence['task'], dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    for current_token_index in range(max_length - len(input_sequence['task'])):
        seq_length = input_ids.shape[1]
        
        target_token_index = seq_length + 1
        
        the_index = None
        requires_grad = False

        try:
            the_index = index_to_visualize.index(generated_token_index)
            requires_grad = True
        except ValueError:
            pass

        activation_store.clear()
        mask = create_mask(input_ids, device, [seq_length], mask_hack)
        
        if requires_grad:
            model.train()
            model.zero_grad()
            outputs = model(input_ids, mask)  # (1, seq_length, vocab_size)
            print('index_to_visualize[?] and model.training?', the_index, model.training)                
        else:
            model.eval()
            with torch.no_grad():
                outputs = model(input_ids, mask)  # (1, seq_length, vocab_size)
                
        next_token_logits = outputs[0, -1, :]  # (vocab_size)

        if requires_grad and target_token_index < len(expected_sequence):
            expexted_token = expected_sequence[target_token_index][0]
            print('next_token_logits, expexted_token', next_token_logits, next_token_logits.requires_grad, expexted_token, outputs.requires_grad, input_ids.requires_grad)
            loss = cross_entropy_loss(next_token_logits, expexted_token)
            # Backward pass
            loss.backward()

            report_importance(model)
            
        next_token_id = torch.argmax(next_token_logits).item()

        if next_token_id < SpecialToken.CELL_TOKEN_SIZE.value:
            coord = (y, x)
            x = x + 1
            x = min(x, model.max_grid_size - 1)
        elif next_token_id == SpecialToken.ROW_SEPARATOR.value:
            coord = (y, x)
            x = 0
            y = y + 1
            y = min(y, model.max_grid_size - 1)
        else:
            y = 0
            x = 0
            coord = (-1, -1)

        input_ids = torch.cat([input_ids, torch.tensor([[[next_token_id, coord[0], coord[1], -1, -1]]], dtype=torch.long, device=device)], dim=1)  # (1, seq_length + 1)
        if next_token_id == SpecialToken.END.value:
            break
        
        generated_token_index += 1

    return input_ids.squeeze(0).tolist()  # (seq_length)


In [None]:
from src.activation_hooks import ActivationStore, register_transformer_hooks
from src.utils.transformer_helper import count_parameters
from src.checkpoint_handler import CheckpointHandler

set_deterministic()

# Load data
data_sources = ['arc-agi_evaluation']
all_challenges = {}

for source in data_sources:
    try:
        challenges, solutions = load_from_json(source, '../input_data/')
        all_challenges.update(challenges)
    except FileNotFoundError as e:
        print(f"Error loading {source}: {e}. Skipping this data source.")

if not all_challenges:
    print("No data could be loaded. Please check the file paths and data sources.")
    
device = torch.device("cuda" if (torch.cuda.is_available() and not debug_on_cpu) else "cpu")

checkpoint_path='../cloud_runs/69.55.141.119/barc/runs/barc/20241107_224744_nogit_nobranch_lr4e-05_bl1e-06_ssu0_bs16_h4_es888_nl18_we10_as1_ph1_ac1_ad1_scosine_oadam_ge1_mh0_ssnone_ss1e-02_c6/Transformer_best_7738.pt'

model, max_seq_length, checkpoint_args = CheckpointHandler.load_checkpoint_in_production(checkpoint_path, device, adjust_max_length=12000)
mask_hack = checkpoint_args.get('mask_hack', True)

count_parameters(model)

activation_store = ActivationStore()
register_transformer_hooks(model, activation_store)
model.activate_attention_weights_output(True)

    
model.eval()

# Create the dataset
dataset_ref = GridDataset.load_from_paired_file(all_challenges, solutions, second_only=True)
dataset = GridDataset.load_from_paired_file(all_challenges, None, second_only=True)

print('dataset', len(dataset), type(dataset[0]))
print('dataset_ref', len(dataset_ref), type(dataset_ref[0]))
print('max_seq_length', max_seq_length)

mismatch_count = 0    
tested_count = 0
# Generate samples
start_index = 4
for i in range(start_index, len(dataset)):
    print('dataset index', i)
    input_sequence = dataset[i]

    if max_seq_length > len(input_sequence['task']):
        tested_count += 1
        
        expected_sequence = dataset_ref[i]['task']
        
        activation_store.clear()
        sample = generate_sample(model, input_sequence, max_seq_length, device, expected_sequence, [2], mask_hack=mask_hack)
        plot_answer(input_sequence, sample, i)

        plot_answer(expected_sequence, expected_sequence, i)
    else:
        print('too long to evaluate')

    break