In [1]:
import argparse
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchaudio

import hydra
from omegaconf import OmegaConf
from torch.distributions import Categorical
from tqdm.auto import tqdm

from src import utils
from src.dataloaders.audio import mu_law_decode
from src.models.baselines.wavenet import WaveNetModel
from train import SequenceLightningModule

import shap
import scipy as sp
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%run -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=10 l_prefix=5 decode=text

[rank: 0] Global seed set to 1111


Loading model...
Full checkpoint path: /home/ys724/S4/State-Space-Interpretability/state-spaces/checkpoints/s4-wt103.pt
[2023-05-05 22:19:43,996][root][INFO] - Loading cached dataset...
Vocab size: 267735
[2023-05-05 22:19:46,970][src.models.sequence.kernels.ssm][INFO] - Pykeops installation found.
[2023-05-05 22:19:46,986][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,009][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,097][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,118][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,204][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,226][src.models.sequence.kernels.ssm][INFO] - Constructing S4 (H, N, L) = (1024, 32, 8192)
[2023-05-05 22:19:47,312][sr

100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 55.20it/s]

x tensor([[    0,     9, 96220,  ...,  5610,     2,   417]]) tensor([[     0,      9,  96220, 198382,      9]]) [['<eos>', '=', 'Homarus', 'gammarus', '=']] ['<eos> = Homarus gammarus =']
y tensor([[     9,  96220, 198382,      9,      0,      0,  96220, 198382,      2,
              1]]) tensor([[     9,  96220, 198382,      9,      0,      0,  96220, 198382,      2,
              1]]) [['=', 'Homarus', 'gammarus', '=', '<eos>', '<eos>', 'Homarus', 'gammarus', ',', 'the']] ['= Homarus gammarus = <eos> <eos> Homarus gammarus , the']





       0  1        2         3  4
0  <eos>  =  Homarus  gammarus  =
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False False False False False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[0, 0, 0, 0, 0]])
explainer start with input [[0 0 0 0 0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 52.53it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [ True False False False False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[0, 0, 0, 0, 0]])
explainer start with input [[0 0 0 0 0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 53.90it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [ True False  True False False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[    0,     0, 96220,     0,     0]])
explainer start with input [[    0     0 96220     0     0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 57.24it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [ True  True  True False False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[    0,     9, 96220,     0,     0]])
explainer start with input [[    0     9 96220     0     0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 53.55it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [ True  True  True  True False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[     0,      9,  96220, 198382,      0]])
explainer start with input [[     0      9  96220 198382      0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 56.29it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [ True  True  True  True  True]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[     0,      9,  96220, 198382,      9]])
explainer start with input [[     0      9  96220 198382      9]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 54.95it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False  True  True  True  True]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[     0,      9,  96220, 198382,      9]])
explainer start with input [[     0      9  96220 198382      9]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 57.11it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False  True False  True  True]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[     0,      9,      0, 198382,      9]])
explainer start with input [[     0      9      0 198382      9]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 56.85it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False False False  True  True]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[     0,      0,      0, 198382,      9]])
explainer start with input [[     0      0      0 198382      9]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 55.10it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False False False False  True]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[0, 0, 0, 0, 9]])
explainer start with input [[0 0 0 0 9]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 56.36it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
x ['<eos>' '=' 'Homarus' 'gammarus' '=']
mask [False False False False False]
x_tensor tensor([     0,      9,  96220, 198382,      9])
output tensor([[0, 0, 0, 0, 0]])
explainer start with input [[0 0 0 0 0]]


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 57.46it/s]


shap forward pass done with output torch.Size([1, 10]) (1,)
permutation explainer


!!!!!!!!!!


In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("gpt2").cuda()
# set model decoder to true
model.config.is_decoder=True
# set text-generation params under task_specific_params
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2
}
s = ['I enjoy walking with my cute dog']

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████| 665/665 [00:00<00:00, 64.1kB/s]
Downloading (…)olve/main/vocab.json: 100%|█████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 19.0MB/s]
Downloading (…)olve/main/merges.txt: 100%|███████████████████████████████████████████| 456k/456k [00:00<00:00, 12.8MB/s]
Downloading (…)/main/tokenizer.json: 100%|█████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 24.4MB/s]
Downloading pytorch_model.bin: 100%|██████████████████████████████████████████████████| 548M/548M [00:02<00:00, 230MB/s]
Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████| 124/124 [00:00<00:00, 15.4kB/s]


In [4]:
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
print(shap_values)
# .values [1, input_len, output_len]
# .base_values [1, output_len]
# .data = [output_len] - str
shap.plots.text(shap_values)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.


.values =
array([[[-1.16613934e-01,  8.72616787e-01,  1.00514289e+00,
          1.25749471e-01,  1.99217282e-01, -4.10349673e-01,
         -6.36086927e-02,  4.05762883e-01, -1.30096765e-01,
          1.27657733e-02,  7.76433187e-02, -3.32962458e-01,
          8.88581558e-02],
        [-4.31012734e-01,  1.01470339e+00,  2.32480506e-02,
         -1.58169040e-01,  1.67172922e-01,  1.49489320e-01,
          3.62267154e-02,  3.55961711e-01,  4.57240805e-01,
          1.14920521e-02,  1.59498516e-02, -2.32643776e-02,
          5.82410425e-02],
        [-4.26616845e-01, -2.93118360e-02, -1.94200333e-03,
         -3.55014850e-02, -1.96080540e-01, -3.22592551e-01,
          1.87846905e-01,  1.71433277e-01, -4.59277155e-02,
          2.31572030e-01,  3.91687669e-01,  2.02801050e-01,
          1.70823823e-01],
        [ 7.23258793e-02,  3.10825985e-01, -1.06642422e-01,
         -8.20662956e-02,  1.83227621e-02,  4.08413111e-01,
          5.25015842e-02, -9.38722999e-02, -4.91522092e-03,
         

!!!!!!!!!!
