In [None]:
import os
import sys

# Add the project root directory to sys.path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from src.model import Transformer
from src.token import VOCAB_SIZE

# Create an instance of your Transformer model
embed_size = 240
num_layers = 10
heads = 6
use_grid_encoder = True
max_length = 2048

model = Transformer(VOCAB_SIZE, embed_size, num_layers, heads, use_grid_encoder, max_length, jupyter_debug=True)

def visualize_encoding(encoding, title):
    plt.figure(figsize=(12, 8))
    sns.heatmap(encoding.squeeze().transpose(0, 1).cpu().numpy(), cmap='viridis')
    plt.title(title)
    plt.ylabel('Embedding Dimension')
    plt.xlabel('Position')
    plt.show()

def plot_encoding_values(encoding, title, step=1):
    encoding = encoding.squeeze().cpu().numpy()
    plt.figure(figsize=(12, 6))
    for i in range(encoding.shape[1] - 1, -1, -step):
        plt.plot(encoding[:, i], label=f'Dim {i}')
    plt.title(title)
    plt.xlabel('Position')
    plt.ylabel('Encoding Value')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

def compare_encodings(encoding, title, step=1):
    encoding = encoding.squeeze().cpu().numpy()
    plt.figure(figsize=(12, 6))
    for i in range(encoding.shape[0] - 1, -1, -step):
        plt.plot(encoding[i], range(len(encoding[i])), label=f'Position {i}')
    plt.title(f'{title} - Comparison')
    plt.ylabel('Embedding Dimension')
    plt.xlabel('Encoding Value')
    plt.legend()
    plt.show()

print('model.positional_encoding', model.positional_encoding.shape)
print('model.grid_encoding', model.grid_encoding.shape)

# Visualize positional encoding
visualize_encoding(model.positional_encoding, 'Positional Encoding')
plot_encoding_values(model.positional_encoding, 'Positional Encoding Values', step=50)
compare_encodings(model.positional_encoding, 'Positional Encoding', step=1100)

# Visualize grid encoding
visualize_encoding(model.grid_encoding, 'Grid Encoding')
plot_encoding_values(model.grid_encoding, 'Grid Encoding Values', step=13)
compare_encodings(model.grid_encoding, 'Grid Encoding', step=13)

In [None]:
def visualize_encoding(t, title):
    plt.figure(figsize=(12, 8))
    sns.heatmap(t.cpu().detach().transpose(0, 1).numpy(), cmap='viridis')
    plt.title(title)
    plt.ylabel('Embedding Dimension')
    plt.xlabel('Position')
    plt.show()


# Create a simple input where we can predict the effect of grid encoding
x = torch.tensor([[
    [0, 0, 10, 1], #0
    [0, 0, 10, 1],
    [2, 0, 11, 0],
    [2, 0, 11, 0],
    [2, 5, 10, 0],
    [1, 5, 10, 2],
    [1, 5, 1, 2],
    [0, 5, 1, 3],
    [0, 12, 2, 3],
    [0, 12, 2, 4],
    [0, 12, 2, 5],
    [0, 12, 2, 6],
    [0, 12, 2, 7],
]])

mask = torch.ones(13, 13).triu(1).bool()

with torch.no_grad():
    model.grid_scale.data = torch.tensor([25.0, 25.0, 25.0])
    out = model(x, mask)

print('model.initial_tensor', model.initial_tensor.shape)
visualize_encoding(model.initial_tensor.squeeze(0), 'xy')