In [109]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
from sklearn.model_selection import train_test_split
from spacy.lang.en import English
from splitbert.textsplit import text_segmentation
from splitbert.splitbert import SplitBertConcatEncoderModel
from splitbert.splitbert import conduct_input_ids_and_attention_masks
from splitbert.splitbert import make_masks
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import tqdm

# Data Preparation

In [92]:
post_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_post.csv', encoding='UTF-8')
comment_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_comment.csv', encoding='UTF-8')
reply_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/avg_satisfaction_raw_0-999.csv', encoding='ISO-8859-1')

modes = [['seg', 'seg', 'snt']]

nlp = English()
nlp.add_pipe("sentencizer")

# satisfaction score (y)
satisfactions_float = list(reply_df['satisfy_composite'])
satisfactions = []

for s in satisfactions_float:
    if s < 3.5:
        satisfactions.append(0)
    elif s < 5:
        satisfactions.append(1)
    else:
        satisfactions.append(2)

reply_contents = list(reply_df['replyContent'])
post_contents = list(post_df['content'])
comment_bodies = list(comment_df['content'])


def get_sequences(contents, mode):
    sequences = []

    if mode == 'all':
        for content in contents:
            sequences.append([content])
    elif mode == 'seg':
        for content in contents:
            sentences = list(map(lambda x: str(x), list(nlp(content).sents)))
            sequences.append(text_segmentation(sentences))
    else:  # sentences
        for content in contents:
            sequences.append(list(map(lambda x: str(x), list(nlp(content).sents))))

    return sequences


for mode in modes:
    print(mode)
    post_sequences = get_sequences(post_contents, mode[0])
    comment_sequences = get_sequences(comment_bodies, mode[1])
    reply_sequences = get_sequences(reply_contents, mode[2])

    data = []
    max_post, max_comment, max_reply = 0, 0, 0
    i = 0
    for post, comment, reply, satisfaction, satisfaction_float in zip(post_sequences, comment_sequences,
                                                                          reply_sequences, satisfactions,
                                                                          satisfactions_float):
        if len(post) > max_post:
            max_post = len(post)
        if len(comment) > max_comment:
            max_comment = len(comment)
        if len(reply) > max_reply:
            max_reply = len(reply)

        data.append([i, post, comment, reply, satisfaction, satisfaction_float])
        i += 1

    print(max_post, max_comment, max_reply)
    max_count = max(max_post, max_comment, max_reply)
    print(max_count)

    columns = ['index', 'post_contents', 'comment_contents', 'reply_contents', 'label', 'score']
    df = pd.DataFrame(data, columns=columns)

    # data split (train & test sets)
    idx_train, idx_remain = train_test_split(df.index.values, test_size=0.20, random_state=42)
    idx_val, idx_test = train_test_split(idx_remain, test_size=0.50, random_state=42)

    train_df = df.iloc[idx_train]
    val_df = df.iloc[idx_val]
    test_df = df.iloc[idx_test]

    count_min_label = min(train_df['label'].value_counts())

    labels = [0, 1, 2]

    train_sample_df = pd.DataFrame([], columns=columns)

    for label in labels:
        tmp = train_df[train_df['label'] == label]
        tmp_sampled = tmp.sample(frac=1).iloc[:count_min_label]
        train_sample_df = pd.concat([train_sample_df, tmp_sampled])

    train_sample_df = train_sample_df.sample(frac=1)

['seg', 'seg', 'snt']
10 4 10
10


# Model Preparation

