In [None]:
import jax
import numpy as np
import jax.numpy as jnp
from interp.model.model_loading import load_model
model_fn, params, tok = load_model(
    "jan5_attn_only_two_layers/", 
)  # leave off models_dir for rrfs, using local is faster
model = model_fn.bind(params)
vocaby = {tok.decode([v]):v for v in range(len(tok))}

In [None]:
ovs = np.array(model.get_ov_combined_mats_all_layers())
qks = np.array(model.get_qk_combined_mats_all_layers())
ovs_virtual = np.array(ovs[0].reshape(8,1,256,256) @ ovs[1].reshape(1,8,256,256))
emb =np.array( params["params"]["embedding"]["token_embedding"]["embedding"]).astype(np.float16)
emb_m0 = emb-np.mean(emb,axis=0)
emb_unit = emb_m0/np.linalg.norm(emb_m0,axis=1).reshape(-1,1)
emb_small = emb[:10_000]
emb_small_unit = emb_unit[:10_000]
print(ovs.shape,qks.shape,ovs_virtual.shape,emb.shape,emb.dtype)

In [None]:
def mat_to_bigrams(mat):
    print(mat.shape)
    bigmat = jnp.array(emb_small_unit) @ jnp.array(mat).astype(jnp.float16) @ jnp.array(emb_small_unit.T)
    return np.array(bigmat.T)
    
def bigrams_top(bigrams,k=40):
    bigrams = np.array(bigrams)
    flat = bigrams.flatten()
    sorted_idxs = np.argpartition(flat,-k)
    tops = sorted_idxs[-k:]
    tops_raveled = np.unravel_index(tops,bigrams.shape)
    tops_tok_ids = [(tops_raveled[0][i],tops_raveled[1][i],flat[tops[i]]) for i in range(k)]
    
    sorted_idxs=None
    sorted_idxs = np.argpartition(flat,k)
    bottoms = sorted_idxs[:k]
    bottoms_raveled = np.unravel_index(bottoms,bigrams.shape)
    bottom_tok_ids = [(bottoms_raveled[0][i],bottoms_raveled[1][i],flat[bottoms[i]]) for i in range(k)]
    return tops_tok_ids,bottom_tok_ids
    
def str_bigrams(bigrams):
    return ("\n".join([f"'{tok.decode([f])}' -> '{tok.decode([t])}' : {v}" for f,t,v in bigrams]))
    
def str_mat_bigrams(mat):
    return "\n\n\n".join([f"{nm}\n\n{str_bigrams(x)}" for nm,x in zip(["top","bottom"],bigrams_top(mat_to_bigrams(mat)))])
    
print(str_mat_bigrams(ovs[1][5]))


In [None]:
print(ovs[1][5].shape)

# I don't think this function is useful anymore
def mat_to_copy_and_noncopy(mat):
    eigenvals,eigenvecs = np.linalg.eig(mat)
    real_mask = np.isreal(eigenvals) & np.all(np.isreal(eigenvecs),axis=0)
    real_mat = np.real(eigenvecs@np.diag(eigenvals*real_mask)@eigenvecs.T)
    imag_mat = np.real(eigenvecs@np.diag(eigenvals*(1-real_mask))@eigenvecs.T)
    return real_mat,imag_mat
    
    
def eig_by_angle(mat):
    mn = 1e-4
    # mn = np.mean(np.linalg.norm(mat,axis=1))*0.0001
    eigenvals,eigenvecs = np.linalg.eig(mat)
    inv_eigenvecs = np.linalg.inv(eigenvecs)
    not_tiny_mask = np.absolute(eigenvals)>mn
    eigenvals = eigenvals[not_tiny_mask]
    eigenvecs = eigenvecs[:,not_tiny_mask]
    inv_eigenvecs = inv_eigenvecs[not_tiny_mask]
    angles =np.angle(eigenvals)
    angles = np.amin(np.stack([np.abs(angles),np.abs(np.pi-angles)]),axis=0)
    # sorted_idxs = np.argsort(angles)
    sorted_idxs = np.argsort(-np.absolute(eigenvals))
    return eigenvals[sorted_idxs],eigenvecs[:,sorted_idxs],inv_eigenvecs[sorted_idxs].T
    
    
