In [50]:
from datasets import load_dataset, load_from_disk
import torch as t
import numpy as np

from othello_world.mechanistic_interpretability.mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

In [51]:
token_dataset = load_from_disk("dummy_othellogpt_dataset")['tokens'][:1000]
token_dataset_string = t.tensor(to_string(token_dataset))

print(token_dataset.shape)
print(token_dataset.shape)


num_games = 100
feature = 88

focus_games_int = token_dataset[:num_games]
focus_games_string = token_dataset_string[:num_games]

print(focus_games_int.shape)    
print(focus_games_string.shape)

activations = t.load("saved_feat_acts/all_feat_acts_0-99.pt")[:num_games, :, feature]
print(activations.shape)

torch.Size([1000, 60])
torch.Size([1000, 60])
torch.Size([100, 60])
torch.Size([100, 60])
torch.Size([100, 59])


In [52]:
def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)


for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(59):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())


print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

focus states: (100, 60, 8, 8)
focus_valid_moves (100, 60, 64)


In [53]:
rows = 8
cols = 8

def state_stack_to_one_hot(state_stack):
    '''
    Creates a tensor of shape (games, moves, rows=8, cols=8, options=3), where the [g, m, r, c, :]-th entry
    is a one-hot encoded vector for the state of game g at move m, at row r and column c. In other words, this
    vector equals (1, 0, 0) when the state is empty, (0, 1, 0) when the state is "their", and (0, 0, 1) when the
    state is "my".
    '''
    one_hot = t.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        rows,
        cols,
        3, # the options: empty, white, or black
        device=state_stack.device,
        dtype=t.int,
    )
    one_hot[..., 0] = state_stack == 0 
    one_hot[..., 1] = state_stack == -1 
    one_hot[..., 2] = state_stack == 1 

    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1), rather than black and white
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one-hot encoded vectors
focus_states_flipped_one_hot = state_stack_to_one_hot(t.tensor(flipped_focus_states))

# Take the argmax (i.e. the index of option empty/their/mine)
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)
print(focus_states_flipped_value.shape)

torch.Size([100, 60, 8, 8])


In [54]:
top_moves = activations > activations.quantile(0.99)
n_moves = top_moves.sum().item()
print('number of moves:', n_moves)
print(top_moves.shape)

number of moves: 31
torch.Size([100, 59])


In [55]:
from plotly_utils import imshow
alpha = "ABCDEFGH"

def plot_square_as_board(state, filename, diverging_scale=True, **kwargs):
    '''Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0'''
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, filename=filename, **kwargs)

In [56]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")
layer = 5

focus_states_flipped_value = focus_states_flipped_value.to(device)
board_state_at_top_moves = t.stack([
    (focus_states_flipped_value == 2)[:, :-1][top_moves].float().mean(0),
    (focus_states_flipped_value == 1)[:, :-1][top_moves].float().mean(0),
    (focus_states_flipped_value == 0)[:, :-1][top_moves].float().mean(0)
])

print(board_state_at_top_moves.shape)

# plot_square_as_board(
#     board_state_at_top_moves,
#     filename=f"top_{n_moves}_moves_layer{layer}_feature{feature}", 
#     facet_col=0,
#     facet_labels=["Mine", "Theirs", "Blank"],
#     title=f"Aggregated top {n_moves} moves for neuron L{layer}F{feature}", 
# )

torch.Size([3, 8, 8])


In [57]:
# print how many are above 0.99
is_board_state_feature = board_state_at_top_moves > 0.75
print(is_board_state_feature.sum())

tensor(1, device='cuda:0')


In [58]:
# plot_square_as_board(
#     is_board_state_feature,
#     filename=f"top_{n_moves}_moves_layer{layer}_feature{feature}", 
#     facet_col=0,
#     facet_labels=["Mine", "Theirs", "Blank"],
#     title=f"Aggregated top {n_moves} moves for neuron L{layer}F{feature}", 
# )