In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from mingpt.model import GPT
from mingpt.dataset import MyDataset
import plotly.express as px
from tqdm import tqdm
import numpy as np
import torch as t
from torch.utils.data import DataLoader
from othello import Othello

In [262]:
model_config = GPT.get_default_config()
model_config.model_type = None
model_config.vocab_size = 61 # playable moves, 60 is pass
model_config.block_size = 59  # generated games all have length 60, so 59 inputs & 59 predictions (not trained on first move)
model_config.n_layer = 8
model_config.n_embd = 512
model_config.n_head = 8
model = GPT(model_config)
model.load_state_dict(t.load('./weights/model_state_dict_large_training_2_02.pth'))
test_data = t.tensor(np.load('./data/test_moves.npy'), dtype=t.long)
test_data[0]

number of parameters: 25.28M


tensor([40, 41, 26, 20, 42, 48, 33, 25, 55, 43, 17,  8, 32, 39, 31, 56,  9, 27,
        28, 54, 13, 23, 29, 18, 50, 30, 57, 47, 53, 49, 34, 51, 58, 10, 46, 45,
         2, 12, 15, 37, 38, 14, 16, 24, 36, 11,  1, 52,  3, 35, 21,  4,  7, 22,
        59,  6,  0, 44,  5, 60])

In [263]:
model.to('cuda')
model.eval()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(61, 512)
    (wpe): Embedding(59, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): NewGELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_feat

In [232]:
model.transformer['h'][0]

