# Interpreting Identity

# Setup
(No need to read)

In [2]:
TRAIN_MODEL = True

In [3]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
# Upgrade pip
%pip install --upgrade pip
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")
    ipython.run_line_magic("pip", "install ipympl")
    ipython.run_line_magic("pip", "install scipy")
    ipython.run_line_magic("pip", "install manim")
    ipython.run_line_magic("pip", "install torch")
    ipython.run_line_magic("pip", "install numpy<2")
    ipython.run_line_magic("pip", "install einops")
    ipython.run_line_magic("pip", "install transformer_lens")

if IN_COLAB or IN_GITHUB:
    %pip install transformer_lens
    %pip install circuitsvis
    %pip install ipympl

Note: you may need to restart the kernel to use updated packages.
Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
/bin/bash: Zeile 1: 2: Datei oder Verzeichnis nicht gefunden
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [4]:
# Import stuff
import torch
import numpy as np
import einops
import os
import tqdm.auto as tqdm
from pathlib import Path

import copy

from transformer_lens import HookedTransformer, HookedTransformerConfig, ActivationCache

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/identity.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

cpu


  return torch._C._cuda_getDeviceCount() > 0


# Model Training

## Config

In [5]:
INPUT_DIM = 3
OUTPUT_DIM = 3
frac_train = 1

# Optimizer config
lr = 1e-3
wd = 1e-2
betas = (0.9, 0.999)

num_epochs = 10000
checkpoint_every = 500

DATA_SEED = 599

## Define Task
* Define random function
* Define the dataset & labels

Input format:
|a|b|

Convert this to a train + test set - 30% in the training set

In [29]:
def get_training_data(input_dim = INPUT_DIM, output_dim = OUTPUT_DIM, data_seed = DATA_SEED):
    # TODO: Define new data set as random two-parameter function

    torch.manual_seed(data_seed)
    a_vector = torch.arange(input_dim)
    dataset = torch.cartesian_prod(a_vector, a_vector).to(device)

    labels = torch.randint(0, output_dim, (dataset.shape[0],), device=device)
    train_data = dataset
    train_labels = labels
    # For now no test data
    test_data = dataset[0:0]
    test_labels = labels[0:0]
    print(train_data)
    print(train_labels)
    print(train_data.shape)
    print(test_data[:5])
    print(test_labels[:5])
    print(test_data.shape)
    return train_data, train_labels, test_data, test_labels

train_data, train_labels, test_data, test_labels = get_training_data()

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]])
tensor([0, 1, 2, 1, 1, 0, 2, 2, 0])
torch.Size([9, 2])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], dtype=torch.int64)
torch.Size([0, 2])


## Define Model

In [7]:

def get_seeded_model(seed = 999, input_dim = INPUT_DIM, output_dim = OUTPUT_DIM):
    cfg = HookedTransformerConfig(
        n_layers = 1,
        n_heads = 1,
        d_model = 2,
        d_head = 2,
        d_mlp = None,
        act_fn = "relu",
        normalization_type=None,
        d_vocab=input_dim,
        d_vocab_out=output_dim,
        n_ctx=2,
        init_weights=True,
        device=device,
        seed = seed,
        attn_only=True
    )
    model = HookedTransformer(cfg)
    # Biases are enabled by default
    # for name, param in model.named_parameters():
    #     if "b_" in name:
    #         param.requires_grad = False
    return model

model = get_seeded_model(seed = 993)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

## Define Optimizer + Loss

In [8]:
print(train_data)
print(model(train_data))

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]])
tensor([[[-0.0232, -0.3831,  0.0734],
         [ 0.3973,  0.0077,  0.2572]],

        [[-0.0232, -0.3831,  0.0734],
         [ 0.1674, -0.1676,  0.1479]],

        [[-0.0232, -0.3831,  0.0734],
         [ 0.1926,  0.8592, -0.0728]],

        [[-0.3110, -0.5506, -0.0756],
         [ 0.3419,  0.0151,  0.2194]],

        [[-0.3110, -0.5506, -0.0756],
         [ 0.1117, -0.1603,  0.1098]],

        [[-0.3110, -0.5506, -0.0756],
         [ 0.1047,  0.8789, -0.1347]],

        [[-0.0824,  0.2363, -0.1083],
         [ 0.4788, -0.1707,  0.3516]],

        [[-0.0824,  0.2363, -0.1083],
         [ 0.2537, -0.3527,  0.2469]],

        [[-0.0824,  0.2363, -0.1083],
         [ 0.3217,  0.6311,  0.0640]]], grad_fn=<ViewBackward0>)


