<a href="https://colab.research.google.com/github/victorlf4/MineRLDecisionTransformers/blob/master/notebook/Decision_Transformer_Hackathon.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
import plotly.io as pio

# Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
pio.renderers.default = "colab"

In [2]:
!pip install git+https://github.com/neelnanda-io/Easy-Transformer.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git to /tmp/pip-req-build-f9j92f1m
  Running command git clone -q https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-f9j92f1m
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 522 kB/s 
Collecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 54.1 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 52.2 MB/s 
Collecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 49.6 MB/s 
[?25hCollecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

# from google.colab import drive
from pathlib import Path
import pickle
import os


import matplotlib.pyplot as plt

%matplotlib inline
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc
import collections
import copy

# import comet_ml
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

import gym

In [4]:
from easy_transformer.utils import (
  gelu_new,
  to_numpy,
  get_corner,
  lm_cross_entropy_loss,
)  # Helper functions
from easy_transformer.hook_points import (
  HookedRootModule,
  HookPoint,
)  # Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig
import easy_transformer
from easy_transformer.experiments import (
  ExperimentMetric,
  AblationConfig,
  EasyAblation,
  EasyPatching,
  PatchingConfig,
)

In [46]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
print(device)

cuda


Some plotting code. Wrappers around Plotly, not important to understand.

In [7]:
def imshow(tensor, yaxis="", xaxis="", **kwargs):
  tensor = to_numpy(tensor)
  plot_kwargs = {"color_continuous_scale":"RdBu", "color_continuous_midpoint":0.0, "labels":{"x":xaxis, "y":yaxis}}
  plot_kwargs.update(kwargs)
  px.imshow(tensor, **plot_kwargs).show()

#UTILS

In [114]:
def gymEnvSample(env, batch_size=5, sequence_size=10,reward2go=False):
  batch = []

  for i in range(batch_size):
    env.reset()
    
    steps = []
    for j in range(sequence_size):
      action = env.action_space.sample()
      obs, reward, done, info = env.step(action)
      steps.append([reward, obs, action])

    #totalReward= np.sum(steps)
    #for step in steps:
    batch.append(steps)

  tensor = torch.as_tensor(batch)
  return tensor.flatten(start_dim=1)

In [44]:
def array2text(array):
  str_array=[]
  assert len(array)% 3 == 0, "array lenght must be divisible by 3"
  for i in range(int(len(array)/3)):
      str_array.append("S"+str(i)+"_R:"+str(array[i]))
      str_array.append("S"+str(i)+"_O:"+str(array[i+1]))
      str_array.append("S"+str(i)+"_A:"+str(array[i+1]))
  return str_array

print(array2text([1,2,3]))

['S0_R:1', 'S0_O:2', 'S0_A:2']


# Multi-armed bandits environment

In [9]:
class MultiarmedBanditsEnv(gym.Env):
  """Environment for multiarmed bandits"""
  metadata = {'render.modes': ['human']}

  def __init__(self, nr_arms=10):
    super(MultiarmedBanditsEnv, self).__init__()
    self.action_space = gym.spaces.Discrete(nr_arms)
    self.observation_space = gym.spaces.Discrete(1)
    self.state = 0
    self.reset()

  def step(self, action):
    assert self.action_space.contains(action)
    reward = self.values[action]

    return self.state, reward, False, {self.optimal}

  def reset(self):
    self.values = list(range(self.action_space.n))
    np.random.shuffle(self.values)
    self.optimal = np.argmax(self.values)
    return self.state

  def render(self, mode='human', close=False):
    print("You are playing a %d-armed bandit" % self.action_space.n)

In [10]:
gymEnvSample(MultiarmedBanditsEnv())

tensor([[2, 0, 4, 0, 0, 2, 7, 0, 8, 4, 0, 1, 8, 0, 3, 0, 0, 2, 0, 0, 2, 8, 0, 3,
         0, 0, 2, 3, 0, 9],
        [9, 0, 1, 3, 0, 3, 0, 0, 0, 1, 0, 5, 0, 0, 0, 8, 0, 8, 4, 0, 6, 2, 0, 4,
         4, 0, 6, 1, 0, 5],
        [1, 0, 1, 4, 0, 0, 3, 0, 8, 3, 0, 8, 3, 0, 8, 9, 0, 6, 0, 0, 7, 4, 0, 0,
         3, 0, 8, 3, 0, 8],
        [4, 0, 1, 9, 0, 5, 2, 0, 6, 7, 0, 4, 8, 0, 0, 6, 0, 3, 6, 0, 3, 1, 0, 9,
         9, 0, 5, 6, 0, 3],
        [4, 0, 4, 1, 0, 0, 9, 0, 5, 8, 0, 8, 9, 0, 5, 6, 0, 7, 0, 0, 3, 2, 0, 6,
         3, 0, 2, 6, 0, 7]])

# Treasure hunt environment

In [99]:
class TreasureHuntEnv(gym.Env):
  """Environment for treasure hunt"""
  metadata = {'render.modes': ['human']}

  def __init__(self, nr_chests=10):
    super(TreasureHuntEnv, self).__init__()
    self.action_space = gym.spaces.Discrete(nr_chests)
    self.observation_space = gym.spaces.Discrete(1)
    self.state = 7
    self.reset()

  def step(self, action):
    assert self.action_space.contains(action)
    reward = self.values[action]

    return self.state, reward, False, {self.optimal}

  def reset(self):
    self.values = np.zeros(self.action_space.n, dtype=np.int32)
    self.values[np.random.randint(0, self.action_space.n)] = 1
    self.optimal = np.argmax(self.values)
    return self.state

  def render(self, mode='human', close=False):
    print("You are playing treasure hunt")

In [12]:
gymEnvSample(TreasureHuntEnv())

tensor([[0, 0, 0, 0, 0, 4, 0, 0, 4, 1, 0, 2, 0, 0, 6, 0, 0, 8, 0, 0, 3, 0, 0, 3,
         0, 0, 6, 0, 0, 7],
        [0, 0, 8, 0, 0, 1, 0, 0, 1, 0, 0, 3, 0, 0, 9, 0, 0, 4, 1, 0, 0, 0, 0, 3,
         0, 0, 3, 0, 0, 1],
        [1, 0, 6, 0, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0, 0, 9,
         0, 0, 3, 0, 0, 2],
        [0, 0, 0, 0, 0, 7, 0, 0, 3, 1, 0, 2, 0, 0, 6, 0, 0, 7, 0, 0, 5, 0, 0, 5,
         0, 0, 5, 0, 0, 0],
        [0, 0, 7, 0, 0, 8, 0, 0, 4, 0, 0, 1, 0, 0, 9, 0, 0, 7, 0, 0, 9, 1, 0, 2,
         1, 0, 2, 0, 0, 5]])

# TestEnviroment

In [100]:
class TestEnv(gym.Env):
  """Environment for treasure hunt"""
  metadata = {'render.modes': ['human']}

  def __init__(self, nr_chests=10):
    super(TestEnv, self).__init__()
    self.action_space = gym.spaces.Discrete(1)
    self.observation_space = gym.spaces.Discrete(1)
    self.state = 1
    self.reset()

  def step(self, action):
    assert self.action_space.contains(action)
    reward = action

    return self.state, reward, False,[0]

  def reset(self):
    return self.state

  def render(self, mode='human', close=False):
    print("test")

In [75]:
gymEnvSample(TestEnv())


tensor([[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0]])

# Training the model

In [111]:
tiny_cfg = EasyTransformerConfig(
  d_model=32,
  d_head=16,
  n_heads=2,
  d_mlp=128,
  n_layers=2,
  n_ctx=60,
  act_fn="solu_ln",
  d_vocab=10,
  normalization_type="LN",
  seed=23,  # Now we're training a custom model, it's good to set the seed to get reproducible results. It defaults to 42.
)

tiny_model = EasyTransformer(tiny_cfg).to(device)
tiny_optimizer = torch.optim.Adam(tiny_model.parameters(), lr=1e-3)
batch_size = 128
num_epochs = 301
env = TreasureHuntEnv()

assert tiny_model.cfg.n_ctx % 3 == 0, "n_ctx must be divisible by 3"

Moving model to device:  cuda
Moving model to device:  cuda


AssertionError: ignored

In [110]:
losses = []

# TODO maybe use a learning rate scheduler
for epoch in tqdm.tqdm(range(num_epochs)):
  batch = gymEnvSample(env, batch_size=batch_size, sequence_size=int(tiny_model.cfg.n_ctx / 3))
  #batch = torch.zeros((batch_size,tiny_model.cfg.n_ctx),dtype=torch.long) #test zeros
  loss = tiny_model(batch, return_type="loss")
  loss.backward()
  tiny_optimizer.step()
  tiny_optimizer.zero_grad()
  losses.append(loss.item())
  if epoch % 100 == 0:
    print(f"Epoch: {epoch}. Loss: {loss}")

px.line(losses, labels={"x": "Epoch", "y": "Loss"})

  0%|          | 0/301 [00:00<?, ?it/s]

Epoch: 0. Loss: 2.585441827774048
Epoch: 100. Loss: 0.8914234638214111
Epoch: 200. Loss: 0.8806241750717163
Epoch: 300. Loss: 0.8800145983695984


In [112]:
logits, tiny_cache = tiny_model.run_with_cache(batch)
print("Loss:", losses[-1])
print("Cache:", tiny_cache)

RuntimeError: ignored

In [113]:
with torch.no_grad():
  thing = torch.as_tensor([0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
         0, 1, 0, 0, 1, 0])
  logit = tiny_model(thing)
  print(thing.shape, np.argmax(logit.cpu().numpy(), axis=2).shape)
  print(logit.shape, np.argmax(logit.cpu().numpy(), axis=2).shape)
  print(np.argmax(logit.cpu().numpy(), axis=2))
#for i in tqdm.tqdm(range(100)):
#    test_tokens = reference_gpt2.to_tokens(test_string).cuda()
#   demo_logits = demo_gpt2(test_tokens)
#    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
#print(test_string)

torch.Size([30]) (1, 30)
torch.Size([1, 30, 10]) (1, 30)
[[3 0 8 8 3 3 3 3 3 3 8 3 3 3 3 3 3 8 8 8 3 3 3 8 3 3 3 3 3 3]]


#Experiments

##Utils

In [93]:
tokens_test=[1,0,1,0,5,0,1,0,0]

##Visualising attention patterns

In [84]:
for layer in range(tiny_model.cfg.n_layers):
  for head in range(tiny_model.cfg.n_heads):
    imshow(to_numpy(tiny_cache[f'blocks.{layer}.attn.hook_attn'].mean(0)[head]), title=f'Layer {layer} Attention Pattern, Head {head}', height=500, width=500, xaxis="Token", yaxis="Activation")

In [20]:
# Install an older version of node, so that Svelte works (a web dev framework)
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
!pip install git+https://github.com/neelnanda-io/PySvelte.git
import sys
sys.path.append('/content/PySvelte')


## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Get:1 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]
Hit:2 http://archive.ubuntu.com/ubuntu bionic InRelease
Get:3 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ Packages [101 kB]
Get:4 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease [15.9 kB]
Get:5 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]
Get:6 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]
Hit:7 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease
Get:8 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [83.3 kB]
Ign:9 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Hit:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
Hit:11 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  

