In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
from sklearn.model_selection import train_test_split

# BertForSequenceClassification Captum

## preparing dataset

In [3]:
post_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_post.csv', encoding='UTF-8')
comment_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_comment.csv', encoding='UTF-8')

In [4]:
# texts (x)
post_contents = list(post_df['content'])
comment_bodies = list(comment_df['content'])

# satisfaction score (y)
satisfactions_float = list(post_df['satisfaction'])
satisfactions = []

for s in satisfactions_float:
    if s < 3.5:
        satisfactions.append(0)
    elif s < 5:
        satisfactions.append(1)
    else:
        satisfactions.append(2)

data = []

for content, body, satisfaction in zip(post_contents, comment_bodies, satisfactions):
    data.append([content + '[SEP]' + body, satisfaction])

columns = ['contents', 'label']
df = pd.DataFrame(data, columns=columns)

# data split (train & test sets)
idx_train, idx_remain = train_test_split(df.index.values, test_size=0.20, random_state=42)
idx_val, idx_test = train_test_split(idx_remain, test_size=0.50, random_state=42)

print(idx_train.shape, idx_val.shape, idx_test.shape)

train_df = df.iloc[idx_train]
val_df = df.iloc[idx_val]
test_df = df.iloc[idx_test]

count_min_label = min(train_df['label'].value_counts())

labels = [0, 1, 2]

train_sample_df = pd.DataFrame([], columns=columns)

for label in labels:
    tmp = train_df[train_df['label'] == label]
    tmp_sampled = tmp.sample(frac=1).iloc[:count_min_label]
    train_sample_df = pd.concat([train_sample_df, tmp_sampled])

(800,) (100,) (100,)


In [5]:
train_sample_df

Unnamed: 0,contents,label
880,For some reason watching a horrifying movie &g...,0
514,My depression/anxiety has been back in full sw...,0
287,I had to be the ugly one.I had to be the depre...,0
505,I would take a bullet for her. I would go as f...,0
435,There's no one I want to be around. I spend t...,0
...,...,...
596,Iâm about three weeks into my antidepressant...,2
154,Or a meeting place where you don't have to soc...,2
266,I love putting on some head phones and blaring...,2
111,I was a 21 year old student studying physics a...,2


## Loading model

In [6]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [7]:
model_path = f'../predicting-satisfaction-using-graphs/model/epoch_10.model'
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3, problem_type='multi_label_classification', output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [8]:
model.load_state_dict(torch.load(model_path, map_location=device))

<All keys matched successfully>

In [9]:
model.eval()
model.zero_grad()

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [35]:
inputs = tokenizer(data[0][0], max_length=512, padding='max_length', truncation=True,return_tensors='pt')
inputs.to(device)