In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)
print("Uniform loss:")
print(np.log(INPUT_DIM * INPUT_DIM))

tensor(1.2079, dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(nan, dtype=torch.float64, grad_fn=<NegBackward0>)
Uniform loss:
2.1972245773362196


## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [10]:
def print_stats(model):
    # Extract the p 2-dimensional tensors, vector i is vec[:, i]
    vec = model.W_U.data

    # Function to compute the angle between two vectors
    def compute_angle(v1, v2):
        cos_theta = torch.dot(v1, v2) / (torch.norm(v1) * torch.norm(v2))
        angle = torch.acos(cos_theta) * (180.0 / np.pi)
        return angle

    # Compute pairwise angles
    # for i in range(vec.shape[1]):
    #     for j in range(i+1, vec.shape[1]):
    #         angle = compute_angle(vec[:, i], vec[:, j])
    #         print(f"Angle between {i} and {j}: {angle.item():.2f}°")
    #     print(f"Norm of vector {i}: {torch.norm(vec[:, i]):.2f}")

print_stats(model)

In [11]:
def train_model(model, train_data, train_labels, test_data, test_labels, num_epochs = num_epochs, loss_target = None):
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=lr, weight_decay=wd, betas=betas
    )
    train_losses = []
    test_losses = []
    model_checkpoints = []
    checkpoint_epochs = []
    if TRAIN_MODEL:
        for epoch in tqdm.tqdm(range(num_epochs)):
            train_logits = model(train_data)
            train_loss = loss_fn(train_logits, train_labels)
            train_loss.backward()
            train_losses.append(train_loss.item())
            if loss_target is not None and train_loss.item() < loss_target:
                print(f"Loss target {loss_target} reached with loss {train_loss.item()} at epoch {epoch}")
                break

            optimizer.step()
            optimizer.zero_grad()

            with torch.inference_mode():
                test_logits = model(test_data)
                test_loss = loss_fn(test_logits, test_labels)
                test_losses.append(test_loss.item())

            if ((epoch+1)%checkpoint_every)==0:
                checkpoint_epochs.append(epoch)
                model_checkpoints.append(copy.deepcopy(model.state_dict()))
                print_stats(model)
                print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
    if TRAIN_MODEL:
        torch.save(
            {
                "model":model.state_dict(),
                "config": model.cfg,
                "checkpoints": model_checkpoints,
                "checkpoint_epochs": checkpoint_epochs,
                "test_losses": test_losses,
                "train_losses": train_losses,
            },
            PTH_LOCATION)

train_model(
    model, train_data, train_labels, test_data, test_labels
)

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 499 Train Loss 0.38312511649032743 Test Loss nan
Epoch 999 Train Loss 0.059466672574566455 Test Loss nan
Epoch 1499 Train Loss 0.017834876413487258 Test Loss nan
Epoch 1999 Train Loss 0.008198291490770912 Test Loss nan
Epoch 2499 Train Loss 0.0045474221205888754 Test Loss nan
Epoch 2999 Train Loss 0.002792479299131615 Test Loss nan
Epoch 3499 Train Loss 0.0018253827444404303 Test Loss nan
Epoch 3999 Train Loss 0.001243454240354598 Test Loss nan
Epoch 4499 Train Loss 0.0008715261628919556 Test Loss nan
Epoch 4999 Train Loss 0.000623363184824029 Test Loss nan
Epoch 5499 Train Loss 0.00045251579865528753 Test Loss nan
Epoch 5999 Train Loss 0.0003321229238846405 Test Loss nan
Epoch 6499 Train Loss 0.0002457877433733275 Test Loss nan
Epoch 6999 Train Loss 0.00018305811923903094 Test Loss nan
Epoch 7499 Train Loss 0.00013701582525633196 Test Loss nan
Epoch 7999 Train Loss 0.00010295625678972151 Test Loss nan
Epoch 8499 Train Loss 7.760370500183763e-05 Test Loss nan
Epoch 8999 Train Los

