## Black Box Interp

### Identity Model

In [4]:
import collections
import copy
import gc
import logging
import math
import os
import pickle
import time
import re

import sys
notebook_dir = os.getcwd()
grandparent_dir = os.path.dirname(os.path.dirname(notebook_dir))
sys.path.append(grandparent_dir)
print(sys.path)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as Fn
from tensordict import TensorDict
from pandas.plotting import table
from datetime import datetime

from core import Config
from models import GPT2, CnnKF
from data_train import set_config_params
from create_plots_with_zero_pred import tf_preds
from linalg_helpers import print_matrix

#set cuda device
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



['/opt/homebrew/Cellar/python@3.11/3.11.5/Frameworks/Python.framework/Versions/3.11/lib/python311.zip', '/opt/homebrew/Cellar/python@3.11/3.11.5/Frameworks/Python.framework/Versions/3.11/lib/python3.11', '/opt/homebrew/Cellar/python@3.11/3.11.5/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload', '', '/Users/alanzu/Library/Python/3.11/lib/python/site-packages', '/opt/homebrew/lib/python3.11/site-packages', '/Users/alanzu/projects/TFs_do_KF_ICL/src', '/var/folders/mn/qf6yrm_s29d029x1q_ntbxcc0000gs/T/tmproniddo6', '/Users/alanzu/projects/TFs_do_KF_ICL/src', '/Users/alanzu/projects/TFs_do_KF_ICL/src', '/Users/alanzu/projects/TFs_do_KF_ICL/src']
/Users/alanzu/projects/TFs_do_KF_ICL/src/interp/black_box
Using device: cpu
CUDA_VISIBLE_DEVICES: None
Using device: cpu


In [7]:
# get model ckpt

model_name = "ident"

if model_name == "ident":
    valA = "ident"
    valC = "_ident_C"
    nx = 5

ckpt_step = 100

ckpt_path = f"/data/shared/ICL_Kalman_Experiments/model_checkpoints/GPT2/250124_052617.8dd0f8_multi_sys_trace_ident_state_dim_5_ident_C_lr_1.584893192461114e-05_num_train_sys_40000/checkpoints/step={ckpt_step}.ckpt"

print(ckpt_path)

config = Config()
output_dir, ckpt_dir, experiment_name = set_config_params(config, model_name)

print(f"ckpt_dir: {ckpt_dir}")
config.override("ckpt_path", ckpt_dir + f"/checkpoints/step={ckpt_step}.ckpt")
print(f"ckpt_path: {config.ckpt_path}")

num_gpu = len(config.devices)
batch_size = config.batch_size
print(f"Number of GPUs: {num_gpu}")
print(f"Batch size: {batch_size}")
print(f"Number of training examples: {ckpt_step*batch_size*num_gpu}")

/data/shared/ICL_Kalman_Experiments/model_checkpoints/GPT2/250124_052617.8dd0f8_multi_sys_trace_ident_state_dim_5_ident_C_lr_1.584893192461114e-05_num_train_sys_40000/checkpoints/step=100.ckpt


IDENTITY MEDIUM MODEL


ckpt_dir: /data/shared/ICL_Kalman_Experiments/model_checkpoints/GPT2/250124_052617.8dd0f8_multi_sys_trace_ident_state_dim_5_ident_C_lr_1.584893192461114e-05_num_train_sys_40000
ckpt_path: /data/shared/ICL_Kalman_Experiments/model_checkpoints/GPT2/250124_052617.8dd0f8_multi_sys_trace_ident_state_dim_5_ident_C_lr_1.584893192461114e-05_num_train_sys_40000/checkpoints/step=100.ckpt
Number of GPUs: 2
Batch size: 512
Number of training examples: 102400


In [8]:
#load the model

model = GPT2.load_from_checkpoint(config.ckpt_path,
                                n_dims_in=config.n_dims_in, n_positions=config.n_positions,
                                n_dims_out=config.n_dims_out, n_embd=config.n_embd,
                                n_layer=config.n_layer, n_head=config.n_head, use_pos_emb=config.use_pos_emb, map_location=device, strict=True).eval().to(
    device)

In [17]:
print(f"model: {model}")

model: GPT2(
  (_read_in): Linear(in_features=57, out_features=128, bias=True)
  (_backbone): GPT2Model(
    (wte): Embedding(50257, 128)
    (wpe): Embedding(2048, 128)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=384, nx=128)
          (c_proj): Conv1D(nf=128, nx=128)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=512, nx=128)
          (c_proj): Conv1D(nf=128, nx=512)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (_read_out): Linear(in_features=128, out_features=5, bias=True)
)

In [14]:
#get interleaved test data
num_sys_haystack = 1

config.override("n_positions", num_sys_haystack*12 + 12)


data_path = f"/data/shared/ICL_Kalman_Experiments/train_and_test_data/{valA}/interleaved_traces_{valA}{valC}_state_dim_{nx}_num_sys_haystack_{num_sys_haystack}.pkl"


with open(data_path, "rb") as f:
    data_dict = pickle.load(f)
    for key in data_dict.keys():
        print(f"{key}: {data_dict[key]}\n")

multi_sys_ys: [[[[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.3653136  -1.0134141
     0.7160135 ]
   ...
   [ 0.          0.          0.         ...  0.3653136  -1.0134141
     0.7160135 ]
   [ 0.          0.          0.         ...  0.3653136  -1.0134141
     0.7160135 ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  [[ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]
   [ 0.          0.          0.         ... -0.08013932  0.7014316
    -0.3362667 ]
   ...
   [ 0.          0.          0.         ... -0.08013932  0.7014316
    -0.3362667 ]
   [ 0.          0.          0.         ... -0.08013932  0.7014316
    -0.3362667 ]
   [ 0.          0.          0.         ...  0.          0.
     0.        ]]

  

In [29]:
# input data into model


preds_tf = tf_preds(data_dict["multi_sys_ys"], model, device, config)

# print_matrix()
# print(f"preds_tf: {preds_tf[0,0].T.shape}")
print_matrix(preds_tf[0,0].T, "preds_tf[0,0].T")

print(f"multi_sys_ys: {data_dict['multi_sys_ys'][0,0,:,-5:].shape}")
print_matrix(data_dict["multi_sys_ys"][0,0,:,-5:].T, "data_dict['multi_sys_ys'][0,0,:,-5:].T")

Matrix preds_tf[0,0].T:
    0.0000     0.0469     0.0564     0.4407     0.3394     0.3724     0.3378     0.3906     0.4034     0.3687     0.3910     0.3664     0.3764     0.3789     0.0444     0.3863     0.3446     0.4039     0.3791     0.4389     0.3787     0.3601     0.3754     0.3754     0.3779 
    0.0000     0.1251    -0.0600    -0.4705    -0.4672    -0.5215    -0.5567    -0.4810    -0.5076    -0.4692    -0.4909    -0.4855    -0.4452     0.2945    -0.0494    -0.4905    -0.4631    -0.4054    -0.4269    -0.5038    -0.5402    -0.4747    -0.4651    -0.4452    -0.4987 
    0.0000     0.0393     0.1003     0.3814     0.3712     0.3782     0.3829     0.3966     0.3528     0.2918     0.3788     0.3235     0.4351     0.1984     0.0971     0.3442     0.4097     0.3775     0.3514     0.2775     0.4476     0.3742     0.4250     0.3965     0.3958 
    0.0000    -0.2247    -0.0730    -0.9643    -0.9586    -0.9911    -0.9497    -0.9918    -1.0205    -1.0287    -0.9849    -0.9841    -1.0017     0