In [1]:
import argparse
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchaudio

import hydra
from omegaconf import OmegaConf
from torch.distributions import Categorical
from tqdm.auto import tqdm

from src import utils
from src.dataloaders.audio import mu_law_decode
from src.models.baselines.wavenet import WaveNetModel
from train import SequenceLightningModule

import shap
import scipy as sp
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%run -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=20 decode=text

[rank: 0] Global seed set to 1111


Loading model...
Full checkpoint path: /home/ys724/S4/State-Space-Interpretability/state-spaces/checkpoints/s4-wt103.pt
[2023-05-05 23:09:20,389][root][INFO] - Loading cached dataset...
Vocab size: 267735
[2023-05-05 23:09:23,310][src.models.sequence.kernels.ssm][INFO] - Pykeops installation found.
[2023-05-05 23:09:23,328][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,351][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,439][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,460][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,546][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,568][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 23:09:23,654][sr

100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 69.62it/s]


x torch.Size([1, 7]) tensor([[   68,  5763,  4490,    19,   617, 16082,  3225]])
x_sym [['I', 'enjoy', 'walking', 'with', 'my', 'cute', 'dog']]
y torch.Size([1, 27]) tensor([[ 5763,  4490,    19,   617, 16082,  3225,   204,    68,    32,    25,
          3225,   166,    17,     8,  1232,     4,    27,   612,    49,  1184,
           596,   149,     3,  1676,   245, 17024,    37]])
y_sym [['enjoy', 'walking', 'with', 'my', 'cute', 'dog', 'since', 'I', 'had', 'his', 'dog', 'around', 'for', 'a', 'rate', 'of', 'at', 'least', 'one', 'hour', 'every', 'day', '.', 'And', 'family', 'pets', 'are']]
pd    0      1        2     3   4     5    6
0  I  enjoy  walking  with  my  cute  dog


100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 72.48it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 72.22it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 56.58it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 60.21it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 55.56it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 57.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 57.71it/s]
100%|███████████████████████████████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 58.78it/s]
100%|███████████████████████████

permutation explainer


Permutation explainer: 2it [00:20, 20.89s/it]                                                                           


shap_values .values =
array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         -7.19883333e+03,  9.59216667e+03, -8.18050000e+03,
         -7.20950000e+03, -9.72166667e+02, -1.96753333e+04,
         -3.51616667e+03,  6.43366667e+03,  5.37200000e+03,
          1.64458333e+04,  9.17666667e+02,  2.92400000e+03,
         -6.15500000e+02, -1.22333333e+02, -3.02500000e+02,
          1.95020000e+04,  3.70383333e+03,  5.13833333e+02,
         -1.91533333e+03,  5.60333333e+02,  1.97266667e+03],
        [ 5.76300000e+03,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          2.08231667e+04,  6.04663333e+04,  6.95261667e+04,
          1.71068333e+04, -4.40166667e+02,  4.26083333e+03,
          2.03383333e+03, -2.28328333e+04, -2.23241667e+04,
         -1.82333333e+02, -1.88000000e+02,  7.51666667e+02,
          2.46333333e+03,  3.07666667e+03,  1.04035000e+04,
         -2.02016

success


In [3]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
# set model decoder to true
model.config.is_decoder=True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.0,
    "top_k": 50,
    "no_repeat_ngram_size": 2
}
s = ['I enjoy walking with my cute dog']

In [5]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
print(shap_values)
# .values [1, input_len, output_len]
# .base_values [1, output_len]
# .data = [output_len] - str
shap.plots.text(shap_values)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


transformer marker
.values =
array([[[-1.15705233e-01,  8.71425777e-01,  1.00390459e+00,
          1.25499957e-01,  2.00091085e-01, -4.11364949e-01,
         -6.48017131e-02,  4.05273380e-01, -1.29734015e-01,
          1.24277879e-02,  7.72957342e-02, -3.30968202e-01,
          8.98718722e-02],
        [-4.32002024e-01,  1.01270052e+00,  2.39886401e-02,
         -1.57598729e-01,  1.66682330e-01,  1.48772218e-01,
          3.77640636e-02,  3.56483063e-01,  4.56275874e-01,
          1.12930683e-02,  1.63692694e-02, -2.45283591e-02,
          6.12490102e-02],
        [-4.26505100e-01, -3.00411366e-02, -2.38224289e-03,
         -3.53173411e-02, -1.95921073e-01, -3.22619288e-01,
          1.87529437e-01,  1.71811006e-01, -4.50298186e-02,
          2.32262599e-01,  3.91155737e-01,  2.04094502e-01,
          1.69829878e-01],
        [ 7.30720785e-02,  3.11213760e-01, -1.05739072e-01,
         -8.19162751e-02,  1.83993413e-02,  4.09421432e-01,
          5.28721934e-02, -9.34302704e-02, -5.1341

success
