In [1]:
import argparse
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchaudio

import hydra
from omegaconf import OmegaConf
from torch.distributions import Categorical
from tqdm.auto import tqdm

from src import utils
from src.dataloaders.audio import mu_law_decode
from src.models.baselines.wavenet import WaveNetModel
from train import SequenceLightningModule

import shap
import scipy as sp
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%run -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=20 decode=text

[rank: 0] Global seed set to 1111


Loading model...
Full checkpoint path: /home/ys724/S4/State-Space-Interpretability/state-spaces/checkpoints/s4-wt103.pt
[2023-05-08 00:54:28,054][root][INFO] - Loading cached dataset...
Vocab size: 267735
[2023-05-08 00:54:31,073][src.models.sequence.kernels.ssm][INFO] - Pykeops installation found.
[2023-05-08 00:54:31,092][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,115][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,203][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,225][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,311][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,333][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-08 00:54:31,419][sr

100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 59.80it/s]


x torch.Size([1, 13]) tensor([[    9,  3074, 18291,     9,    13,  1046,  4193,    47,  1046,    49,
            12,  4193,  1470]])
x_sym [['=', 'Gold', 'Dollar', '=', 'The', 'gold', 'dollar', 'or', 'gold', 'one', '@-@', 'dollar', 'piece']]
y torch.Size([1, 33]) tensor([[ 3074, 18291,     9,    13,  1046,  4193,    47,  1046,    49,    12,
          4193,  1470,     9,     9,     0,     0,    13,  4323, 18291,    11,
          1510,    17,  2793,    26,  7922,     6,  4982,     2,   472,    20,
             1,  4528,  4193]])
y_sym [['Gold', 'Dollar', '=', 'The', 'gold', 'dollar', 'or', 'gold', 'one', '@-@', 'dollar', 'piece', '=', '=', '<eos>', '<eos>', 'The', 'Silver', 'Dollar', 'was', 'struck', 'for', 'circulation', 'from', '1854', 'to', '1895', ',', 'followed', 'by', 'the', 'Peace', 'dollar']]
pd   0     1       2  3    4     5       6   7     8    9    10      11     12
0  =  Gold  Dollar  =  The  gold  dollar  or  gold  one  @-@  dollar  piece


100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 57.37it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 56.66it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 56.58it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 53.38it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 53.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 52.57it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 56.97it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 60.93it/s]
100%|███████████████████████████

permutation explainer


Permutation explainer: 2it [00:34, 34.12s/it]                                                                           


shap_values .values =
array([[[ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
          0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
          0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
         -9.527500e+02,  1.387775e+04, -1.174825e+04, -1.220575e+04,
         -2.821000e+03,  9.315000e+02,  1.580000e+03,  6.975000e+03,
          8.520500e+03, -7.053500e+03,  1.269750e+03, -2.179000e+03,
         -2.082500e+02, -1.317150e+04, -9.500000e+01, -9.665000e+02,
         -4.955500e+03, -2.606750e+03,  8.573500e+03, -3.475000e+01,
          1.422750e+03],
        [ 3.074000e+03,  0.000000e+00,  0.000000e+00,  0.000000e+00,
          0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
          0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
          5.097250e+03, -1.982500e+03,  1.996500e+03, -1.696250e+03,
          5.799500e+03,  3.741750e+03,  2.671275e+04, -3.910000e+02,
         -3.728650e+04, -4.539750e+03,  2.371750e+03, -2

success


In [4]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

In [5]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
# set model decoder to true
model.config.is_decoder=True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 100,
    "temperature": 0.9,
    "top_k": 50,
    "no_repeat_ngram_size": 2
}
s = ['= Gold Dollar = The gold dollar or gold one @-@ dollar piece']

In [6]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
print(shap_values)
# .values [1, input_len, output_len]
# .base_values [1, output_len]
# .data = [output_len] - str
shap.plots.text(shap_values)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


transformer marker
.values =
array([[[-0.20835433, -0.1051532 ,  0.13181247,  0.26817376,
         -0.02727577],
        [ 0.30752816,  0.03537743, -0.28205281, -0.0491177 ,
         -0.15479989],
        [-0.11170694,  0.2963228 , -0.02762653,  0.02929163,
         -0.03442985],
        [-0.64292733,  0.11487361,  0.74492825,  0.52801265,
         -0.15754432],
        [ 1.45247378,  0.40795844, -0.60887061,  0.14190982,
          1.16245558],
        [ 0.17974687,  0.04086812, -0.17671865, -0.00915309,
          0.06156504],
        [-0.07644959,  0.23733146,  0.15284447,  0.05611296,
          0.0704268 ],
        [-0.06096549,  0.12050801, -0.05214229,  0.02295307,
          0.00960364],
        [-0.27836133,  0.04904373,  0.12374053, -0.07576422,
         -0.18271554],
        [ 0.29612501, -0.25741026,  0.51427652, -0.06724159,
          0.46504156],
        [-0.17658567, -0.01882131,  0.21228288,  0.07124143,
         -0.30806721],
        [ 0.2066903 ,  0.04875015, -0.19779008,

success