def svd_nont_tiny(mat):
    u,s,vh = np.linalg.svd(mat)
    not_tiny = s>1e-4
    s=s[not_tiny]
    u = u[:,not_tiny]
    v = vh.T[:,not_tiny]
    return s,u,v
    
def svd_inout(mat):
    s,innie,outie = svd_nont_tiny(mat)
    return innie*np.sqrt(s),outie*np.sqrt(s)

def eig_inout(mat):
    s,innie,outie = eig_by_angle(mat)
    return innie*np.sqrt(s),outie*np.sqrt(s)
    
def format_top_logits(logit_shaped,k=20):
    sorted_idxs = np.argsort(logit_shaped)
    topk = sorted_idxs[-k:][::-1]
    bottomk=sorted_idxs[:k]
    topk_s = '\n'.join([f"{tok.decode([x])}" for x in topk])
    bottomk_s = '\n'.join([f"{tok.decode([x])}" for x in bottomk])
    string = f"\n{topk_s}"
    return string

def show_in_outies(s,innies,outies):
    emb_eigbasis_to = emb_small_unit@outies*np.sqrt(s)
    emb_eigbasis_from = emb_small_unit@innies*np.sqrt(s)
    strang = ""
    for i in range(emb_eigbasis_to.shape[1]):
        strang+="\n\nsingular val "+str(s[i])
        strang+="\npos"
        strang+=f"\nfrom {jnp.amax(emb_eigbasis_from[:,i])}"
        strang+=format_top_logits(emb_eigbasis_from[:,i],5)
        strang+=f"\nto {jnp.amax(emb_eigbasis_to[:,i])}"
        strang+=format_top_logits(emb_eigbasis_to[:,i],5)
        strang+="\nneg"
        strang+=f"\nfrom {jnp.amax(-emb_eigbasis_from[:,i])}"
        strang+=format_top_logits(-emb_eigbasis_from[:,i],5)
        strang+=f"\nto {jnp.amax(-emb_eigbasis_to[:,i])}"
        strang+=format_top_logits(-emb_eigbasis_to[:,i],5)
    return strang

def next_type_copying_metric(ov,qk):
    q_to_o = qk @ ov.T
    
    print(show_in_outies(*eig_by_angle(q_to_o)))
    
# print(show_in_outies(*eig_by_angle(ovs[1][5])))
next_type_copying_metric(ovs[1,5],qks[1,5])


In [None]:
for layer in range(2):
    for head in range(8):
        for qkovname,qkov in [("qk",qks),("ov",ovs)]:
            fname = f"/home/ubuntu/tokmaps_jan5/{qkovname}_{layer}.{head}.txt"
            strang=""
            strang+=show_in_outies(*eig_by_angle(qkov[layer,head]))
            open(fname,"w").write(strang)

In [None]:
# I want to know what subspace is copied, what subspace is moved to what subspace, and what subspace is ignored. (4 subspaces).
open_square_strs = [x for x in vocaby.keys() if ("[" in x) and len(set("[]()<>{}").intersection(x))==1]
close_square_strs = [x for x in vocaby.keys() if ("]" in x) and len(set("[]()<>{}").intersection(x))==1]
open_strs = [x for x in vocaby.keys() if ("(" in x) and len(set("[]()<>{}").intersection(x))==1]
close_strs = [x for x in vocaby.keys() if (")" in x) and len(set("[]()<>{}").intersection(x))==1]
print(open_square_strs,close_square_strs,open_strs,close_strs)

In [None]:
print(format_top_logits(((emb_unit[vocaby["}"]]+emb_unit[vocaby["}"]])@emb_unit.T).T,100))

In [None]:
from functools import lru_cache
import itertools
import torch
import os
from interp.tools.data_loading import DATA_DIR


@lru_cache()
def get_val_seqs():
    fnames = os.listdir(f"{DATA_DIR}/owt_tokens_int16_val")[:2]
    all_tokens = [torch.load(f"{DATA_DIR}/owt_tokens_int16_val/{f}") for f in fnames]
    data_pt = list(itertools.chain(*[torch.split(x["tokens"], x["lens"].tolist()) for x in all_tokens]))

    max_size = 511

    data = torch.stack(
        [data_pt_val[:max_size].to(torch.int64) + 32768 for data_pt_val in data_pt if data_pt_val.size(0) >= max_size],
        dim=0,
    ).numpy()
    print("data shape",data.shape)
    return data

In [None]:
get_val_seqs()