In [13]:
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
from scipy.sparse.linalg import eigsh
from torch import nn


def make_layer_accessor(path: str):
    def accessor(model: nn.Module) -> torch.Tensor:
        layer = model
        for name in path.split("."):
            layer = getattr(layer, name)

        return layer.weight

    return accessor


# TODO: Check that this yields the correct weights
def make_head_accessor(attn_layer: str, head: int, num_heads=2, embed_dim=4, head_size: Optional[int] = None):
    """Make a function to access the weights of a particular attention head."""
    head_size = embed_dim // num_heads if head_size is None else head_size
    num_head_params = 3 * head_size * embed_dim

    def accessor(model: nn.Module) -> torch.Tensor:
        return make_layer_accessor(attn_layer)(model)[head * num_head_params: (head + 1) * num_head_params]

    return accessor


def _head_loc(l: int, h: int):
    return f"{1 + l}.1.{h}:block_{l}/head_{h}"

def _ln_loc(l: int, i: int):
    return f"{1 + l}.{2 * i}:block_{l}/ln_{i}"

def _mlp_loc(l: int, i: int):
    return f"{1 + l}.3.{i}:block_{l}/mlp_{i}"

def make_transformer_accessors(L=2, H=4):
    assert L < 10, "L must be less than 10"

    accessors = {
        _head_loc(l, h): make_head_accessor(f"token_sequence_transformer.blocks.{l}.attention.attention", h)
        for l in range(L) for h in range(H)
    }

    accessors.update({
        _mlp_loc(l, i): make_layer_accessor(f"token_sequence_transformer.blocks.{l}.compute.{2 * i}")
        for l in range(L) for i in range(2)
    })

    accessors.update({
        _ln_loc(l, i): make_layer_accessor(f"token_sequence_transformer.blocks.{l}.attention.layer_norms.{i}")
        for i in range(2) for l in range(L)
    })

    accessors.update({
        "0.0:embed/token": make_layer_accessor("token_sequence_transformer.token_embedding"),
        "0.1:embed/pos": make_layer_accessor("token_sequence_transformer.postn_embedding"),
        f"{1+L}.0:unembed/ln": make_layer_accessor("token_sequence_transformer.unembedding.0"),
        f"{1+L}.1:unembed/linear": make_layer_accessor("token_sequence_transformer.unembedding.1"),
    })

    # Sort
    accessors = {k: v for k, v in sorted(accessors.items(), key=lambda x: x[0])}
    
    return accessors

def make_transformer_accessors_and_interactions(L=2, H=4):
    accessors = make_transformer_accessors(L, H)

    # We want all within-layer covariances
    paths = {key: (key, key) for key in accessors.keys()}

    # We want between-head covariances for successive layers
    if L >= 2:
        for l in range(0, L-1):
            for h1 in range(H):
                for h2 in range(H):
                    paths[f"{_head_loc(l, h1)}-{_head_loc(l+1, h2)}"] = (_head_loc(l, h1), _head_loc(l+1, h2))

    # Let's check between mlp covariances within a single block
    for l in range(L):
        for i in range(2):
            paths[f"{_mlp_loc(l, i)}-{_mlp_loc(l, i)}"] = (_mlp_loc(l, i), _mlp_loc(l, i))

    # And let's check between embeds and unembeds
    paths[f"0.0:embed/token-{1+L}.1:unembed/linear"] = ("0.0:embed/token", f"{1+L}.1:unembed/linear")

    return accessors, paths

make_transformer_accessors_and_interactions(2, 4)

({'0.0:embed/token': <function __main__.make_layer_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '0.1:embed/pos': <function __main__.make_layer_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.0:block_0/ln_0': <function __main__.make_layer_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.1.0:block_0/head_0': <function __main__.make_head_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.1.1:block_0/head_1': <function __main__.make_head_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.1.2:block_0/head_2': <function __main__.make_head_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.1.3:block_0/head_3': <function __main__.make_head_accessor.<locals>.accessor(model: torch.nn.modules.module.Module) -> torch.Tensor>,
  '1.2:block_0/ln_1': <function __main__