Block(
  (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (attn): CausalSelfAttention(
    (c_attn): Linear(in_features=512, out_features=1536, bias=True)
    (c_proj): Linear(in_features=512, out_features=512, bias=True)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (mlp): ModuleDict(
    (c_fc): Linear(in_features=512, out_features=2048, bias=True)
    (c_proj): Linear(in_features=2048, out_features=512, bias=True)
    (act): NewGELU()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [264]:
activations = {}

In [270]:
def hook_layer(layer):
    def hook_fn(self, input, output):
        activations[layer] = output.detach().to('cuda')
    return hook_fn

In [271]:
layer_hooks = [model.transformer['h'][layer].register_forward_hook(hook_layer(layer)) for layer in range(8)]

In [236]:
model(test_data[0,:-1].unsqueeze(dim = 0).to('cuda'))

(tensor([[[-0.8208, -0.5618, -1.4155,  ..., -1.0816, -0.5452, -1.3990],
          [-1.8787, -0.7847, -0.5960,  ..., -1.1909, -0.8045, -2.6238],
          [-1.6215, -1.5087, -0.9066,  ..., -1.9895, -0.7249, -2.2398],
          ...,
          [ 2.2550, -0.3504,  1.9273,  ..., -1.5460, -0.7663,  5.0170],
          [-1.2137, -1.2391,  2.1715,  ...,  0.8388,  1.1657, -3.0377],
          [-0.1534,  1.6688,  0.9831,  ..., -1.4201,  0.2443, 11.7916]]],
        device='cuda:0', grad_fn=<UnsafeViewBackward0>),
 None)

In [237]:
activations

{0: tensor([[[ 1.2135e+01, -2.7764e-01,  8.2318e-01,  ..., -1.9390e-01,
            1.0715e+00,  4.8527e-01],
          [ 6.5625e-01,  1.1775e-01,  2.3534e-01,  ...,  3.0556e-01,
           -1.2226e+00, -5.1176e-01],
          [ 1.7558e-01, -5.6305e-02, -7.5093e-01,  ...,  4.2980e-02,
            4.4480e-01,  2.0154e-01],
          ...,
          [ 1.8869e-01, -2.2275e-01, -3.4362e-02,  ...,  3.0657e-01,
           -1.6302e-01,  2.5002e-01],
          [-1.1167e-01, -4.2898e-02, -4.1120e-02,  ..., -5.8610e-01,
            9.6107e-02,  8.4802e-03],
          [ 1.9607e-01, -2.1042e-01, -1.5930e-01,  ..., -1.3228e-01,
           -7.7680e-02,  7.8192e-01]]], device='cuda:0'),
 1: tensor([[[ 1.3324e+01, -4.1872e-01,  9.6616e-01,  ..., -1.1323e-01,
            5.8864e-01,  5.2013e-01],
          [ 3.9493e-01,  1.2693e-01,  7.6263e-02,  ...,  1.8179e-01,
           -9.3505e-01, -6.0920e-01],
          [ 3.9375e-01, -2.2670e-01, -9.1336e-01,  ..., -7.8271e-02,
            3.3576e-01, -8.7583e-0

In [269]:
for hook in layer_hooks:
    hook.remove()

In [239]:
import torch.nn as nn
import torch.nn.functional as F

In [256]:
x = t.arange(20).view(2,2,5).flatten(-2,-1).reshape(20)

In [257]:
x

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

In [298]:
class Probe(nn.Module):
    def __init__(self, layers, in_features, out_categories):
        super().__init__()
        self.linlayer = nn.Linear(layers*in_features, out_categories)
        self.out_categories = out_categories
        self.layers = layers
    def forward(self, x, targets = None):
        x = self.linlayer(x)
        if targets is not None:
            loss = F.cross_entropy(x.view(-1,self.out_categories), targets.flatten())
            return x, loss
        return x

In [299]:
device = 'cuda'
current_player_probe = Probe(8, 512, 2)

In [261]:
target = t.tensor([i%2 for i in range(59)]).expand(32,59)

In [311]:
train_dataset = MyDataset('./data/moves.npy')
train_loader = DataLoader(
    train_dataset,
    shuffle=True,
    pin_memory=True,
    batch_size=32,
    num_workers=0,
)

optimizer = t.optim.Adam(current_player_probe.parameters())

pbar = tqdm(train_loader)
model.to('cuda')
current_player_probe.to('cuda')
target = target.to('cuda')

for x, y in pbar:
    x = x.to('cuda')
    y = y.to('cuda')
    optimizer.zero_grad()
    with t.no_grad():
        activations = {}
        model(x,y)
    #activations # B, 59, 512
    #logits = current_player_probe(activations)
    logits, loss = current_player_probe(t.cat([activations[i] for i in range(8)], dim = 2), target)
    loss.backward()
    optimizer.step()
    pbar.set_description(f'loss = {loss}')

print(loss)

logits.shape

loss = 3.1570254743762405e-10: 100%|██████████| 3125/3125 [49:07<00:00,  1.06it/s]

tensor(3.1570e-10, device='cuda:0', grad_fn=<NllLossBackward0>)





torch.Size([32, 59, 2])

In [301]:
for param in current_player_probe.parameters():
    print(param)

Parameter containing:
tensor([[ 0.0710, -0.0560, -0.0245,  ..., -0.0017, -0.0121,  0.0246],
        [-0.0680,  0.0408,  0.0455,  ...,  0.0122,  0.0188, -0.0150]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.0186, -0.0184], device='cuda:0', requires_grad=True)


In [302]:
list(current_player_probe.parameters())[0].detach()[0]+list(current_player_probe.parameters())[0].detach()[1]

tensor([ 0.0030, -0.0152,  0.0210,  ...,  0.0105,  0.0067,  0.0096],
       device='cuda:0')

In [330]:
F.cosine_similarity(list(current_player_probe.parameters())[0].detach()[0].reshape(8,512)[3], list(current_player_probe.parameters())[0].detach()[1].reshape(8,512)[3], dim = 0)

tensor(-0.7917, device='cuda:0')

In [355]:
white_directions = list(current_player_probe.parameters())[0].detach()[0].view(8,512)
black_directions = list(current_player_probe.parameters())[0].detach()[1].view(8,512)

In [342]:
activations[4].shape

torch.Size([32, 59, 512])

In [357]:
t.einsum('t c, c -> t',activations[3][0], black_directions[3])

tensor([-9.7992,  1.5291, -2.3716,  0.6193, -2.8316,  0.4949, -2.5077,  0.4332,
        -2.2310,  0.3860, -2.6180,  0.3776, -2.2986,  0.5967, -2.3090,  0.8789,
        -2.4481,  0.3369, -2.4326,  0.7508, -2.6387, -0.2188, -2.3764,  0.2636,
        -2.9054,  0.8000, -3.1111,  0.5353, -3.2226,  0.2312, -2.4706,  0.5090,
        -2.5921, -0.0259, -2.6566,  0.4443, -2.4886,  0.7498, -1.9179,  1.1109,
        -1.8992,  0.4646, -2.6522,  0.6891, -2.9005,  0.6179, -2.1779,  0.5044,
        -2.3521,  0.2104, -2.5466,  0.6627, -3.3349,  0.3953, -2.6592, -0.5509,
        -2.9354, -0.4950, -2.4276], device='cuda:0')

In [337]:
t.save(current_player_probe.state_dict(), './data/current_player_probe_state_dict.pth')

In [16]:
current_player_probe.load_state_dict(t.load('./data/current_player_probe_state_dict.pth'))

<All keys matched successfully>

In [213]:
from mingpt.model import Intervention_GPT

In [18]:
black_direction = list(current_player_probe.parameters())[0].detach()[0]

In [19]:
b = black_direction/t.sqrt(((black_direction*black_direction).sum()))

In [13]:
(b*b).sum()

tensor(1.)

In [29]:
b = b.to('cuda')

In [38]:
b.device

device(type='cuda', index=0)

In [20]:
def intervention(input):
    print(input.device)
    print(input.shape)
    print(b.device)
    print((input[:,15,:] * b[None, :]).sum(dim = -1)[:,None])
    print((2*(input[:,15,:] * b[None, :]).sum(dim = -1)[:,None]*b[None, None, :]).shape)
    
    edit = (2*(input[:,15,:] * b[None, :]).sum(dim = -1)[:,None]*b[None, None, :])
    template = t.zeros_like(input)
    template[:,15,:] = - edit[:,0,:]
    print(template[0,15])
    print(f'b dot input before = {2*(input[:,15,:] * b[None, :]).sum(dim = -1)[:,None]}')
    print(f'b dot input after = {2*((input + template)[:,15,:] * b[None, :]).sum(dim = -1)[:,None]}')
    return input + template

In [42]:
def intervention_einsum(input):
    print(input.shape)
    print(b.shape)

    return 0*input

In [109]:
x = t.arange(7)
x

tensor([0, 1, 2, 3, 4, 5, 6])

In [None]:
y = t.tensor

In [43]:
intervention_model = Intervention_GPT(model_config, None, intervention_einsum)
intervention_model.load_state_dict(t.load('./weights/model_state_dict_large_training_2_02.pth'))
intervention_model.to('cuda')

number of parameters: 25.28M


Intervention_GPT(
  (transformer): ModuleDict(
    (wte): Embedding(61, 512)
    (wpe): Embedding(59, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=512, out_features=1536, bias=True)
          (c_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
          (act): NewGELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): L

In [23]:
test_moves = t.tensor(np.load('./data/test_moves.npy'))

In [54]:
test_moves_clean = t.tensor(np.load('./data/test_moves.npy'))

In [56]:
t.all(test_moves == test_moves_clean)

tensor(True)

In [205]:
Othello.board_state(test_moves[15],5)

In [369]:
black_hat = black_directions / t.sqrt(t.einsum('l c, l c -> l', black_directions, black_directions).unsqueeze(dim = 1))
white_hat = white_directions / t.sqrt(t.einsum('l c, l c -> l', white_directions, white_directions).unsqueeze(dim = 1))

In [402]:
layer = 3

def intervention_einsum(input):
    print(f'input shape is {input.shape}')
    print(f'black component before is {t.einsum("t c, c -> t", input[0], black_hat[layer])[4]}')
    print(f'white component before is {t.einsum("t c, c -> t", input[0], white_hat[layer])[4]}')
    # print(t.einsum('i j k, k -> j', input, b)[4])
    # template = t.zeros_like(input)
    # template[0,4,:] = t.einsum('i j k, k -> j', input, b)[4] * b
    template = t.zeros_like(input)
    template[0,4,:] = 1*(-3*white_hat[layer]+10*black_hat[layer])
    # print(t.einsum('i j k, k -> j', input - 2*template, b)[4])
    print(f'black component after is {t.einsum("t c, c -> t", (input + template)[0], black_hat[layer])[4]}')
    print(f'white component after is {t.einsum("t c, c -> t", (input + template)[0], white_hat[layer])[4]}')
    return input + template


intervention_model = Intervention_GPT(model_config, 4, intervention_einsum)
intervention_model.load_state_dict(t.load('./weights/model_state_dict_large_training_2_02.pth'))
intervention_model.to('cuda')
intervention_model.eval()

fig = px.imshow(Othello.tokens_to_ij(intervention_model(test_data[15,:-1].unsqueeze(dim=0).to(device = 'cuda'))[0][0,5 - 1,:]).detach().numpy())

fig.update_layout(width=400, height=400)

fig.update_layout(
    margin=dict(l=0,r=0,b=0,t=0)
        )

fig.show()

number of parameters: 25.28M
input shape is torch.Size([1, 59, 512])
black component before is -2.7800729274749756
white component before is 0.7711783647537231
black component after is 9.594934463500977
white component after is -10.145515441894531


In [179]:
intervention_model(test_data[15,:-1].unsqueeze(dim=0).to(device = 'cuda'))[0][0,6 - 1,:]

torch.Size([1, 59, 512])
torch.Size([512])


tensor([ 3.3138, -0.9150,  1.8627, -1.9997, -1.5406,  3.2887,  1.1535, -1.6492,
         0.8260, -0.3423,  0.7529, -2.0722,  0.2037, -0.7998, -0.4436, -2.6264,
        -1.2772, -0.0347, -2.8659, -0.5119, -0.0630, -2.0859,  2.3057, -2.6996,
        -1.1755,  1.2278, -2.6508,  9.6554,  2.2354, -1.4875, -0.8964, -0.2006,
        -1.8900,  2.7460,  2.4992, -1.1967, -1.2600,  3.1163,  2.7699,  7.6051,
         2.8035, 10.7764,  1.2490, -1.7645, -0.0422, -2.1621, -0.2611,  1.0203,
        -1.4473, -3.4068, -0.2482, -0.2082, -1.6040,  0.5504, -1.4883,  2.8089,
        -1.6981, -3.6737, -2.1738, -1.1330, -4.6043], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [85]:
list(intervention_model.parameters())

[Parameter containing:
 tensor([[ 0.0834, -0.0524,  0.0030,  ..., -0.1855, -0.0207, -0.1003],
         [-0.0584, -0.0330,  0.1053,  ...,  0.2168,  0.0822, -0.0995],
         [ 0.0433,  0.0641, -0.3092,  ..., -0.0396,  0.0213, -0.0739],
         ...,
         [ 0.0250, -0.0779,  0.4678,  ...,  0.3177, -0.0897, -0.0996],
         [-0.0255, -0.2313,  0.0761,  ..., -0.0997, -0.0171, -0.1082],
         [ 0.1741,  0.1394,  0.0229,  ...,  0.3320, -0.5934,  0.0196]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[ 8.4453e+00,  2.8448e-03,  6.9660e-02,  ...,  6.0375e-02,
           1.4339e+00,  2.4229e-01],
         [ 6.5520e-01, -3.3817e-02, -1.7010e-01,  ..., -1.2982e-01,
          -1.5948e+00, -4.2678e-01],
         [-1.2839e-01, -8.2243e-03, -9.3632e-02,  ..., -3.6962e-02,
           2.6830e-01,  7.6189e-02],
         ...,
         [-2.5456e-02, -1.0982e-02,  4.3018e-03,  ...,  4.9805e-02,
          -1.3216e-01,  4.6654e-01],
         [-8.0499e-02, -1.6650e-0

In [88]:
list(intervention_model.parameters())

[Parameter containing:
 tensor([[ 0.0834, -0.0524,  0.0030,  ..., -0.1855, -0.0207, -0.1003],
         [-0.0584, -0.0330,  0.1053,  ...,  0.2168,  0.0822, -0.0995],
         [ 0.0433,  0.0641, -0.3092,  ..., -0.0396,  0.0213, -0.0739],
         ...,
         [ 0.0250, -0.0779,  0.4678,  ...,  0.3177, -0.0897, -0.0996],
         [-0.0255, -0.2313,  0.0761,  ..., -0.0997, -0.0171, -0.1082],
         [ 0.1741,  0.1394,  0.0229,  ...,  0.3320, -0.5934,  0.0196]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[ 8.4453e+00,  2.8448e-03,  6.9660e-02,  ...,  6.0375e-02,
           1.4339e+00,  2.4229e-01],
         [ 6.5520e-01, -3.3817e-02, -1.7010e-01,  ..., -1.2982e-01,
          -1.5948e+00, -4.2678e-01],
         [-1.2839e-01, -8.2243e-03, -9.3632e-02,  ..., -3.6962e-02,
           2.6830e-01,  7.6189e-02],
         ...,
         [-2.5456e-02, -1.0982e-02,  4.3018e-03,  ...,  4.9805e-02,
          -1.3216e-01,  4.6654e-01],
         [-8.0499e-02, -1.6650e-0

In [77]:
test_data[15,:-1]

tensor([33, 41, 19, 20, 21, 27, 40, 32, 26, 12, 49, 50, 13, 25, 17, 57,  5, 11,
        18,  8, 48, 34, 24, 31, 39, 47, 58,  6, 36,  4, 55, 10, 16, 38, 45, 54,
         2,  9, 51, 59,  1, 30, 29, 43, 56, 42, 37,  3, 28, 35, 14, 15, 23, 22,
         7, 53, 44, 52, 46])

In [26]:
test_moves.to('cuda')

tensor([[40, 41, 26,  ..., 44,  5, 60],
        [33, 41, 26,  ..., 45, 52, 36],
        [26, 20, 21,  ...,  3, 56, 44],
        ...,
        [33, 39, 18,  ..., 29,  2, 23],
        [26, 18, 10,  ...,  5, 43, 53],
        [33, 41, 26,  ..., 52, 57, 60]], device='cuda:0', dtype=torch.int32)

In [27]:
b.to('cuda')

tensor([-1.7305e-02,  9.6176e-03, -2.9521e-02, -4.9176e-02, -1.8288e-02,
        -5.8259e-02, -2.1841e-02,  1.8219e-03, -1.4584e-02, -5.6416e-02,
        -6.9991e-02, -2.8398e-02, -1.8896e-02, -1.5648e-02, -2.5691e-02,
        -2.2644e-02, -2.9278e-02, -3.0237e-02, -1.2285e-02,  1.1931e-02,
        -1.0810e-02, -3.1307e-02, -1.2830e-02, -3.2311e-02, -2.8417e-02,
        -2.7700e-02, -7.6280e-03,  1.0053e-01, -3.7349e-02, -1.9853e-02,
        -4.5630e-02, -3.8352e-02,  9.8387e-03, -1.3102e-02, -5.6305e-02,
        -1.6317e-02, -4.5630e-02, -2.4721e-02,  4.0237e-03, -1.0709e-02,
        -1.2749e-02,  2.7516e-02, -2.9780e-02, -5.6682e-02, -1.6142e-02,
        -7.2649e-03, -2.7361e-02, -3.9503e-02, -2.8272e-02,  1.7102e-02,
        -1.7522e-02, -2.5213e-02, -1.0411e-02, -2.4180e-02,  1.6439e-03,
        -1.8348e-02,  1.3425e-02, -1.4769e-02, -4.1878e-02, -8.7471e-03,
        -2.8541e-02, -1.0034e-02, -2.3373e-02, -2.7516e-02, -5.3712e-02,
        -3.9535e-02, -2.7202e-02,  3.3971e-01, -1.2