In [12]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']

# Activations

In [13]:
def create_cache(model):
    input = train_data
    print("Transposed Input:\n", input.transpose(0, 1))
    logits, cache = model.run_with_cache(input)
    print("Labels: ", train_labels)
    print("Logits of last token:\n", logits[:, -1, :])
    print("Unembed:\n", model.W_U.data)
    print("Last layer before unembed:\n", cache.cache_dict["blocks.0.hook_resid_post"][:, -1, :])
    return cache

cache = create_cache(model)

Transposed Input:
 tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
Labels:  tensor([0, 1, 2, 1, 1, 0, 2, 2, 0])
Logits of last token:
 tensor([[ 10.4829,  -0.3025,  -2.2474],
        [  6.5879,  17.4972, -25.4383],
        [ 13.5779, -21.1048,  24.9873],
        [ 11.3726,  29.8343, -42.1891],
        [  6.9325,  32.8404, -45.7635],
        [ 15.0433,  -1.8849,  -0.5669],
        [  7.1881, -14.3199,  16.5910],
        [  2.7490, -11.2691,  12.9574],
        [ 13.9139,  -5.9930,   4.9690]], grad_fn=<SliceBackward0>)
Unembed:
 tensor([[ 1.8803,  1.9334, -2.7273],
        [-1.8391,  0.6233, -0.6581]])
Last layer before unembed:
 tensor([[  1.4772,  -3.7749],
        [  7.8878,   4.8970],
        [ -6.2070, -13.3137],
        [ 13.3178,   7.8469],
        [ 13.9018,  10.8582],
        [  1.4628,  -6.2691],
        [ -4.4100,  -8.0023],
        [ -3.8085,  -4.9736],
        [ -0.2841,  -7.4411]])


In [14]:
-1.3451 * -3.4785 + 0.7905 * 5.5609

9.074821799999999

In [15]:
# Get a list of all activations stored in the cache, especially their names
cache.cache_dict

