In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from mingpt.model import GPT, Intervention_GPT
from mingpt.dataset import MyDataset, BoardDataset
import plotly.express as px
from tqdm import tqdm
import numpy as np
import torch as t
from torch.utils.data import DataLoader
from othello import Othello
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

TRAINING = False

In [3]:
model_config = GPT.get_default_config()
model_config.model_type = None
model_config.vocab_size = 61 # playable moves, 60 is pass
model_config.block_size = 59  # generated games all have length 60, so 59 inputs & 59 predictions (not trained on first move)
model_config.n_layer = 8
model_config.n_embd = 512
model_config.n_head = 8
model = GPT(model_config)
model.load_state_dict(t.load('./weights/model_state_dict_large_training_2_02.pth'))
model.to('cuda')
model.eval()

number of parameters: 25.28M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(61, 512)
    (wpe): Embedding(59, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): NewGELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_feat

In [4]:
class Board_Probe(nn.Module):
    def __init__(self, layers, h_dim, board_dim, categories):
        super().__init__()
        self.layer_probes = nn.ModuleList(nn.Linear(h_dim, board_dim * categories) for _ in range(layers))
        self.categories = categories
        self.layers = layers
        self.board_dim = board_dim
        self.h_dim = h_dim
        
    def forward(self, x, targets = None):
        ##x : layer, batch, sequence length, h_dim
        l,b,s,h = x.shape
        assert l == self.layers
        assert h == self.h_dim
        predictions = t.zeros(self.layers, b, s, self.board_dim, self.categories, device = 'cuda')
        for i, layer_probe in enumerate(self.layer_probes):
            predictions[i] = rearrange(layer_probe(x[i]), 'batch sequence (board categories) -> batch sequence board categories', board = self.board_dim, categories = self.categories)
            
            
        #x = rearrange(x, 'batch sequence layer hdim -> (batch sequence) (layer hdim)')
        if targets is not None:
            #targets: batch sequence board
            loss = F.cross_entropy(rearrange(predictions, 'layer batch sequence board categories -> (layer batch sequence board) categories'), repeat(targets, 'batch sequence board -> (layer batch sequence board)', layer = self.layers))
            return predictions, loss
        return predictions
    
    def loss_by_layer(self, x, targets):
        l,b,s,h = x.shape
        assert l == self.layers
        assert h == self.h_dim
        loss = t.zeros(self.layers, device = 'cuda')
        for i, layer_probe in enumerate(self.layer_probes):
            predictions= rearrange(layer_probe(x[i]), 'batch sequence (board categories) -> batch sequence board categories', board = self.board_dim, categories = self.categories)
            loss[i] = F.cross_entropy(rearrange(predictions, 'batch sequence board categories -> (batch sequence board) categories'), repeat(targets, 'batch sequence board -> (batch sequence board)'))
        return loss


In [5]:
activations = t.zeros(8,32,59,512, device = 'cuda')

def hook_layer(layer):
    def hook_fn(self, input, output):
        activations[layer] = output.detach()
    return hook_fn

layer_hooks = [model.transformer['h'][layer].register_forward_hook(hook_layer(layer)) for layer in range(8)]

In [6]:
probe = Board_Probe(8, 512, 64, 3)
probe.to('cuda')

Board_Probe(
  (layer_probes): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=192, bias=True)
  )
)

In [7]:
activations = t.zeros(8,32,59,512, device = 'cuda')

In [15]:
## Train or load weights

test_dataset = BoardDataset('./data/test_moves_2.npy')

