In [91]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
from sklearn.model_selection import train_test_split
from spacy.lang.en import English
from splitbert.textsplit import text_segmentation
from splitbert.splitbert import SplitBertConcatEncoderModel
from splitbert.splitbert import conduct_input_ids_and_attention_masks
from splitbert.splitbert import make_masks
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import tqdm

# Data Preparation

In [92]:
post_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_post.csv', encoding='UTF-8')
comment_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/liwc_comment.csv', encoding='UTF-8')
reply_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/avg_satisfaction_raw_0-999.csv', encoding='ISO-8859-1')

modes = [['seg', 'seg', 'snt']]

nlp = English()
nlp.add_pipe("sentencizer")

# satisfaction score (y)
satisfactions_float = list(reply_df['satisfy_composite'])
satisfactions = []

for s in satisfactions_float:
    if s < 3.5:
        satisfactions.append(0)
    elif s < 5:
        satisfactions.append(1)
    else:
        satisfactions.append(2)

reply_contents = list(reply_df['replyContent'])
post_contents = list(post_df['content'])
comment_bodies = list(comment_df['content'])


def get_sequences(contents, mode):
    sequences = []

    if mode == 'all':
        for content in contents:
            sequences.append([content])
    elif mode == 'seg':
        for content in contents:
            sentences = list(map(lambda x: str(x), list(nlp(content).sents)))
            sequences.append(text_segmentation(sentences))
    else:  # sentences
        for content in contents:
            sequences.append(list(map(lambda x: str(x), list(nlp(content).sents))))

    return sequences


for mode in modes:
    print(mode)
    post_sequences = get_sequences(post_contents, mode[0])
    comment_sequences = get_sequences(comment_bodies, mode[1])
    reply_sequences = get_sequences(reply_contents, mode[2])

    data = []
    max_post, max_comment, max_reply = 0, 0, 0
    i = 0
    for post, comment, reply, satisfaction, satisfaction_float in zip(post_sequences, comment_sequences,
                                                                          reply_sequences, satisfactions,
                                                                          satisfactions_float):
        if len(post) > max_post:
            max_post = len(post)
        if len(comment) > max_comment:
            max_comment = len(comment)
        if len(reply) > max_reply:
            max_reply = len(reply)

        data.append([i, post, comment, reply, satisfaction, satisfaction_float])
        i += 1

    print(max_post, max_comment, max_reply)
    max_count = max(max_post, max_comment, max_reply)
    print(max_count)

    columns = ['index', 'post_contents', 'comment_contents', 'reply_contents', 'label', 'score']
    df = pd.DataFrame(data, columns=columns)

    # data split (train & test sets)
    idx_train, idx_remain = train_test_split(df.index.values, test_size=0.20, random_state=42)
    idx_val, idx_test = train_test_split(idx_remain, test_size=0.50, random_state=42)

    train_df = df.iloc[idx_train]
    val_df = df.iloc[idx_val]
    test_df = df.iloc[idx_test]

    count_min_label = min(train_df['label'].value_counts())

    labels = [0, 1, 2]

    train_sample_df = pd.DataFrame([], columns=columns)

    for label in labels:
        tmp = train_df[train_df['label'] == label]
        tmp_sampled = tmp.sample(frac=1).iloc[:count_min_label]
        train_sample_df = pd.concat([train_sample_df, tmp_sampled])

    train_sample_df = train_sample_df.sample(frac=1)

['seg', 'seg', 'snt']
10 4 10
10


# Model Preparation