{'input_ids': tensor([[  101,  2353,  2095,  2267,  3076,  1012,  2026,  6245,  2038,  2467,
         27674,  4321,  2013,  2591, 10089,  1010, 10261,  1045,  2018,  2814,
          2021,  2025,  4209,  2129,  2000,  2191,  2068,  1012,  1045,  2985,
          2048,  2086,  1999, 19568,  2015,  1998,  2196,  3764,  2000,  2151,
          1997,  2026, 10638,  1998,  2134,  1005,  1056,  2428,  3113, 10334,
          1012,  2197,  2305,  1010,  2005,  1996,  2034,  2051,  1999,  2026,
          2166,  1010,  1045,  2001,  4778,  2000,  1037,  2283,  1006,  2025,
          7714,  2568,  2017,  1025,  1045,  2001,  1037,  2112,  1997,  1037,
          2177,  2008,  2288, 13643,  4778,  1007,  1012,  1045, 10749,  1037,
          2210,  2978,  1010,  2021,  2288, 10247,  2100,  2855,  1010,  1998,
          2025,  4209,  2026,  6537,  1045,  3030,  1012,  1045,  2001,  3294,
         17358,  2306,  1996,  2034,  3178,  1012,  4661, 25795,  1998,  2019,
          8138, 14806,  1010,  1996,  

## Custom forward function

In [47]:
def predict(inputs):
    output = model(**inputs)
    return output.logits

In [56]:
def classification_forward_func(inputs):
    pred = predict(inputs)
    pred = np.argmax(pred.cpu().detach().numpy(), axis=1)
    
    return pred[0]

In [57]:
print(classification_forward_func(inputs))

0


## Define baselines

In [58]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference (zero padding)
sep_token_id = tokenizer.sep_token_id # A token used as a separator between post and comment and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated post-comment word sequence

In [59]:
print(ref_token_id, sep_token_id, cls_token_id)

0 102 101


In [69]:
def construct_input_ref_pair(post, comment, ref_token_id, sep_token_id, cls_token_id):
    post_ids = tokenizer.encode(post, add_special_tokens=False)
    comment_ids = tokenizer.encode(comment, add_special_tokens=False)

    # construct input token ids
    input_ids = torch.tensor([[cls_token_id] + post_ids + [sep_token_id] + comment_ids + [sep_token_id]], device=device)
    seq_len = input_ids.size(1)
    
    # token_type_ids is used both embeddings
    token_type_ids = torch.tensor([[0 for i in range(seq_len)]], device=device)

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(post_ids) + [sep_token_id] + \
        [ref_token_id] * len(comment_ids) + [sep_token_id]

    return input_ids, torch.tensor([ref_input_ids], device=device), token_type_ids, len(post_ids)

In [70]:
print(construct_input_ref_pair(post_contents[0], comment_bodies[0], ref_token_id, sep_token_id, cls_token_id))

(tensor([[  101,  2353,  2095,  2267,  3076,  1012,  2026,  6245,  2038,  2467,
         27674,  4321,  2013,  2591, 10089,  1010, 10261,  1045,  2018,  2814,
          2021,  2025,  4209,  2129,  2000,  2191,  2068,  1012,  1045,  2985,
          2048,  2086,  1999, 19568,  2015,  1998,  2196,  3764,  2000,  2151,
          1997,  2026, 10638,  1998,  2134,  1005,  1056,  2428,  3113, 10334,
          1012,  2197,  2305,  1010,  2005,  1996,  2034,  2051,  1999,  2026,
          2166,  1010,  1045,  2001,  4778,  2000,  1037,  2283,  1006,  2025,
          7714,  2568,  2017,  1025,  1045,  2001,  1037,  2112,  1997,  1037,
          2177,  2008,  2288, 13643,  4778,  1007,  1012,  1045, 10749,  1037,
          2210,  2978,  1010,  2021,  2288, 10247,  2100,  2855,  1010,  1998,
          2025,  4209,  2026,  6537,  1045,  3030,  1012,  1045,  2001,  3294,
         17358,  2306,  1996,  2034,  3178,  1012,  4661, 25795,  1998,  2019,
          8138, 14806,  1010,  1996,  2069,  2111, 

In [12]:
inputs = inputs.to(device)

In [13]:
outputs = model(**inputs)

In [65]:
print(outputs.hidden_states[12])

tensor([[[-0.2771,  0.6628,  0.6646,  ...,  0.4570, -0.3214,  0.4301],
         [-0.6448,  0.5613,  1.9034,  ...,  1.4501,  0.6793, -0.8462],
         [-0.3622,  1.1128,  1.3065,  ...,  0.9648,  0.6278, -0.7356],
         ...,
         [-0.1370,  1.3976,  0.9012,  ...,  1.1594, -0.1056, -0.3206],
         [ 0.1244,  1.1471,  0.3618,  ...,  1.0564, -0.4413, -0.4324],
         [-0.1614,  0.8678,  0.9781,  ...,  0.9212, -0.3137, -0.6701]]],
       grad_fn=<NativeLayerNormBackward>)


In [83]:
print(data[0][1])

0


In [14]:
model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [None]:
# forward_func을 뭔가 바꿔야 할 듯. baseline 필요!! 이게 아마 0 vector.

In [15]:
lig = LayerIntegratedGradients(model, model.bert.embeddings)

In [16]:
attribution = lig.attribute(inputs.input_ids, target=0)

AttributeError: 'SequenceClassifierOutput' object has no attribute 'shape'

In [27]:
# padding = 0
ref_token_id = tokenizer.pad_token_id

In [28]:
print(ref_token_id)

0


In [21]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits

In [22]:
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [23]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings