In [None]:
import os
import sys
from plot_funcs import plot_task
from   matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# %matplotlib widget

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
from typing import List, Optional

import argparse
import torch

from src.model import Transformer
from src.token import VOCAB_SIZE, SpecialToken
from src.load_data import load_from_json, GridDataset
from src.utils.helper import set_deterministic
from src.utils.display_diff import compare_sequences, colorize, split_into_chunks
from src.utils.transformer_helper import create_mask

from checkpoint_plot import format_batch, plot_task, plot_answer

debug_on_cpu = False

def generate_sample(model, input_sequence, max_length, device, index_to_visualize: List[int], override_this_index: Optional[List[int]] = None, *, mask_hack):
    assert override_this_index is None, override_this_index
    model.eval()
    y = 0
    x = 0
    coord = (-1, -1)
    with torch.no_grad():
        input_ids = torch.tensor(input_sequence['task'], dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

        seq_length = len(input_sequence['task'])
        print('seq_length', seq_length, max_length)
        for generated_token_index in range(max_length - seq_length):
            activation_store.clear()

            mask = create_mask(input_ids, device, [seq_length], mask_hack=mask_hack)
            outputs = model(input_ids, mask)  # (1, seq_length, vocab_size)
            next_token_logits = outputs[0, -1, :]  # (vocab_size)

            if generated_token_index == index_to_visualize[0]:
                # visualize_activation(model, input_ids, next_token_logits)
                torch.save({
                        'activations': activation_store.activations,
                        'input_ids': input_ids,
                        'outputs': outputs
                    },
                    f'../temp/{generated_token_index}_activation.pt')    

            next_token_id = torch.argmax(next_token_logits).item()

            print(f'next_token_id [{seq_length + generated_token_index}]: {next_token_id}')

            try:
                the_index = index_to_visualize.index(generated_token_index)
                if override_this_index is not None:
                    next_token_id = override_this_index[the_index]
                    print('override_this_index', generated_token_index, next_token_id)
            except ValueError:
                pass
            

            if next_token_id < SpecialToken.CELL_TOKEN_SIZE.value:
                coord = (y, x)
                x = x + 1
                x = min(x, model.max_grid_size - 1)
            elif next_token_id == SpecialToken.ROW_SEPARATOR.value:
                coord = (y, x)
                x = 0
                y = y + 1
                y = min(y, model.max_grid_size - 1)
            else:
                y = 0
                x = 0
                coord = (-1, -1)

            input_ids = torch.cat([input_ids, torch.tensor([[[next_token_id, coord[0], coord[1], -1, -1]]], dtype=torch.long, device=device)], dim=1)  # (1, seq_length + 1)
            if next_token_id == SpecialToken.END.value:
                break
            
        return input_ids.squeeze(0).tolist()  # (seq_length)


In [None]:
from analyze_tensor import analyze_tensor


def show_interesting_attention(attention_weights, input_ids, seq_length, grid_index, top_k=5):        
    # Sum attention weights across rows to get column importance
    column_importance = attention_weights.sum(dim=0) # (seq_len)
    assert column_importance.dim() == 1
    assert column_importance.shape == (seq_length,), f"got {column_importance.shape}, not ({seq_length})"

    # Get the indices of the top-k interesting columns
    top_k_values, top_k_indices = torch.topk(column_importance, k=top_k)
    assert top_k_values.shape == top_k_indices.shape == (top_k,), f"got {top_k_values.shape}"

    print('grid_index', grid_index)
    print(f"Top {top_k} interesting attention columns for the current layer:")
       
    for i, (index, value) in enumerate(zip(top_k_indices.tolist(), top_k_values.tolist()), 1):
        token = input_ids[0, index]
        # Find the corresponding grid index
        grid_position = torch.searchsorted(grid_index, index, right=True).item()
        
        print(f"{i}. Index: {index}, grid:{grid_position}, Token: '{token}', Attention Sum: {value}")

def visualize_activation(model, input_ids, outputs):
    # Find all SpecialToken.START_INPUT.value and SpecialToken.START_OUTPUT.value in input_ids[0, :, 0]
    grid_index = torch.where((input_ids[0, :, 0] == SpecialToken.START_INPUT.value) | 
                             (input_ids[0, :, 0] == SpecialToken.START_OUTPUT.value))[0]

    print('grid_index', grid_index)

    transformer_input_x = activation_store['transformer_input_x']
    assert len(transformer_input_x) == 1
    transformer_input_x = transformer_input_x[0]
    torch.set_printoptions(profile="full")
    # print('transformer_input_x', transformer_input_x)
    analyze_tensor(transformer_input_x, f"transformer_input_x ({transformer_input_x.shape})")\

    # transformer_input_mask = activation_store['transformer_input_mask']
    # assert len(transformer_input_mask) == 1
    # transformer_input_mask = transformer_input_mask[0]
    # analyze_tensor(transformer_input_mask, f"transformer_input_mask ({transformer_input_mask.shape})")

    embedding_out = activation_store['embedding_out']
    assert len(embedding_out) == 1
    embedding_out = embedding_out[0]
    analyze_tensor(embedding_out, f"embedding_out ({embedding_out.shape})")

    initial_dropout_out = activation_store['initial_dropout_out']
    assert len(initial_dropout_out) == 1
    initial_dropout_out = initial_dropout_out[0]
    analyze_tensor(initial_dropout_out, f"initial_dropout_out ({initial_dropout_out.shape})")
    
    num_layers = len(model.layers)
    assert num_layers == len(activation_store['attention_weights'])
    assert num_layers == len(activation_store['attention_out'])
    
    for layer_index in range(num_layers): # range(2): # 
        attention_weights = activation_store['attention_weights'][layer_index]
        
        batch, head, seq_length, _ = attention_weights.shape

        for head_index in range(head):
            attention_weights_slice = attention_weights[0, head_index, :, :]
            attention_weights_slice = attention_weights_slice.squeeze()
            
            assert attention_weights_slice.dim() == 2, f"Got {attention_weights_slice.shape}"

            analyze_tensor(attention_weights_slice, f"[{layer_index}].attention_weights[{head_index}] ({attention_weights_slice.shape})")
            show_interesting_attention(attention_weights_slice, input_ids, seq_length, grid_index, 45)

        attention_out = activation_store['attention_out'][layer_index]
        analyze_tensor(attention_out, f"[{layer_index}].attention_out ({attention_out.shape})")
        norm1_out = activation_store['norm1_out'][layer_index]
        analyze_tensor(norm1_out, f"[{layer_index}].norm1_out ({norm1_out.shape})")
        norm2_out = activation_store['norm2_out'][layer_index]
        analyze_tensor(norm2_out, f"[{layer_index}].norm2_out ({norm2_out.shape})")
        
    analyze_tensor(outputs, f"output ({outputs.shape})")

In [None]:
from src.activation_hooks import ActivationStore, register_transformer_hooks
from src.checkpoint_handler import CheckpointHandler

set_deterministic()

# Load data
data_sources = ['arc-agi_evaluation'] # ['synth_conditional_logic_test'] 
all_challenges = {}

for source in data_sources:
    try:
        challenges, solutions = load_from_json(source, '../input_data/')
        all_challenges.update(challenges)
    except FileNotFoundError as e:
        print(f"Error loading {source}: {e}. Skipping this data source.")

if not all_challenges:
    print("No data could be loaded. Please check the file paths and data sources.")
    
device = torch.device("cuda" if (torch.cuda.is_available() and not debug_on_cpu) else "cpu")
checkpoint_path='../cloud_runs/69.55.141.119/barc/runs/barc/20241107_224744_nogit_nobranch_lr4e-05_bl1e-06_ssu0_bs16_h4_es888_nl18_we10_as1_ph1_ac1_ad1_scosine_oadam_ge1_mh0_ssnone_ss1e-02_c6/Transformer_best_7738.pt'
model, max_seq_length, checkpoint_args = CheckpointHandler.load_checkpoint_in_production(checkpoint_path, device, adjust_max_length=12000)
mask_hack = checkpoint_args.get('mask_hack', True)

activation_store = ActivationStore()
register_transformer_hooks(model, activation_store)

model.activate_attention_weights_output(True)
model.eval()

# Create the dataset
dataset_ref = GridDataset.load_from_paired_file(all_challenges, solutions, second_only=False)
dataset = GridDataset.load_from_paired_file(all_challenges, None, second_only=False)

print('dataset', len(dataset), type(dataset[0]) if len(dataset) > 0 else '?')
print('dataset_ref', len(dataset_ref), type(dataset_ref[0]))
print('max_seq_length', max_seq_length)

mismatch_count = 0    
tested_count = 0
# Generate samples
start_index = 109
for i in range(start_index, len(dataset)):
    print('index', i)
    input_sequence = dataset[i]

    if max_seq_length > len(input_sequence['task']):
        tested_count += 1
        
        sample = generate_sample(model, input_sequence, max_seq_length, device, [1], mask_hack=mask_hack) # [2, 5, 7, 8], override_this_index=[8, 8, 0, 8])
        # plot_answer(input_sequence['task'], sample, i)

        expected_sequence = dataset_ref[i]
        # plot_answer(expected_sequence['task'], expected_sequence['task'], i)
    else:
        print('too long to evaluate')

    break

In [None]:
checkpoint = torch.load('../temp/1_activation.pt', map_location=torch.device('cpu'))
attention_weights = checkpoint['activations']['attention_weights']
input_ids = checkpoint['input_ids']

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
import copy
from matplotlib.backend_bases import MouseButton

from src.token import SpecialToken
from src.utils.helper import detokenize_grid

# Create output widget
out = widgets.Output()
display(out)

cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
     '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
norm = colors.Normalize(vmin=0, vmax=9)

def to_grid_index(i, train_or_test, input_or_output):
    return i * 2 + (1 if input_or_output == 'output' else 0) + 1

def find_cell_index(input_sequence, y, x, grid_index):
    cell_index = next(index for index, element in enumerate(input_sequence) if (lambda cell: cell[1]==y and cell[2] == x and cell[4] == grid_index)(element))
    return cell_index

def find_grid_width(input_sequence_segment):
    cell_index = next(index for index, element in enumerate(input_sequence_segment) if (lambda cell: cell[0]==SpecialToken.ROW_SEPARATOR.value)(element))
    return cell_index
    
def apply_attention_overlay(y, x, i, train_or_test, input_or_output, context, all_axes):
    grid_index = to_grid_index(i, train_or_test, input_or_output)
    # print('grid_index', grid_index, y, x, train_or_test, input_or_output)

    cell_index = find_cell_index(context['input_sequence'], y, x, grid_index)
    # print(cell_index)     # , context['input_sequence']

    attention_vector = context['attention_head'][cell_index, :]
    min_att = attention_vector.min().item()
    assert min_att >= 0
    max_att = attention_vector.max().item()
    
    for ax in all_axes:
        for img in ax.get_images():
            if hasattr(img, 'original_data'):
                i, train_or_test, input_or_output = img.original_data[:3]
                target_grid_index = to_grid_index(i, train_or_test, input_or_output)
                target_grid_cell_start_index = find_cell_index(context['input_sequence'], 0, 0, target_grid_index)
                # width = find_grid_width(context['input_sequence'][target_grid_ceel_start_index:])
                input_matrix = img.original_data[3]
                height = len(input_matrix)
                width = len(input_matrix[0])

                attention_seq = attention_vector[target_grid_cell_start_index:]

                # print(i, train_or_test, input_or_output, target_grid_index, height, width)
                  
                # Reshape attention weights to match grid shape
                attention_grid = np.zeros((height, width))
                for h in range(height):
                    for w in range(width):
                        attention_grid[h][w] = attention_seq[h * (width + 1) + w]

                # Normalize attention weights to [0,1]
                attention_grid = attention_grid / max_att
                
                # Remove existing attention overlay if any
                for artist in ax.get_children():
                    if getattr(artist, 'is_attention', False): # isinstance(artist, plt.matplotlib.collections.QuadMesh) and 
                        # print('remove old one')
                        artist.remove()
                
                # Add new attention overlay
                attention_overlay = ax.imshow(attention_grid, 
                                           cmap='Reds', 
                                           alpha=0.999,
                                           interpolation='nearest')
                attention_overlay.is_attention = True                

def plot_one(ax, task_data, i, train_or_test, input_or_output, context):
    try:
        input_matrix = task_data[train_or_test][i][input_or_output]
    except:
        return
        
    im = ax.imshow(input_matrix, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth=0.5)
    
    plt.setp(plt.gcf().get_axes(), xticklabels=[], yticklabels=[])
    ax.set_xticks([x-0.5 for x in range(1 + len(input_matrix[0]))])     
    ax.set_yticks([x-0.5 for x in range(1 + len(input_matrix))])
    
    # Calculate font size based on grid dimensions
    grid_size = max(len(input_matrix), len(input_matrix[0]))
    base_font_size = 10  # Base font size for small grids
    min_font_size = 4    # Minimum font size to ensure readability
    font_size = max(base_font_size - (grid_size - 5) * 0.5, min_font_size)
    
    # Add text annotations with adjusted font size
    for y in range(len(input_matrix)):
        for x in range(len(input_matrix[0])):
            value = input_matrix[y][x]
            text_color = 'white' if value > 5 or value == 0 else 'black'
            ax.text(x, y, str(value), ha='center', va='center', color=text_color, fontsize=font_size)
    
    ax.set_title(f'{train_or_test} {input_or_output}', color='black' if train_or_test == 'train' else 'red')
    im.original_data = (i, train_or_test, input_or_output, input_matrix)

    # Add hover functionality
    # Define click handler with output capture
    @out.capture()
    def onclick(event):
        if event.inaxes == ax:
            if event.button is MouseButton.LEFT:
                x, y = int(round(event.xdata)), int(round(event.ydata))
                # print('onclick!', x, y)
                if 0 <= x < len(input_matrix[0]) and 0 <= y < len(input_matrix):
                    all_axes = fig.get_axes()
                    for img in ax.get_images():
                        if hasattr(img, 'original_data'):
                            apply_attention_overlay(y, x, *img.original_data[:3], context, all_axes)
                            break
                    fig.canvas.draw_idle()           
                    
            else:
                for i_ax in fig.get_axes():
                    for artist in i_ax.get_children():
                        if getattr(artist, 'is_attention', False): # isinstance(artist, plt.matplotlib.collections.QuadMesh) and 
                            # print('remove old one')
                            artist.remove()                        
                fig.canvas.draw_idle()

    fig = ax.figure
    result = fig.canvas.mpl_connect("button_press_event", onclick)

def plot_task(task, i, t, context):
    """    Plots the first train and test pairs of a specified task,
    using same color scheme as the ARC app    """    

    num_train = len(task['train'])
    num_test  = len(task['test'])

    w = num_train + num_test
    fig, axs = plt.subplots(2, w, figsize=(3*w, 3*2))
    plt.suptitle(f'Set #{i}, {t}:', fontsize=20, fontweight='bold', y=1)

    for j in range(num_train):     
        plot_one(axs[0, j], task, j, 'train', 'input', context)
        plot_one(axs[1, j], task, j, 'train', 'output', context)        
    for j in range(num_test):     
        plot_one(axs[0, num_train + j], task, j, 'test', 'input', context)
        plot_one(axs[1, num_train + j], task, j, 'test', 'output', context)        
       
    fig.patch.set_linewidth(5)
    fig.patch.set_edgecolor('black') 
    fig.patch.set_facecolor('#dddddd')
    
    plt.tight_layout()
    
    plt.show()  

def plot_interactive_activation(input_sequence, attention_head, attention_index):
    task = {'train': [], 'test': []}

    context = {
        'attention_head': attention_head,
        'input_sequence': input_sequence
    }
    
    current_section = None
    current_data = []
    input_length = len(input_sequence)
    
    for index, element in enumerate(input_sequence):
        token = element[0]
        if token == SpecialToken.START_INPUT.value:
            current_section = 'input'
            current_data = []
        elif token == SpecialToken.END_INPUT.value:
            if current_section == 'input':
                task['train'].append({'input': detokenize_grid(current_data)})
        elif token == SpecialToken.START_OUTPUT.value:
            current_section = 'output'
            current_data = []
        elif token == SpecialToken.END_OUTPUT.value:
            if current_section == 'output':
                output_data = {'output': detokenize_grid(current_data)}
                if index < input_length:
                    task['train'][-1].update(output_data)
                else:
                    print ('new answer', output_data)
                    last_train = task['train'].pop()
                    last_train.update(output_data)
                    task['test'].append(last_train)                    
        else:
            if token < 10 or token == SpecialToken.ROW_SEPARATOR.value:
                current_data.append(token)
                
    # Plot task
    plot_task(task, 0, f"attention index {attention_index}", context)

def main():
    # print('fetched', sequence)
    # print('detokenized', task)    
    for attention_index, attention in enumerate(attention_weights):
    
        # if attention_index <= -1:
        #     continue
        
        assert attention.shape[0] == 1 # batch is 1
        
        for attention_head in range(attention.shape[1]):
            attention_head = attention[0, attention_head]
            assert input_ids.shape[1] == attention_head.shape[0]
    
            plot_interactive_activation(input_ids[0].tolist(), attention_head, attention_index)
            # break
        # break

main()