In [93]:
def forward_func_ig(inputs):
    now = 0

    for embeddings, count in zip(inputs, sentence_counts):
        embeddings = embeddings.unsqueeze(0)
        embeddings = embeddings.swapaxes(0, 1)

        embeddings = model.pe(embeddings)

        src_mask, src_key_padding_mask = make_masks(model.max_len, count, device)

        encoder_output = model.encoder(embeddings, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

        encoder_output = torch.mean(encoder_output[:count], dim=0)

        if now == 0:
            result_outputs = encoder_output
        else:
            result_outputs = torch.cat([result_outputs, encoder_output], dim=1)
        now += 1
    
    outputs = model.classifier1(result_outputs)
    logits = model.classifier2(outputs)
    return logits

In [94]:
def forward_func_ig2(inputs, sentence_counts):
    now = 0

    for encoder_output, count in zip(inputs, sentence_counts):
        encoder_output = encoder_output.swapaxes(0, 1)
        encoder_output = torch.mean(encoder_output[:count], dim=0)

        if now == 0:
            result_outputs = encoder_output
        else:
            result_outputs = torch.cat([result_outputs, encoder_output], dim=1)
        now += 1
        
    outputs = model.classifier1(result_outputs)
    logits = model.classifier2(outputs)
    return logits

ig = IntegratedGradients(forward_func_ig2)

In [95]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

model_path = f'../predicting-satisfaction-using-graphs/splitbert/model/seg_seg_snt/epoch_4.model'
model = SplitBertConcatEncoderModel(num_labels=len(labels), embedding_size=384, max_len=max_count, device='cpu', pc_segmentation=False)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to('cpu')
model.eval()

for param in model.sbert.parameters():
    param.requires_grad = False
    
for param in model.bert.parameters():
    param.requires_grad = False

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [96]:
def construct_input_ref_pair(triplet):
    input_ids_list, ref_input_ids_list, attention_masks_list, sentence_count_list = [], [], [], []
    
    for contents in triplet:
        result = tokenizer(contents, pad_to_max_length=True, truncation=True, max_length=256, return_tensors='pt')
        
        input_ids = result['input_ids']
        sentence_count_list.append(torch.tensor(len(input_ids)).unsqueeze(0))
        attention_masks = result['attention_mask']
        
        pad = (0, 0, 0, max_count-len(input_ids))
        input_ids = nn.functional.pad(input_ids, pad, "constant", 0)
        attention_masks = nn.functional.pad(attention_masks, pad, "constant", 0)
        ref_input_ids = torch.zeros_like(input_ids)

        input_ids_list.append(input_ids.unsqueeze(0))
        ref_input_ids_list.append(ref_input_ids.unsqueeze(0))
        attention_masks_list.append(attention_masks.unsqueeze(0))
    
    return input_ids_list, ref_input_ids_list, attention_masks_list, sentence_count_list

In [97]:
def summarize_attributions(attribution):
    attributions = attribution.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [214]:
def splitbert_integrated_gradient(index, post, comment, reply, p_sentences, c_sentences, r_sentences, label, score):
    
    def post_or_comment_or_reply(index):
        for i, sentences in enumerate(all_sentences):
            if all_tokens[index] in sentences:
                if i == 0:
                    return 'post'
                elif i == 1:
                    return 'comment'
                else:
                    return 'reply'
    
    input_ids, ref_input_ids, attention_masks, sentence_counts = construct_input_ref_pair([post, comment, reply])
    
    
    one_hot_labels = torch.nn.functional.one_hot(torch.tensor(label), num_classes=len(labels))
    inputs = {'labels': one_hot_labels.type(torch.float).to(device),
          'input_ids1': input_ids[0].to(device),
          'input_ids2': input_ids[1].to(device),
          'input_ids3': input_ids[2].to(device),
          'attention_mask1': attention_masks[0].to(device),
          'attention_mask2': attention_masks[1].to(device),
          'attention_mask3': attention_masks[2].to(device),
          'sentence_count1': sentence_counts[0].to(device),
          'sentence_count2': sentence_counts[1].to(device),
          'sentence_count3': sentence_counts[2].to(device),
          'mode': 'triplet'
         }
    
    with torch.no_grad():
        embeddings = model(**inputs).hidden_states
        
    inputs = torch.stack(embeddings, dim=0)
    pred = forward_func_ig2(inputs, sentence_counts)
    # print(f'answer: {label}, predict: {torch.argmax(pred)}')
    
    result = []
    
    attribution, delta = ig.attribute(inputs=inputs, target=torch.argmax(pred), additional_forward_args=sentence_counts, return_convergence_delta=True)
    attributions = summarize_attributions(attribution)
    f_attributions = torch.flatten(attributions)
    f_attributions = f_attributions[f_attributions.nonzero()].squeeze(1)
    abs_attributions = list(map(abs, map(float, f_attributions)))
    idx_attributions = []
    for j in range(len(abs_attributions)):
        idx_attributions.append((j, abs_attributions[j], f_attributions[j].item()))
    idx_attributions.sort(key=lambda x:x[1], reverse=True)
        
    top3 = idx_attributions[:3]
        
    all_sentences = [post, comment, reply]
    all_tokens = [item for all_sentences in all_sentences for item in all_sentences]
        
    score_vis = viz.VisualizationDataRecord(f_attributions,
                                            torch.max(torch.softmax(pred, dim=0)),
                                            torch.argmax(pred),  # predicted label
                                            label,  # true label
                                            p_sentences + ' ' + c_sentences + ' ' + r_sentences,
                                            f_attributions.sum(),
                                            all_tokens,
                                            delta)
    raw_text = ' '.join(post) + ' '.join(comment) + ' '.join(reply)
        
    where = []
        
    for j in range(len(top3)):
        where.append(post_or_comment_or_reply(top3[j][0]))
        
    # print(max_where, min_where)
        
    # print(f_attributions)
    # print('\033[1m', 'Visualization For Score', '\033[0m')
    # viz.visualize_text([score_vis])
        
    result.append([index, post, comment, reply, score, all_tokens[top3[0][0]], top3[0][2], where[0]])
    result.append([index, post, comment, reply, score, all_tokens[top3[1][0]], top3[1][2], where[1]])
    result.append([index, post, comment, reply, score, all_tokens[top3[2][0]], top3[2][2], where[2]])
    
    return result, label, torch.argmax(pred).item()

In [202]:
train_index_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/train_index.csv', encoding='UTF-8')
val_index_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/dataset/val_index.csv', encoding='UTF-8')

train_index = sorted(list(train_index_df.idx.values))
val_index = sorted(list(val_index_df.idx.values))

In [203]:
print(len(train_index))
print(len(val_index))

546
100


In [204]:
def main(index_list):
    result_list = []
    label_pred_list = []
    
    for i in range(len(index_list)):
        result, label, pred = splitbert_integrated_gradient(i, post_sequences[i], comment_sequences[i], reply_sequences[i], post_contents[i], comment_bodies[i], reply_contents[i], satisfactions[i], satisfactions_float[i])
        result_list.extend(result)
        
        for i in range(3):
            label_pred_list.append((label, pred))
        
    return result_list, label_pred_list   

In [215]:
if __name__ == "__main__":
    result_list, label_pred_list = main(val_index)
    path = '../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/'
    columns = ['idx', 'post_text', 'comment_text', 'reply_text', 'score', 'attr_sentence', 'attr_score', 'origin']
    
    print(len(result_list), len(label_pred_list))
    
    for i in range(3):
        for j in range(3):
            result = []
            for k in range(len(result_list)):
                if label_pred_list[k] == (i, j):
                    result.append(result_list[k])
            result_df = pd.DataFrame(result, columns=columns)
            result_df.to_csv(path + f'label_{i}_pred_{j}_attribution.csv')



300 300


In [216]:
for i in range(3):
    for j in range(3):
        print(f'label: {i}, pred: {j}')
        
        result_df = pd.read_csv(f'../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/label_{i}_pred_{j}_attribution.csv', encoding='UTF-8')
        print(result_df.origin.value_counts())
        print()

label: 0, pred: 0
reply      24
comment    17
post       10
Name: origin, dtype: int64

label: 0, pred: 1
reply      5
comment    3
post       1
Name: origin, dtype: int64

label: 0, pred: 2
Series([], Name: origin, dtype: int64)

label: 1, pred: 0
reply      16
comment    13
post        7
Name: origin, dtype: int64

label: 1, pred: 1
reply      67
post       18
comment    17
Name: origin, dtype: int64

label: 1, pred: 2
reply      11
comment     6
post        1
Name: origin, dtype: int64

label: 2, pred: 0
reply      2
comment    1
Name: origin, dtype: int64

label: 2, pred: 1
reply      12
post        2
comment     1
Name: origin, dtype: int64

label: 2, pred: 2
reply      47
comment    12
post        7
Name: origin, dtype: int64



In [212]:
result_df = pd.DataFrame(result_list, columns=columns)
result_df.origin.value_counts()

reply      184
comment     70
post        46
Name: origin, dtype: int64

In [217]:
result_df

Unnamed: 0.1,Unnamed: 0,idx,post_text,comment_text,reply_text,score,attr_sentence,attr_score,origin
0,0,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,Thank you.,0.775491,reply
1,1,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,This is completely new to me and I wanted to k...,0.566843,reply
2,2,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,"Depression is not an excuse for this behavior,...",-0.252374,comment
3,3,5,"[' Hello. I have had insecurities, anxiety, de...","[""I didn't mean necessarily trying harder. Jus...","['Yeah that would be good.', ""I'll try to find...",5.05,Thank you.,0.678371,reply
4,4,5,"[' Hello. I have had insecurities, anxiety, de...","[""I didn't mean necessarily trying harder. Jus...","['Yeah that would be good.', ""I'll try to find...",5.05,I didn't mean necessarily trying harder. Just ...,-0.499317,comment
...,...,...,...,...,...,...,...,...,...
61,61,97,"[""I made it though the day. I had a breakdown ...","['Hey, if you want to talk about it just messa...","[""Thank you, if it gets too hard and I'm actua...",5.65,I made it though the day. I had a breakdown un...,-0.142034,post
62,62,97,"[""I made it though the day. I had a breakdown ...","['Hey, if you want to talk about it just messa...","[""Thank you, if it gets too hard and I'm actua...",5.65,"Hey, if you want to talk about it just message...",-0.124270,comment
63,63,99,"[""Hi I'm 25.. and am a failure in life..I have...","["" Doesn't sound like your rubbish. It sounds ...","[""I really appreciate your comment.. I know th...",6.15,I really appreciate your comment.. I know ther...,0.994649,reply
64,64,99,"[""Hi I'm 25.. and am a failure in life..I have...","["" Doesn't sound like your rubbish. It sounds ...","[""I really appreciate your comment.. I know th...",6.15,Doesn't sound like your rubbish. It sounds li...,-0.086146,comment


In [220]:
result_df1 = pd.read_csv(f'../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/label_0_pred_0_attribution.csv', encoding='UTF-8')
result_df1

Unnamed: 0.1,Unnamed: 0,idx,post_text,comment_text,reply_text,score,attr_sentence,attr_score,origin
0,0,3,"["" Almost every day for months, I've gone to a...",['Then you have nothing to lose. Tell her dude.'],"[""I just can't, I'm just way too shy for it.""]",3.3,"I just can't, I'm just way too shy for it.",0.898281,reply
1,1,3,"["" Almost every day for months, I've gone to a...",['Then you have nothing to lose. Tell her dude.'],"[""I just can't, I'm just way too shy for it.""]",3.3,Then you have nothing to lose. Tell her dude.,0.437093,comment
2,2,3,"["" Almost every day for months, I've gone to a...",['Then you have nothing to lose. Tell her dude.'],"[""I just can't, I'm just way too shy for it.""]",3.3,"I've considered that the easiest way, and the...",-0.043678,post
3,3,9,"["" I was never really a jealous person. I like...",['you could try to explain this dilemma to the...,"['Normally they believe my pretending.', ""Some...",3.25,Normally they believe my pretending.,0.704133,reply
4,4,9,"["" I was never really a jealous person. I like...",['you could try to explain this dilemma to the...,"['Normally they believe my pretending.', ""Some...",3.25,you could try to explain this dilemma to them....,0.693229,comment
5,5,9,"["" I was never really a jealous person. I like...",['you could try to explain this dilemma to the...,"['Normally they believe my pretending.', ""Some...",3.25,Sometimes it's just really difficult and tirin...,0.134231,reply
6,6,10,['Im the guy no one suspects is as messed up a...,"['As bad as I feel about the depression, and I...","['no he came out swinging first, he has had an...",2.45,"no he came out swinging first, he has had and ...",0.75338,reply
7,7,10,['Im the guy no one suspects is as messed up a...,"['As bad as I feel about the depression, and I...","['no he came out swinging first, he has had an...",2.45,"As bad as I feel about the depression, and I a...",0.646939,comment
8,8,10,['Im the guy no one suspects is as messed up a...,"['As bad as I feel about the depression, and I...","['no he came out swinging first, he has had an...",2.45,Im the guy no one suspects is as messed up as ...,0.11785,post
9,9,11,[' I finally did it. It took me a week to work...,"[""You are saying this because you don't know h...","['Bruh, I went years without having anybody to...",3.1,"Bruh, I went years without having anybody to s...",0.553607,reply


In [223]:
result_df2 = pd.read_csv(f'../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/label_1_pred_1_attribution.csv', encoding='UTF-8')
result_df2

Unnamed: 0.1,Unnamed: 0,idx,post_text,comment_text,reply_text,score,attr_sentence,attr_score,origin
0,0,1,"[""I don't really know why, and maybe it's just...",['I use to love taking cold showers when I was...,"[""I hate it when I'm not really depressed."", ""...",4.25,How'd you end up taking cold showers when you ...,0.687577,reply
1,1,1,"[""I don't really know why, and maybe it's just...",['I use to love taking cold showers when I was...,"[""I hate it when I'm not really depressed."", ""...",4.25,"I'm always a little depressed, but when it get...",0.508866,reply
2,2,1,"[""I don't really know why, and maybe it's just...",['I use to love taking cold showers when I was...,"[""I hate it when I'm not really depressed."", ""...",4.25,I use to love taking cold showers when I was y...,0.391237,comment
3,3,4,"[""Can you love when you're depressed? When dep...","[""Yes, of course. I think it is because we lov...","[""It's romantic relationships I struggle with....",3.80,I'm ok for about 4 weeks then I just go numb a...,0.710685,reply
4,4,4,"[""Can you love when you're depressed? When dep...","[""Yes, of course. I think it is because we lov...","[""It's romantic relationships I struggle with....",3.80,My emotions just seem to shut down and then I ...,0.500283,reply
...,...,...,...,...,...,...,...,...,...
97,97,93,"["" I went to my doctor and told him I want to ...","[""No shit small world Haha. What do you do for...","['I work at a best buy.', 'How do like doing t...",4.60,I work at a best buy.,0.507964,reply
98,98,93,"["" I went to my doctor and told him I want to ...","[""No shit small world Haha. What do you do for...","['I work at a best buy.', 'How do like doing t...",4.60,No shit small world Haha. What do you do for w...,0.189059,comment
99,99,96,"["" Of course, these aren't the only emotions I...",['Which med are you on? What is it officially ...,"['150 mg Bupropion.', 'To increase dopamine an...",3.90,Which med are you on? What is it officially pr...,0.744541,comment
100,100,96,"["" Of course, these aren't the only emotions I...",['Which med are you on? What is it officially ...,"['150 mg Bupropion.', 'To increase dopamine an...",3.90,That's what my doctor said.,0.390975,reply


In [222]:
result_df3 = pd.read_csv(f'../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/label_2_pred_2_attribution.csv', encoding='UTF-8')
result_df3

Unnamed: 0.1,Unnamed: 0,idx,post_text,comment_text,reply_text,score,attr_sentence,attr_score,origin
0,0,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,Thank you.,0.775491,reply
1,1,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,This is completely new to me and I wanted to k...,0.566843,reply
2,2,2,"["" I am 30 years old and my girlfriend is 24 a...","[""Depression is not an excuse for this behavio...","['Thank you.', 'This is completely new to me a...",5.35,"Depression is not an excuse for this behavior,...",-0.252374,comment
3,3,5,"[' Hello. I have had insecurities, anxiety, de...","[""I didn't mean necessarily trying harder. Jus...","['Yeah that would be good.', ""I'll try to find...",5.05,Thank you.,0.678371,reply
4,4,5,"[' Hello. I have had insecurities, anxiety, de...","[""I didn't mean necessarily trying harder. Jus...","['Yeah that would be good.', ""I'll try to find...",5.05,I didn't mean necessarily trying harder. Just ...,-0.499317,comment
...,...,...,...,...,...,...,...,...,...
61,61,97,"[""I made it though the day. I had a breakdown ...","['Hey, if you want to talk about it just messa...","[""Thank you, if it gets too hard and I'm actua...",5.65,I made it though the day. I had a breakdown un...,-0.142034,post
62,62,97,"[""I made it though the day. I had a breakdown ...","['Hey, if you want to talk about it just messa...","[""Thank you, if it gets too hard and I'm actua...",5.65,"Hey, if you want to talk about it just message...",-0.124270,comment
63,63,99,"[""Hi I'm 25.. and am a failure in life..I have...","["" Doesn't sound like your rubbish. It sounds ...","[""I really appreciate your comment.. I know th...",6.15,I really appreciate your comment.. I know ther...,0.994649,reply
64,64,99,"[""Hi I'm 25.. and am a failure in life..I have...","["" Doesn't sound like your rubbish. It sounds ...","[""I really appreciate your comment.. I know th...",6.15,Doesn't sound like your rubbish. It sounds li...,-0.086146,comment


In [118]:
main(train_index, '../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/train_set_result_abs.csv')
main(val_index, '../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/test_set_result_abs.csv')

In [108]:
train_result_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/train_set_result.csv', encoding='UTF-8')
test_result_df = pd.read_csv('../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/test_set_result.csv', encoding='UTF-8')

In [109]:
print(train_result_df.max_where.value_counts())
print(train_result_df.min_where.value_counts())

reply      937
post       372
comment    329
Name: max_where, dtype: int64
reply      671
comment    568
post       399
Name: min_where, dtype: int64


In [110]:
print(test_result_df.max_where.value_counts())
print(test_result_df.min_where.value_counts())

reply      168
post        74
comment     58
Name: max_where, dtype: int64
reply      119
comment    107
post        74
Name: min_where, dtype: int64


In [127]:
train_result_df2 = pd.read_csv('../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/train_set_result_abs.csv', encoding='UTF-8')
test_result_df2 = pd.read_csv('../predicting-satisfaction-using-graphs/csv/integrated_gradient_result/test_set_result_abs.csv', encoding='UTF-8')

In [132]:
train_result_df2.columns = ['', 'idx', 'raw_text', 'pred', 'label', 'target', 'attr_sentence', 'max_where']
test_result_df2.columns = ['', 'idx', 'raw_text', 'pred', 'label', 'target', 'attr_sentence', 'max_where']

In [133]:
print(train_result_df2.max_where.value_counts())

reply      1474
comment     103
post         61
Name: max_where, dtype: int64


In [131]:
print(test_result_df2.max_where.value_counts())

reply      269
comment     18
post        13
Name: max_where, dtype: int64


In [134]:
correct_df = test_result_df2[test_result_df2['pred'] == test_result_df2['label']]