In [21]:
tokens_test=[1,0,1,0,5,0,1,0,0]

['S0_R:1', 'S0_O:2', 'S0_A:2']


In [85]:
import pysvelte
test_tokens_str=array2text(tokens_test)
test_tokens=torch.as_tensor(tokens_test)
logits, tiny_cache = tiny_model.run_with_cache(test_tokens)

for layer in range(tiny_model.cfg.n_layers):
    print("Attention for layer", layer)
    
    print(tiny_cache[f'blocks.{layer}.attn.hook_attn'][0].shape)
    pysvelte.AttentionMulti(attention=tiny_cache[f'blocks.{layer}.attn.hook_attn'][0].permute(1, 2, 0),tokens=test_tokens_str).show() #.attn.hook_attn']

Attention for layer 0
torch.Size([2, 9, 9])


Attention for layer 1
torch.Size([2, 9, 9])


##Direct logit atribution
###Neels explanation

The residual stream is the sum of the output of each layer, so we can map each layer output to a set of logits, and look at the contribution to the logit diff to figure out how much it contributes.

Further, as the output of the attn layer is the sum of each head (plus a bias term), we can look at each head's contribution.

This is called the direct logit attribution - direct since we're looking at the logits, so there's just a linear map (holding layernorm scale fixed) (Exercise: Why is this not direct if we're looking at the loss or log probs?)

We can calculate this more efficiently by dotting with W_U[John] - W_U[Mary], which just a direction in the residual stream

In [92]:
#TODO modify this to work on our code

# Set use attn result to True - this gives us a hook for the result of each head, 
# ie the d_model length vectors whose sum makes up attn_out and is added to the residual stream
tiny_model.cfg.use_attn_result = True

logit_diff_direction = tiny_model.unembed.W_U[:, 0] - tiny_model.unembed.W_U[:, 1]
# Take the scaling factor of the layernorm pre-unembed on the final token, so our logit attrs are on the same scale
final_layer_norm_scale = (tiny_cache['ln_final.hook_scale'][0, -1])
print("Final layer norm scaling factor:", final_layer_norm_scale.item())

direct_logit_attr = torch.zeros(tiny_model.cfg.n_layers, tiny_model.cfg.n_heads).to(device)
def calc_direct_logit_attr(result, hook):
    layer = int(hook.name.split('.')[1])
    final_token_result = result[0, -1]
    direct_logit_attr[layer] = (final_token_result @ (logit_diff_direction))/final_layer_norm_scale

tiny_model.run_with_hooks(tokens_test, fwd_hooks = [(lambda name:name.endswith('hook_result'), calc_direct_logit_attr)])

imshow(direct_logit_attr, xaxis='Head', yaxis='Layer', title='Direct Logit Attribution')

# Switch use_attn_result back off, since it consumes a lot of memory (Exercise: Why is this more expensive than eg calculating the value?)
tiny_model.cfg.use_attn_result = False

Final layer norm scaling factor: 3.6594507694244385


AssertionError: ignored

##Ablating Heads

In [None]:
def ablate_head_hook(value, hook, head_index):
    # Shape of value: batch x position x head_index x d_head
    value[:, :, head_index] = 0.
    return value
head_ablation = torch.zeros((tiny_model.cfg.n_layers, tiny_model.cfg.n_heads))
for layer in tqdm.tqdm(range(tiny_model.cfg.n_layers)):
    for head_index in range(tiny_model.cfg.n_heads):
        logits = tiny_model.run_with_hooks(example_text, fwd_hooks=[(f"blocks.{layer}.attn.hook_v", partial(ablate_head_hook, head_index=head_index))])
        ablated_logit_diff = get_logit_diff(logits)
        change_in_logit_diff = ablated_logit_diff - example_logit_diff #Negative = strong effect
        head_ablation[layer, head_index]=((change_in_logit_diff))
imshow(head_ablation, title='Effect of ablating each head', xaxis='Head', yaxis='Layer')

# Quincallada

In [None]:
imshow(to_numpy(tiny_cache["attn", 0].mean([0, 1])), title="Layer 0 Attention Pattern", height=500, width=500)

In [None]:
env = TreasureHuntEnv()
for i in range(2000):
  action = int(input("Action: "))
  obs, rewards, done, info = env.step(action)
  print(rewards)