In [93]:
def forward_func_ig(inputs):
    now = 0

    for embeddings, count in zip(inputs, sentence_counts):
        embeddings = embeddings.unsqueeze(0)
        embeddings = embeddings.swapaxes(0, 1)

        embeddings = model.pe(embeddings)

        src_mask, src_key_padding_mask = make_masks(model.max_len, count, device)

        encoder_output = model.encoder(embeddings, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

        encoder_output = torch.mean(encoder_output[:count], dim=0)

        if now == 0:
            result_outputs = encoder_output
        else:
            result_outputs = torch.cat([result_outputs, encoder_output], dim=1)
        now += 1
    
    outputs = model.classifier1(result_outputs)
    logits = model.classifier2(outputs)
    return logits

In [94]:
def forward_func_ig2(inputs, sentence_counts):
    now = 0

    for encoder_output, count in zip(inputs, sentence_counts):
        encoder_output = encoder_output.swapaxes(0, 1)
        encoder_output = torch.mean(encoder_output[:count], dim=0)

        if now == 0:
            result_outputs = encoder_output
        else:
            result_outputs = torch.cat([result_outputs, encoder_output], dim=1)
        now += 1
        
    outputs = model.classifier1(result_outputs)
    logits = model.classifier2(outputs)
    return logits

ig = IntegratedGradients(forward_func_ig2)

In [95]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

model_path = f'../predicting-satisfaction-using-graphs/splitbert/model/seg_seg_snt/epoch_6.model'
model = SplitBertConcatEncoderModel(num_labels=len(labels), embedding_size=384, max_len=max_count, device='cpu', pc_segmentation=False)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to('cpu')
model.eval()

for param in model.sbert.parameters():
    param.requires_grad = False
    
for param in model.bert.parameters():
    param.requires_grad = False

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [96]:
def construct_input_ref_pair(triplet):
    input_ids_list, ref_input_ids_list, attention_masks_list, sentence_count_list = [], [], [], []
    
    for contents in triplet:
        result = tokenizer(contents, pad_to_max_length=True, truncation=True, max_length=256, return_tensors='pt')
        
        input_ids = result['input_ids']
        sentence_count_list.append(torch.tensor(len(input_ids)).unsqueeze(0))
        attention_masks = result['attention_mask']
        
        pad = (0, 0, 0, max_count-len(input_ids))
        input_ids = nn.functional.pad(input_ids, pad, "constant", 0)
        attention_masks = nn.functional.pad(attention_masks, pad, "constant", 0)
        ref_input_ids = torch.zeros_like(input_ids)

        input_ids_list.append(input_ids.unsqueeze(0))
        ref_input_ids_list.append(ref_input_ids.unsqueeze(0))
        attention_masks_list.append(attention_masks.unsqueeze(0))
    
    return input_ids_list, ref_input_ids_list, attention_masks_list, sentence_count_list

In [97]:
def summarize_attributions(attribution):
    attributions = attribution.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [98]:
def main(post, comment, reply, p_sentences, c_sentences, r_sentences, label):
    input_ids, ref_input_ids, attention_masks, sentence_counts = construct_input_ref_pair([post, comment, reply])
    
    
    one_hot_labels = torch.nn.functional.one_hot(torch.tensor(label), num_classes=len(labels))
    inputs = {'labels': one_hot_labels.type(torch.float).to(device),
          'input_ids1': input_ids[0].to(device),
          'input_ids2': input_ids[1].to(device),
          'input_ids3': input_ids[2].to(device),
          'attention_mask1': attention_masks[0].to(device),
          'attention_mask2': attention_masks[1].to(device),
          'attention_mask3': attention_masks[2].to(device),
          'sentence_count1': sentence_counts[0].to(device),
          'sentence_count2': sentence_counts[1].to(device),
          'sentence_count3': sentence_counts[2].to(device),
          'mode': 'triplet'
         }
    
    with torch.no_grad():
        embeddings = model(**inputs).hidden_states
        
    inputs = torch.stack(embeddings, dim=0)
    pred = forward_func_ig2(inputs, sentence_counts)
    print(f'answer: {label}, predict: {torch.argmax(pred)}')
    
    for i in range(len(labels)):
        attribution, delta = ig.attribute(inputs=inputs, target=i, additional_forward_args=sentence_counts, return_convergence_delta=True)
        attributions = summarize_attributions(attribution)
        f_attributions = torch.flatten(attributions)
        f_attributions = f_attributions[f_attributions.nonzero()].squeeze(1)
        
        all_sentences = [post, comment, reply]
        all_tokens = [item for all_sentences in all_sentences for item in all_sentences]
        
        score_vis = viz.VisualizationDataRecord(f_attributions,
                                        torch.max(torch.softmax(pred, dim=0)),
                                        torch.argmax(pred),  # predicted label
                                        label,  # true label
                                        p_sentences + ' ' + c_sentences + ' ' + r_sentences,
                                        f_attributions.sum(),
                                        all_tokens,
                                        delta)
        print(f_attributions)
        print('\033[1m', 'Visualization For Score', '\033[0m')
        viz.visualize_text([score_vis])

In [118]:
N = 23
main(post_sequences[N], comment_sequences[N], reply_sequences[N], post_contents[N], comment_bodies[N], reply_contents[N], satisfactions[N])

answer: 0, predict: 0
tensor([0.2182, 0.6925, 0.3547, 0.0668, 0.2128, 0.2503, 0.3623, 0.3215],
       dtype=torch.float64)
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally.I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice.",2.48,"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally. I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice."
,,,,


tensor([-0.4113,  0.8138, -0.1699,  0.1730,  0.2604,  0.0038, -0.1420, -0.1478],
       dtype=torch.float64)
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally.I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice.",0.38,"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally. I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice."
,,,,


tensor([ 0.0453, -0.8829, -0.1641, -0.1360, -0.2736, -0.1821, -0.2018, -0.1559],
       dtype=torch.float64)
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally.I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice.",-1.95,"I'm wondering if anyone here has thought about suicide methods that come close to killing you but can't. I'm not ready to die but I want to hurt. I want people to see physically how much I'm hurting emotionally. I've thought about jumping in front of cars or falling from non-lethal heights. Those aren't the nicest methods but I'll go with them if I can't find a better one. Donât do these things please. If other people donât believe how hurt you are emotionally, try your best to let it not affect you. No it needs to be seen. I've talked with people about it all I can. They don't think I'm being realistic. I'm capable of killing myself and one day I will if I don't get help. Maybe sooner rather than later. Believe it or not this is the responsible choice."
,,,,
