In [1]:
import torch
from transformer import Transformer
import matplotlib.pyplot as plt
import os
import math
import re
dropout = 0
N = 20
hidden_dim = 2
proj_dim = 10
output_dim = 1
num_heads = 1
num_layers = 1
ff_dim = 128
ln_eps = 1e-5
ln = False
device = "cpu"


model = Transformer(dropout, N, hidden_dim, proj_dim, output_dim, num_heads, num_layers, ff_dim, ln_eps, device, ln)
print(model)

x = torch.randint(0, 2**N, (1000,)).to(device)
y = model(x)
y

Transformer(
  (embeddings): Embedding(2, 2)
  (transformer): AttentionBlock(
    (attn): CustomMHA(
      (out_proj): None
    )
    (linear): Sequential(
      (0): Linear(in_features=1, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
torch.Size([1000, 20, 20])
torch.Size([1000, 20, 10])
torch.Size([20, 10])
tensor([[-9.3919e-02, -1.1117e-01, -3.4433e-01, -5.8144e-02, -1.0872e-01,
          4.4123e-01,  9.0699e-02,  5.7264e-02,  1.9023e-01,  3.9258e-01],
        [-5.8898e-01,  2.9353e-01,  4.3331e-01, -1.1082e-01,  1.3656e-01,
         -1.5010e-01,  3.0561e-01, -4.4309e-01,  4.1347e-02, -1.1216e-01],
        [-2.3517e-01, -6.6879e-01, -1.5947e-01,  3.0832e-01,  1.6363e-01,
          3.4300e-01,  2.1560e-01,  5.3285e-01, -9.1335e-02,  3.8381e-02],
        [ 2.3413e-01, -5.6890e-01,  1.9724e-02, -7.5098e-02,  3.1639e-01,
         -1.2160e-01,  3.2292e-01, -3.8569e-02, -2.8024e-01,  3.0186e-01],
        [ 3.2589e-0

tensor([-4.9281, -5.0303, -4.9098, -4.9026, -5.0359, -4.9043, -5.0497, -4.9406,
        -5.0159, -4.9145, -5.0401, -4.9940, -4.8554, -4.8585, -4.9175, -4.9612,
        -4.9009, -4.9656, -4.9554, -4.9410, -4.8952, -4.9150, -4.9184, -5.0665,
        -4.9334, -5.0048, -5.0011, -4.9922, -5.0232, -4.8730, -5.0273, -4.9840,
        -5.0823, -5.0073, -4.9392, -4.9891, -4.9025, -4.9391, -4.9576, -5.0216,
        -4.9641, -4.9562, -4.9429, -4.8206, -4.9390, -4.8388, -4.9670, -4.9294,
        -4.8946, -4.9320, -4.9794, -4.8745, -4.8661, -5.0393, -4.9067, -5.0232,
        -5.0391, -4.9670, -4.9484, -4.9236, -4.9175, -4.9453, -4.9437, -5.0408,
        -4.9691, -4.7712, -5.0239, -4.8501, -4.9772, -4.8968, -5.0026, -4.9186,
        -4.9661, -5.1696, -4.9472, -4.9687, -4.9928, -4.9304, -4.9844, -4.9757,
        -4.9530, -4.9564, -4.9445, -4.8696, -4.8265, -4.9226, -5.0096, -5.0189,
        -4.9741, -4.9873, -4.9430, -4.9538, -4.8164, -4.8701, -5.0304, -4.9130,
        -4.9459, -4.9172, -4.9605, -4.91

In [63]:
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go
import pandas as pd

# l = ["HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-3/func-1/",
# "HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-5/func-2/",
# "HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-4/func-1/"]

folder = "MINTERP"


for root, dirs, files in os.walk(f"{folder}"):
    if not files: continue
    epochs = [x for x in files if "epoch-" in x]
    if not epochs: continue
    new_epochs = []
    for epoch in epochs:
        try:
            model.load_state_dict(torch.load(f"{root}/{epoch}", map_location=device))
            new_epochs.append(epoch)
        except Exception as e:
            print(e)
            continue

    epochs = new_epochs
    if not epochs: continue
    epochs.sort(key=lambda x: int(re.search(r"epoch-(\d+).pt$", x).group(1)))

    # Load coeffs and combs
    pattern = rf"{folder}/degree-(\d+)/width-(\d+)/func-(\d+)"
    comb_path = rf"{folder}/combs_func\3_deg\1_width\2.pt"
    coeff_path = rf"{folder}/coefs_func\3_deg\1_width\2.pt"

    combs_path = re.sub(pattern, comb_path, root)
    coeff_path = re.sub(pattern, coeff_path, root)

    combs = torch.load(combs_path, map_location=device)
    coeffs = torch.load(coeff_path, map_location=device)
    

    # Create string representation of the coefficients
    s = ""
    for (i, c) in enumerate(coeffs):
        s += f"{c:.3F}({combs[i].tolist()})+"

    # Take final epoch
    epochs = [epochs[-1]]

 
    with torch.no_grad():
        for (i, epoch) in enumerate(epochs):
            # Load model
            ckpt_point = f"{root}/{epoch}"
            print(s[:-1])

            print(ckpt_point)
            model.load_state_dict(torch.load(ckpt_point, map_location=device))
            model.eval()

            in_proj = model.transformer.attn.in_proj_weight.detach().T
            lin1 = model.transformer.linear[0].weight.detach()
            lin1b = model.transformer.linear[0].bias.detach()
            lin2 = model.transformer.linear[2].weight.detach()
            lin2b = model.transformer.linear[2].bias.detach()
            out_proj = model.output_proj.detach()
            embed_dim = dim + N 
            w_q, w_k, w_v = in_proj[:, :embed_dim], in_proj[:, embed_dim:2*embed_dim], in_proj[:, 2*embed_dim:].T
            w_qk = w_q @ w_k.T
            # Use the indices in combs to set the values in actual_combs to one
            actual_combs = torch.ones(combs.shape[0], N+2)*torch.min(w_qk)
            for j in range(combs.shape[0]):
                actual_combs[j, combs[j]] = torch.max(w_qk)



            # Vector containing all multiples of 1/D_f
            D_f = 1/1000
            inputs = torch.tensor([D_f * i - 1 for i in range(1, 2*int(1/D_f)+1)]).unsqueeze(1).to(device)
            print(inputs.shape)
            
            # Create inputs ranging several orders of magnitude
            # inputs = torch.exp(torch.tensor([(D_f * i - 2) for i in range(1, 4*int(1/D_f)+1)]).unsqueeze(1).to(device))-3


            outputs0 = model.transformer.linear[0](inputs)
            outputs1 = model.transformer.linear[1](outputs0)
            output1 = outputs1[outputs1.sum(dim=1) != 0]
            outputs2 = model.transformer.linear[2](outputs1)



            fig = sp.make_subplots(rows=len(epochs), cols=2, subplot_titles=("QK.T","F(relu(MX+G))+b"), shared_yaxes=False, horizontal_spacing=0.10)
            colorscale = "PiYG"
            color = px.colors.sequential.Cividis_r
            fig.add_trace(go.Heatmap(z=torch.concat((actual_combs, w_qk), axis=0), y=list(range(-actual_combs.shape[0], w_qk.shape[0])), coloraxis='coloraxis2', showscale=True), row=i+1, col=1)
            # fig.add_trace(go.Heatmap(z=outputs0, coloraxis='coloraxis2'), row=i+1, col=2)
            # fig.add_trace(go.Heatmap(z=outputs1, coloraxis='coloraxis3'), row=i+1, col=3)
            # Plot output2 as a line
            fig.add_trace(
                go.Scatter(x=inputs.squeeze().numpy(), y=outputs2.squeeze().tolist(), mode='lines', line=dict(color='blue')), 
                row=i+1, col=2
            )
            fig.add_hline(y=-0.5, row=i+1, col=1, line_width=3)

            # fig.add_trace(go.Heatmap(z=actual_combs, coloraxis='coloraxis1', showscale=False), row=i+1, col=1)
  
    fig.update_layout(
        coloraxis1=dict(colorscale="ylgn", showscale=False),
        coloraxis2=dict(colorscale="ylgn", colorbar=dict(x=0.44)),
        # coloraxis3=dict(colorscale="ylgn", colorbar=dict(x=0.73)),
        # coloraxis4=dict(colorscale="ylgn", colorbar=dict(x=1)),
        height=350, width=750,
        margin=dict(l=10, r=10, t=30, b=30)
    )

    fig.show()




Error(s) in loading state_dict for Transformer:
	size mismatch for transformer.linear.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([128, 1]).
	size mismatch for transformer.linear.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.linear.2.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 128]).
0.089([0, 6, 15])+-0.186([8, 9, 15])+0.097([5, 10, 15])
MINTERP/degree-3/width-3/func-3/epoch-233.pt
torch.Size([2000, 1])


-0.208([1, 2, 13])+0.776([6, 11, 12])+-0.568([9, 10, 13])
MINTERP/degree-3/width-3/func-4/epoch-2997.pt
torch.Size([2000, 1])


Error(s) in loading state_dict for Transformer:
	size mismatch for transformer.linear.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([128, 1]).
	size mismatch for transformer.linear.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.linear.2.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 128]).
0.436([6, 11, 19])+0.354([3, 9, 16])+-0.790([4, 10, 18])
MINTERP/degree-3/width-3/func-2/epoch-305.pt
torch.Size([2000, 1])


Error(s) in loading state_dict for Transformer:
	size mismatch for transformer.linear.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([128, 1]).
	size mismatch for transformer.linear.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.linear.2.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 128]).
0.239([2, 8, 16])+0.467([4, 13, 18])+-0.706([0, 1, 6])
MINTERP/degree-3/width-3/func-1/epoch-78254.pt
torch.Size([2000, 1])


Error(s) in loading state_dict for Transformer:
	size mismatch for transformer.linear.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([128, 1]).
	size mismatch for transformer.linear.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.linear.2.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([1, 128]).
Error(s) in loading state_dict for Transformer:
	size mismatch for transformer.linear.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([128, 1]).
	size mismatch for transformer.linear.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.linear.2.weight: copying a param with shape torch.Size([1, 32]) from 

In [75]:
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go
import pandas as pd

# l = ["HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-3/func-1/",
# "HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-5/func-2/",
# "HYPERPARAM_TESTS_MECHINTERP2/degree-3/width-4/func-1/"]

folder = "MINTERP"


for root, dirs, files in os.walk(f"{folder}"):
    if not files: continue
    epochs = [x for x in files if "epoch-" in x]
    if not epochs: continue
    new_epochs = []
    for epoch in epochs:
        try:
            model.load_state_dict(torch.load(f"{root}/{epoch}", map_location=device))
            new_epochs.append(epoch)
        except Exception as e:
            # print(e)
            continue

    epochs = new_epochs
    if not epochs: continue
    epochs.sort(key=lambda x: int(re.search(r"epoch-(\d+).pt$", x).group(1)))

    # Load coeffs and combs
    pattern = rf"{folder}/degree-(\d+)/width-(\d+)/func-(\d+)"
    comb_path = rf"{folder}/combs_func\3_deg\1_width\2.pt"
    coeff_path = rf"{folder}/coefs_func\3_deg\1_width\2.pt"

    combs_path = re.sub(pattern, comb_path, root)
    coeff_path = re.sub(pattern, coeff_path, root)

    combs = torch.load(combs_path, map_location=device)
    coeffs = torch.load(coeff_path, map_location=device)
    

    # Create string representation of the coefficients
    s = ""
    for (i, c) in enumerate(coeffs):
        s += f"{c:.3F}({combs[i].tolist()})+"

    # Take final epoch
    epochs = [epochs[-1]]

 
    with torch.no_grad():
        for (i, epoch) in enumerate(epochs):
            # Load model
            ckpt_point = f"{root}/{epoch}"
            print(s[:-1])

            print(ckpt_point)
            model.load_state_dict(torch.load(ckpt_point, map_location=device))
            model.eval()

            in_proj = model.transformer.attn.in_proj_weight.detach().T
            lin1 = model.transformer.linear[0].weight.detach()
            lin1b = model.transformer.linear[0].bias.detach()
            lin2 = model.transformer.linear[2].weight.detach()
            lin2b = model.transformer.linear[2].bias.detach()
            out_proj = model.output_proj.detach()
            embed_dim = dim + N 
            w_q, w_k, w_v = in_proj[:, :embed_dim], in_proj[:, embed_dim:2*embed_dim], in_proj[:, 2*embed_dim:].T
            w_qk = w_q @ w_k.T

            # Use the indices in combs to set the values in actual_combs to one
            actual_combs = torch.zeros(combs.shape[0], N+2)
            for j in range(combs.shape[0]):
                actual_combs[j, combs[j]] = 1

            # For each comb in actual_combs, calculate the cosine similarity with each row in w_qk
            print(w_qk.shape, actual_combs.shape)

            actual_combs = torch.nn.functional.normalize(actual_combs, dim=1)
            w_qk = torch.nn.functional.normalize(w_qk, dim=1)
            cosine_similarities = torch.mm(actual_combs, w_qk.T)
            cosine_similarities = torch.abs(cosine_similarities)



            # Calculate cosine similarities with random vectors
            num_randoms = 10
            k = combs.shape[1]
            random_combs = torch.zeros(num_randoms, N+2)
            for j in range(num_randoms):
                indices = torch.randperm(N+2)[:k]
                random_combs[j, indices] = 1
        
            random_combs = torch.nn.functional.normalize(random_combs, dim=1)
            random_cosine_similarities = torch.mm(random_combs, w_qk.T)
            random_cosine_similarities = torch.abs(random_cosine_similarities)



            # actual_combs = torch.max(w_qk) * actual_combs
            # random_cosine_similarities = torch.max(w_qk) * random_cosine_similarities
            # cosine_similarities = torch.max(w_qk) * cosine_similarities

            # Mean across each row 
            # cosine_similarities = torch.mean(cosine_similarities, dim=1).unsqueeze(1)
            # random_cosine_similarities = torch.mean(random_cosine_similarities, dim=0).unsqueeze(0)
            # print(random_cosine_similarities.shape)




















            fig = sp.make_subplots(rows=len(epochs), cols=3, subplot_titles=("QK.T","Cos. Sim.", "Cos. Sim. Rand."), shared_yaxes=False, horizontal_spacing=0.10)
            colorscale = "PiYG"
            color = px.colors.sequential.Cividis_r


            fig.add_trace(
                        go.Heatmap(z=torch.concat((actual_combs, w_qk), axis=0), 
                            y=list(range(-actual_combs.shape[0], w_qk.shape[0])), 
                            coloraxis='coloraxis1', showscale=True),
                        row=i+1, col=1)
            fig.add_hline(y=-0.5, row=i+1, col=1, line_width=3)

            # Plot cosine_similarities as a heatmap
            fig.add_trace(
                go.Heatmap(z=cosine_similarities.T, coloraxis='coloraxis1', showscale=True),
                row=i+1, col=2
            )

            fig.add_trace(
                go.Heatmap(z=random_cosine_similarities.T, coloraxis='coloraxis1', showscale=True),
                row=i+1, col=3
            )

            # fig.add_trace(go.Heatmap(z=actual_combs, coloraxis='coloraxis1', showscale=False), row=i+1, col=1)
  
    fig.update_layout(
        coloraxis1=dict(colorscale="ylgn", showscale=True),
        height=350, width=1000,
        margin=dict(l=10, r=10, t=30, b=30)
    )

    fig.show()




0.089([0, 6, 15])+-0.186([8, 9, 15])+0.097([5, 10, 15])
MINTERP/degree-3/width-3/func-3/epoch-233.pt
torch.Size([22, 22]) torch.Size([3, 22])


-0.208([1, 2, 13])+0.776([6, 11, 12])+-0.568([9, 10, 13])
MINTERP/degree-3/width-3/func-4/epoch-2997.pt
torch.Size([22, 22]) torch.Size([3, 22])


0.436([6, 11, 19])+0.354([3, 9, 16])+-0.790([4, 10, 18])
MINTERP/degree-3/width-3/func-2/epoch-305.pt
torch.Size([22, 22]) torch.Size([3, 22])


0.239([2, 8, 16])+0.467([4, 13, 18])+-0.706([0, 1, 6])
MINTERP/degree-3/width-3/func-1/epoch-78254.pt
torch.Size([22, 22]) torch.Size([3, 22])