if TRAINING:

    n_epochs = 100
    
    ## saves board states so don't have to regenerate
    try:
        train_dataset = BoardDataset('./data/moves.npy','./data/moves_flattened_boards.npy', alternating=True)
    except:   
        train_dataset = BoardDataset('./data/test_moves.npy', alternating=True)
        np.save('./data/moves_flattened_boards.npy', train_dataset.flattened_boards.numpy()) 

    optimizer = t.optim.Adam(probe.parameters())

    model.to('cuda')
    probe.to('cuda')

    for epoch in (pbar_epoch := tqdm(range(n_epochs))):
        train_loader = DataLoader(
            train_dataset,
            shuffle=True,
            pin_memory=True,
            batch_size=32,
            num_workers=0,
        )
        pbar = tqdm(train_loader)
        for x, y, z in pbar:
            b_dim, t_dim = x.shape
            activations = t.zeros(model_config.n_layer, b_dim, t_dim, model_config.n_embd, device = 'cuda')
            x = x.to('cuda')
            y = y.to('cuda')
            z = z.to('cuda')
            optimizer.zero_grad()
            with t.no_grad():
                model(x)
            logits, loss = probe(activations, z)
            loss_by_layer = probe.loss_by_layer(activations, z)
            loss.backward()
            optimizer.step()
            pbar_epoch.set_description(f'{str(loss_by_layer)[7:71]}')

    print(loss)
    t.save(probe.state_dict(), './weights/probe_state_dict.pth')

else:
    probe.load_state_dict(t.load('./weights/probe_state_dict.pth'))
    

In [13]:
##check probe output looks correct

for turn in range(39,42):
    fig = px.imshow(probe(activations[:,0,:,:].unsqueeze(1))[5,0,turn,:,:].argmax(dim = -1).view(8,8).cpu().detach().numpy())
    Othello.board_state(test_dataset[0][0],turn+ 1)
    fig.update_layout(width=400, height=400)

    fig.update_layout(
    margin=dict(l=0,r=0,b=0,t=0)
        )
    fig.show()

In [9]:
##normalized directions from probes

for i, p in enumerate(probe.parameters()):
    if i == 10:
        print(p.shape)
        directions = rearrange(p, '(board categories) hdim -> board categories hdim', categories = 3)
        directions = (directions / t.sqrt(t.einsum('i j k, i j k -> i j', directions, directions)).unsqueeze(-1))

torch.Size([192, 512])


In [18]:
##manual intervention to flip cell 17 i.e (2,1) from white to black

move = 17
cell = 17

def intervention_fn(input):
    print(input.shape)
    print(t.einsum('i,i -> ', input[0,move], directions[cell,0]))
    print(t.einsum('i,i -> ', input[0,move], directions[cell,1]))
    print(t.einsum('i,i -> ', input[0,move], directions[cell,2]))
    output = input.clone().detach()
    output[0,move] += 3*(directions[cell,1] - directions[cell,2])
    print(t.einsum('i,i -> ', output[0,move], directions[cell,0]))
    print(t.einsum('i,i -> ', output[0,move], directions[cell,1]))
    print(t.einsum('i,i -> ', output[0,move], directions[cell,2]))
    return output

intervention_model = Intervention_GPT(model_config, 5, intervention_fn).to('cuda')
intervention_model.load_state_dict(t.load('./weights/model_state_dict_large_training_2_02.pth'))
model.eval()
intervention_model.eval()

fig1 = px.imshow(Othello.tokens_to_ij(model(test_dataset[0][0].unsqueeze(dim=0).to(device = 'cuda'))[0][0, move, :].softmax(dim=-1)).detach().numpy())

fig1.update_layout(width=400, height=400)

fig1.update_layout(
    margin=dict(l=0,r=0,b=0,t=0)
        )

fig1.show()

fig = px.imshow(Othello.tokens_to_ij(intervention_model(test_dataset[0][0].unsqueeze(dim=0).to(device = 'cuda'))[0][0, move, :].softmax(dim=-1)).detach().numpy())

fig.update_layout(width=400, height=400)

fig.update_layout(
    margin=dict(l=0,r=0,b=0,t=0)
        )

fig.show()

Othello.board_state(test_dataset[0][0],move + 1)

number of parameters: 25.28M


torch.Size([1, 59, 512])
tensor(-0.9954, device='cuda:0', grad_fn=<ViewBackward0>)
tensor(0.0078, device='cuda:0', grad_fn=<ViewBackward0>)
tensor(1.4052, device='cuda:0', grad_fn=<ViewBackward0>)
tensor(-1.3035, device='cuda:0', grad_fn=<ViewBackward0>)
tensor(2.9501, device='cuda:0', grad_fn=<ViewBackward0>)
tensor(-1.5371, device='cuda:0', grad_fn=<ViewBackward0>)
