<a href="https://colab.research.google.com/github/tthoraldson/Interpretability_Hackathon/blob/main/notebooks/Decision_Transformer_Hackathon.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
import plotly.io as pio

# Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
pio.renderers.default = "colab"

In [2]:
!pip install git+https://github.com/neelnanda-io/Easy-Transformer.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git to /tmp/pip-req-build-ows22p1m
  Running command git clone -q https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-ows22p1m
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 561 kB/s 
Collecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 26.5 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 55.5 MB/s 
Collecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 42.2 MB/s 
[?25hCollecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

# from google.colab import drive
from pathlib import Path
import pickle
import os


import matplotlib.pyplot as plt

%matplotlib inline
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc
import collections
import copy

# import comet_ml
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

import gym

In [4]:
from easy_transformer.utils import (
  gelu_new,
  to_numpy,
  get_corner,
  lm_cross_entropy_loss,
)  # Helper functions
from easy_transformer.hook_points import (
  HookedRootModule,
  HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig
import easy_transformer
from easy_transformer.experiments import (
  ExperimentMetric,
  AblationConfig,
  EasyAblation,
  EasyPatching,
  PatchingConfig,
)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
print(device)

cuda


Some plotting code. Wrappers around Plotly, not important to understand.

In [7]:
def imshow(tensor, yaxis="", xaxis="", **kwargs):
  tensor = to_numpy(tensor)
  plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
  plot_kwargs.update(kwargs)
  px.imshow(tensor, **plot_kwargs).show()

In [131]:
def gymEnvSample(env, batch_size=5, sequence_size=10):
  batch = []

  for i in range(batch_size):
    env.reset()

    steps = []
    for j in range(sequence_size):
      action = env.action_space.sample()
      obs, reward, done, info = env.step(action)
      steps.append([reward, obs, action])

    batch.append(steps)

  tensor = torch.as_tensor(batch)
  return tensor.flatten(start_dim=1)

# Multi-armed bandits environment

In [43]:
class MultiarmedBanditsEnv(gym.Env):
  """Environment for multiarmed bandits"""
  metadata = {'render.modes': ['human']}

  def __init__(self, nr_arms=10):
    super(MultiarmedBanditsEnv, self).__init__()
    self.action_space = gym.spaces.Discrete(nr_arms)
    self.observation_space = gym.spaces.Discrete(1)
    self.state = 0
    self.reset()

  def step(self, action):
    assert self.action_space.contains(action)
    reward = self.values[action]

    return self.state, reward, False, {self.optimal}

  def reset(self):
    self.values = list(range(self.action_space.n))
    np.random.shuffle(self.values)
    self.optimal = np.argmax(self.values)
    return self.state

  def render(self, mode='human', close=False):
    print("You are playing a %d-armed bandit" % self.action_space.n)

In [None]:
gymEnvSample(MultiarmedBanditsEnv())

# Treasure hunt environment

In [142]:
class TreasureHuntEnv(gym.Env):
  """Environment for treasure hunt"""
  metadata = {'render.modes': ['human']}

  def __init__(self, nr_chests=10):
    super(TreasureHuntEnv, self).__init__()
    self.action_space = gym.spaces.Discrete(nr_chests)
    self.observation_space = gym.spaces.Discrete(1)
    self.state = 0
    self.reset()

  def step(self, action):
    assert self.action_space.contains(action)
    reward = self.values[action]

    return self.state, reward, False, {self.optimal}

  def reset(self):
    self.values = np.zeros(self.action_space.n, dtype=np.int32)
    self.values[np.random.randint(0, self.action_space.n)] = 1
    self.optimal = np.argmax(self.values)
    return self.state

  def render(self, mode='human', close=False):
    print("You are playing treasure hunt")

In [143]:
gymEnvSample(TreasureHuntEnv())

tensor([[0, 0, 7, 0, 0, 9, 0, 0, 1, 0, 0, 3, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 8,
         0, 0, 6, 0, 0, 9],
        [0, 0, 4, 0, 0, 3, 0, 0, 3, 0, 0, 7, 0, 0, 6, 0, 0, 7, 0, 0, 2, 0, 0, 8,
         0, 0, 8, 0, 0, 4],
        [1, 0, 1, 0, 0, 3, 0, 0, 9, 0, 0, 8, 1, 0, 1, 0, 0, 0, 0, 0, 3, 0, 0, 2,
         0, 0, 3, 0, 0, 4],
        [0, 0, 3, 0, 0, 6, 0, 0, 2, 1, 0, 4, 1, 0, 4, 0, 0, 9, 0, 0, 0, 0, 0, 0,
         0, 0, 7, 0, 0, 6],
        [0, 0, 3, 0, 0, 2, 0, 0, 4, 0, 0, 4, 0, 0, 8, 0, 0, 7, 0, 0, 1, 1, 0, 6,
         0, 0, 5, 0, 0, 4]])

# Training the model

In [157]:
tiny_cfg = EasyTransformerConfig(
  d_model=32,
  d_head=16,
  n_heads=2,
  d_mlp=128,
  n_layers=1,
  n_ctx=60,
  act_fn="solu_ln",
  d_vocab=150,
  normalization_type="LN",
  seed=23,  # Now we're training a custom model, it's good to set the seed to get reproducible results. It defaults to 42.
)

tiny_model = EasyTransformer(tiny_cfg).to(device)
tiny_optimizer = torch.optim.Adam(tiny_model.parameters(), lr=1e-3)
batch_size = 20
num_epochs = 301
env = TreasureHuntEnv()

assert tiny_model.cfg.n_ctx % 3 == 0, "n_ctx must be divisible by 3"

Moving model to device:  cuda
Moving model to device:  cuda


In [158]:
losses = []

# TODO maybe use a learning rate scheduler
for epoch in tqdm.tqdm(range(num_epochs)):
  batch = gymEnvSample(env, batch_size=batch_size, sequence_size=int(tiny_model.cfg.n_ctx / 3))
  loss = tiny_model(batch, return_type="loss")
  loss.backward()
  tiny_optimizer.step()
  tiny_optimizer.zero_grad()
  losses.append(loss.item())
  if epoch % 100 == 0:
    print(f"Epoch: {epoch}. Loss: {loss}")

px.line(losses, labels={"x": "Epoch", "y": "Loss"})

  0%|          | 0/301 [00:00<?, ?it/s]

Epoch: 0. Loss: 5.4372639656066895
Epoch: 100. Loss: 1.1288789510726929
Epoch: 200. Loss: 0.9185519814491272
Epoch: 300. Loss: 0.9038007259368896


In [159]:
logits, tiny_cache = tiny_model.run_with_cache(batch)
print("Loss:", losses[-1])
print("Cache:", tiny_cache)

Loss: 0.9038007259368896
Cache: ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_attn', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_mid', 'blocks.0.mlp.ln.hook_scale', 'blocks.0.mlp.ln.hook_normalized', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'ln_final.hook_scale', 'ln_final.hook_normalized']


In [170]:
with torch.no_grad():
  thing = torch.as_tensor([1])
  logit = tiny_model(thing)
  print(logit.shape, np.argmax(logit.cpu().numpy(), axis=1).shape)

torch.Size([1, 1, 150]) (1, 150)


# Visualising attention patterns

In [160]:
for layer in range(tiny_model.cfg.n_layers):
  for head in range(tiny_model.cfg.n_heads):
    imshow(to_numpy(tiny_cache[f'blocks.{layer}.attn.hook_attn'].mean(0)[head]), title=f'Layer {layer} Attention Pattern, Head {head}', height=500, width=500, xaxis="Token", yaxis="Activation")

# No quincallada

# Quincallada

In [None]:
# Install an older version of node, so that Svelte works (a web dev framework)
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
!pip install git+https://github.com/neelnanda-io/PySvelte.git
import sys
sys.path.append('/content/PySvelte')

In [None]:
import pysvelte
for layer in range(tiny_model.cfg.n_layers):
    print("Attention for layer", layer)
    pysvelte.AttentionMulti(attention=tiny_cache[f'blocks.{layer}.attn.hook_attn'].permute(1, 2, 0)).show()

In [None]:
imshow(to_numpy(tiny_cache["attn", 0].mean([0, 1])), title="Layer 0 Attention Pattern", height=500, width=500)

In [129]:
env = TreasureHuntEnv()
for i in range(2000):
  action = int(input("Action: "))
  obs, rewards, done, info = env.step(action)
  print(rewards)

Action: 0
0.0
Action: 1
0.0
Action: 2
0.0
Action: 3
1.0
Action: 4
0.0
Action: 3
1.0
Action: 3
1.0
Action: 3
Action: 2
1.0
0.0
Action: 4
0.0


KeyboardInterrupt: ignored