{'hook_embed': tensor([[[-1.9119, -1.4572],
          [-1.9119, -1.4572]],
 
         [[-1.9119, -1.4572],
          [-1.3104,  1.5715]],
 
         [[-1.9119, -1.4572],
          [ 2.2279, -0.8822]],
 
         [[-1.3104,  1.5715],
          [-1.9119, -1.4572]],
 
         [[-1.3104,  1.5715],
          [-1.3104,  1.5715]],
 
         [[-1.3104,  1.5715],
          [ 2.2279, -0.8822]],
 
         [[ 2.2279, -0.8822],
          [-1.9119, -1.4572]],
 
         [[ 2.2279, -0.8822],
          [-1.3104,  1.5715]],
 
         [[ 2.2279, -0.8822],
          [ 2.2279, -0.8822]]]),
 'hook_pos_embed': tensor([[[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.2175],
          [-0.7833, -1.3808]],
 
         [[-0.3115,  0.

# Animations

In [16]:
from manim import *

config.media_width = "80%"
config.verbosity = "WARNING"
config.preview = False

In [25]:
class VectorParams:
    def __init__(self, values = [], color = WHITE, label = ""):
        self.values = values
        self.color = color
        self.label = label
    def __repr__(self) -> str:
        return str(self.values) + "(" + str(self.color) + ")"

class Data:
    def __init__(self):
        self.vectors: list[list[VectorParams]] = [[]]
        self.steps = 0
        self.current_labels = set()

    def add_vector(self, vector, color = WHITE, label = ""):
        if label not in self.current_labels:
            self.current_labels.add(label)
            self.vectors[self.steps].append(VectorParams(values = vector, color = color, label = label))

    def next_step(self):
        self.steps += 1
        self.vectors.append([])
        self.current_labels = set()

    def add_vectors_at_hook(self, c: ActivationCache, hook: str, color0 = WHITE, color1 = WHITE, input_labels = None, input_colors = None):
        if input_labels is None:
            input_labels = ["" for i in range(c.cache_dict[hook].shape[0])]
        for i in range(c.cache_dict[hook].shape[0]):
            self.add_vector(c.cache_dict[hook][i][0].cpu(), color = color0, label = input_labels[i][:1])
            self.add_vector(
                c.cache_dict[hook][i][1].cpu(), color=color1 if input_colors is None else input_colors[i], label=input_labels[i][:2]
            )


def compile_data_vectors(cache, input_labels=None, input_colors=None):
    # Set default value as list of empty strings
    vectors = Data()
    vectors.add_vectors_at_hook(cache, "hook_embed", color1 = GRAY, input_labels=input_labels, input_colors=input_colors)
    vectors.next_step()
    vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_pre", input_labels=input_labels, input_colors=input_colors)
    vectors.next_step()
    vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_post", color0 = GRAY, input_labels=input_labels, input_colors=input_colors)
    # vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_mid")

    print(vectors.vectors)
    return vectors


vectors = compile_data_vectors(cache)

[[tensor([ 0.8873, -1.3817])(#FFFFFF)], [tensor([ 2.7090, -0.1558])(#FFFFFF)], [tensor([6.0623, 5.5371])(#888888)]]


In [18]:
def change_font_size(labeled_arrow: LabeledArrow, new_size):
    # print(labeled_arrow, labeled_arrow.submobjects)
    # print(labeled_arrow.submobjects[-1].font_size)
    if not isinstance(labeled_arrow, LabeledArrow):
        return
    label = labeled_arrow.submobjects[-1]
    box = labeled_arrow.submobjects[-2]
    if not isinstance(box, BackgroundRectangle):
        box = labeled_arrow.submobjects[-3]
    coords = label.get_center()
    # print(new_size)
    labeled_arrow.submobjects[-1] = MathTex(
        label.get_tex_string(), color=label.color, font_size=new_size
    )
    # print("size=", labeled_arrow.submobjects[-1].font_size)
    label = labeled_arrow.submobjects[-1]
    label.move_to(coords)
    box.width = label.width + 2 * box.buff
    box.height = label.height + 2 * box.buff

In [49]:
DOT_SCALE = 0.05
class VisualizeTransformer(MovingCameraScene):
    def construct(self):
        print("v=", vectors.vectors)
        axes = Axes(
            x_range = [-20, 20, 1],
            y_range = [-20, 20, 1],
            x_axis_config={
                "numbers_to_include": np.arange(-18, 18.1, 3),
                "font_size": 24
            },
            y_axis_config={
                "numbers_to_include": np.arange(-18, 18.1, 3), 
                "font_size": 24            
            },
            x_length = 40,
            y_length = 40,
            axis_config={"color": GREEN}
        )

        scale = ValueTracker(2)

        dots = VGroup()
        def update_scale(self):
            return
            # TODO: Make the scaling nicer
            self.stroke_width = 6 * scale.get_value()
            change_font_size(self, 48 * scale.get_value())
            # print("New font size: ", self.font_size)

        # Embedding arrows
        for i, t in enumerate(vectors.vectors[0]):
            # print(t, t.numpy())
            # arrow = LabeledArrow(
            #     start=ORIGIN,
            #     end=np.append(t.values.numpy(), 0),
            #     buff = 0,
            #     label = t.label,
            #     label_frame = False,
            #     label_color=YELLOW,
            #     color = t.color,
            #     max_stroke_width_to_length_ratio = 100,
            # )

            # arrow.add_updater(update_scale)
            # arrows.add(arrow)
            dot = LabeledDot(
                point=np.append(t.values.numpy(), 0),
                label=t.label,
                color=t.color,
                radius=DOT_SCALE * (len(vectors.vectors[0]) - i) + 0.2,
            )

            dot.add_updater(update_scale)
            dots.add(dot)

        # Transitioning the arrows through the model
        self.add(axes, axes.get_axis_labels(), dots)
        for step in range(1, len(vectors.vectors)):
            new_dots = VGroup()
            transition_arrows = VGroup()
            for i, t in enumerate(vectors.vectors[step]):
                # print(t, t.numpy())
                # new_arrow = LabeledArrow(
                #     start=ORIGIN,
                #     end=np.append(t.values.numpy(), 0),
                #     buff=0,
                #     label=t.label,
                #     label_frame=False,
                #     label_color=YELLOW,
                #     color=t.color,
                #     max_stroke_width_to_length_ratio=100,
                # )
                # new_arrow.add_updater(update_scale)
                # new_arrows.add(new_arrow)
                new_dot = LabeledDot(
                    point=np.append(t.values.numpy(), 0),
                    label=t.label,
                    color=t.color,
                    radius = DOT_SCALE * (len(vectors.vectors[step]) - i) + 0.2,
                )
                new_dot.add_updater(update_scale)
                new_dots.add(new_dot)

                transition_arrow = Arrow(
                    start=dots[i].arc_center,
                    end=new_dots[i].arc_center,
                    buff=0,
                    color=RED,
                )
                transition_arrow.add_updater(update_scale)
                transition_arrows.add(transition_arrow)

            view = SurroundingRectangle(new_dots)
            factor = max(
                view.width / self.camera.frame_width,
                view.height / self.camera.frame_height,
            )
            print(
                factor,
                self.camera.frame_width, view.width,
                self.camera.frame_height, view.height,
            )
            self.wait()
            self.play(FadeIn(transition_arrows), self.camera.auto_zoom(view, margin = 2), scale.animate.set_value(scale.get_value() * factor))
            self.wait()
            self.play(
                ReplacementTransform(dots, new_dots)
            )
            self.wait()
            self.play(FadeOut(transition_arrows))
            self.wait()
            dots = new_dots

        # Unembedding Arrows
        embedding_arrows = VGroup()
        data = model.W_U.data
        print("unembed: ", data)
        for i in range(model.W_U.data.size()[1]):
            embedding_arrow = LabeledArrow(
                start=ORIGIN,
                end=[data[0, i].item(), data[1, i].item(), 0],
                label=str(i),
                color=LIGHT_PINK,
                buff=0,
                max_stroke_width_to_length_ratio=100,
            )
            embedding_arrows.add(embedding_arrow)
        self.play(FadeIn(embedding_arrows))
        self.wait()

# v = VisualizeTransformer()
# v.construct()

In [40]:
dot = LabeledDot(label="A")
dot

LabeledDot

In [50]:
%%manim -qh Video

input_dim = 4
output_dim = 2
train_data, train_labels, test_data, test_labels = get_training_data(input_dim, output_dim, data_seed=999)
# model = get_seeded_model(998, input_dim, output_dim)
# train_model(model, train_data, train_labels, test_data, test_labels, num_epochs = 10000, loss_target = 1/(input_dim ** 2 * 4))
cache = create_cache(model)
arrow_labels = ["".join([str(d.item()) for d in v]) for v in train_data]
colors = [BLUE, YELLOW, GREEN, RED]
arrow_colors = [colors[l] for l in train_labels]
vectors = compile_data_vectors(cache, input_labels=arrow_labels, input_colors = arrow_colors)
print("Labels: ", arrow_labels)

class Video(VisualizeTransformer):
    def construct(self):
        VisualizeTransformer.construct(self)

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3]])
tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
torch.Size([16, 2])
tensor([], size=(0, 2), dtype=torch.int64)
tensor([], dtype=torch.int64)
torch.Size([0, 2])
Transposed Input:
 tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]])
Labels:  tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
Logits of last token:
 tensor([[  0.5407,  -0.1525],
        [  0.5307,  -0.1625],
        [ -7.3317,   7.2997],
        [-10.1159,  11.8442],
        [  0.5406,  -0.1524],
        [  0.5307,  -0.1625],
        [ -7.0416,   6.9686],
        [-10.5721,  12.3541],
        [ -5.5939,   6.6703],
        [ -5.5159,   6.5641],
        [  4.4402,  -6.0507],
        [ -9.5549, 

v= [[tensor([-1.3298,  1.1375])(#FFFFFF), tensor([-1.3298,  1.1375])(#58C4DD), tensor([-1.3044,  1.1679])(#58C4DD), tensor([1.0662, 2.0783])(#FFFF00), tensor([-0.1918, -1.5593])(#FFFF00), tensor([-1.3044,  1.1679])(#FFFFFF), tensor([-1.3298,  1.1375])(#58C4DD), tensor([-1.3044,  1.1679])(#FFFF00), tensor([1.0662, 2.0783])(#FFFF00), tensor([-0.1918, -1.5593])(#FFFF00), tensor([1.0662, 2.0783])(#FFFFFF), tensor([-1.3298,  1.1375])(#FFFF00), tensor([-1.3044,  1.1679])(#FFFF00), tensor([1.0662, 2.0783])(#58C4DD), tensor([-0.1918, -1.5593])(#FFFF00), tensor([-0.1918, -1.5593])(#FFFFFF), tensor([-1.3298,  1.1375])(#FFFF00), tensor([-1.3044,  1.1679])(#58C4DD), tensor([1.0662, 2.0783])(#58C4DD), tensor([-0.1918, -1.5593])(#58C4DD)], [tensor([-3.6370,  0.9759])(#FFFFFF), tensor([-1.2176,  0.3486])(#58C4DD), tensor([-1.1922,  0.3789])(#58C4DD), tensor([1.1784, 1.2893])(#FFFF00), tensor([-0.0796, -2.3483])(#FFFF00), tensor([-3.6116,  1.0062])(#FFFFFF), tensor([-1.2176,  0.3486])(#58C4DD), tensor

                                                                                                            

1.967785768951384 14.515331734551324 9.190085887908936 8.16487410068512 16.06672306060791


                                                                                                             

unembed:  tensor([[-0.3928, -0.0660],
        [ 1.2664, -1.6290]])


                                                                                              

In [21]:
model.W_U.data

tensor([[ 1.9119, -1.8526, -0.0829],
        [ 0.0683, -2.1787,  2.0396]])

In [22]:
vectors.vectors

[[tensor([ 0.8166, -1.0588])(#FFFFFF),
  tensor([ 0.8166, -1.0588])(#58C4DD),
  tensor([ 0.8166, -1.0588])(#FFFFFF),
  tensor([-1.3558, -1.9105])(#FFFF00),
  tensor([ 0.8166, -1.0588])(#FFFFFF),
  tensor([0.8130, 2.3552])(#83C167),
  tensor([-1.3558, -1.9105])(#FFFFFF),
  tensor([ 0.8166, -1.0588])(#FFFF00),
  tensor([-1.3558, -1.9105])(#FFFFFF),
  tensor([-1.3558, -1.9105])(#FFFF00),
  tensor([-1.3558, -1.9105])(#FFFFFF),
  tensor([0.8130, 2.3552])(#58C4DD),
  tensor([0.8130, 2.3552])(#FFFFFF),
  tensor([ 0.8166, -1.0588])(#83C167),
  tensor([0.8130, 2.3552])(#FFFFFF),
  tensor([-1.3558, -1.9105])(#83C167),
  tensor([0.8130, 2.3552])(#FFFFFF),
  tensor([0.8130, 2.3552])(#58C4DD)],
 [tensor([ 2.1039, -0.3284])(#FFFFFF),
  tensor([-0.4866, -1.6315])(#58C4DD),
  tensor([ 2.1039, -0.3284])(#FFFFFF),
  tensor([-2.6590, -2.4832])(#FFFF00),
  tensor([ 2.1039, -0.3284])(#FFFFFF),
  tensor([-0.4902,  1.7825])(#83C167),
  tensor([-0.0685, -1.1802])(#FFFFFF),
  tensor([-0.4866, -1.6315])(#FFFF00