In [1]:
import os
import json
import warnings
import random
import copy

import dgl
import torch
import torch_geometric as tg

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl.nn.pytorch as dglnn

from graphlime import GraphLIME
from mumin import MuminDataset

import torchmetrics as tm

from tqdm.notebook import tqdm
from collections import defaultdict

from torch_geometric.data import Data as tgData

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

from mumin_explainable.architectures.graphs import GAT
from mumin_explainable.processor.tweetnormalizer import normalizeTweet

from dotenv import load_dotenv

import mlflow


%matplotlib inline

warnings.filterwarnings('ignore')
_= torch.manual_seed(42)

# Setup mumin graph

In [2]:
size= 'small'
dataset_mumin = MuminDataset(
    twitter_bearer_token=os.getenv('TWITTER_BEARER_TOKEN'),
    size=size,
    dataset_path=f'./data/datasets/mumin-{size}.zip'
)
dataset_mumin.compile()
mumin_graph = dataset_mumin.to_dgl()
mumin_graph

2023-08-25 08:30:53,600 [INFO] Loading dataset
2023-08-25 08:31:56,689 [INFO] Outputting to DGL


Graph(num_nodes={'article': 1446, 'claim': 2083, 'hashtag': 27802, 'image': 1015, 'reply': 177816, 'tweet': 4061, 'user': 152038},
      num_edges={('article', 'has_article_inv', 'tweet'): 1890, ('claim', 'discusses_inv', 'tweet'): 4749, ('hashtag', 'has_hashtag_inv', 'tweet'): 2284, ('hashtag', 'has_hashtag_inv', 'user'): 49626, ('image', 'has_image_inv', 'tweet'): 1019, ('reply', 'posted_inv', 'user'): 177816, ('reply', 'quote_of', 'tweet'): 88495, ('reply', 'reply_to', 'tweet'): 78576, ('tweet', 'discusses', 'claim'): 4749, ('tweet', 'has_article', 'article'): 1890, ('tweet', 'has_hashtag', 'hashtag'): 2284, ('tweet', 'has_image', 'image'): 1019, ('tweet', 'mentions', 'user'): 1112, ('tweet', 'posted_inv', 'user'): 4061, ('tweet', 'quote_of_inv', 'reply'): 88495, ('tweet', 'reply_to_inv', 'reply'): 78576, ('tweet', 'retweeted_inv', 'user'): 12800, ('user', 'follows', 'user'): 17974, ('user', 'follows_inv', 'user'): 17974, ('user', 'has_hashtag', 'hashtag'): 49626, ('user', 'mentions

# Setup subgraph

In [3]:
user_df = dataset_mumin.nodes['user']
tweet_df = dataset_mumin.nodes['tweet']

user_posted_tweet_df = dataset_mumin.rels[('user', 'posted', 'tweet')]
user_posted_tweet_subgraph = dgl.edge_type_subgraph(mumin_graph, etypes=[('user', 'posted', 'tweet')])

## Filter language

In [4]:
LANG = 'multilingual'
tweet_ds_copy = dataset_mumin.nodes['tweet'].dropna()
lang_tweets = (tweet_ds_copy['lang'] == LANG).to_list()
del(tweet_ds_copy)

if LANG == 'multilingual':
    lang_tweets = [True] * len(lang_tweets)
tweet_train_mask = user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets)
tweet_val_mask = user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets)
tweet_test_mask = user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets)

# Setup dataset

In [5]:
edges_index = torch.cat([
    user_posted_tweet_subgraph.edges(etype='posted')[0].unsqueeze(0),
    user_posted_tweet_subgraph.edges(etype='posted')[1].unsqueeze(0)
], dim=0)
data = tgData(
    x=user_posted_tweet_subgraph.nodes['tweet'].data['feat'],
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=tweet_train_mask,
    val_mask=tweet_val_mask,
    test_mask=tweet_test_mask,
    edge_index=edges_index.long())

# Enhance with text-based features

In [6]:
from transformers import AutoModelForSequenceClassification

TEXT_DIM = 100

LANG_TOOL_MAP = {
    'multilingual': {
        'bertweet': AutoModel.from_pretrained('vinai/bertweet-base'),
        'tokenizer': AutoTokenizer.from_pretrained('vinai/bertweet-base', use_fast=False)
    },
    'en': {
        'bertweet': AutoModel.from_pretrained('vinai/bertweet-base'),
        'tokenizer': AutoTokenizer.from_pretrained('vinai/bertweet-base', use_fast=False)
    },
    'pt': {
        'bertweet': AutoModel.from_pretrained('melll-uff/bertweetbr'),
        'tokenizer': AutoTokenizer.from_pretrained('melll-uff/bertweetbr', normalization=True)
    },
    'es': {
        'bertweet': AutoModel.from_pretrained('pysentimiento/robertuito-base-cased'),
        'tokenizer': AutoTokenizer.from_pretrained('pysentimiento/robertuito-base-cased')
    }
}

bertweet = LANG_TOOL_MAP[LANG]['bertweet']
tokenizer = LANG_TOOL_MAP[LANG]['tokenizer']

def tweetencoder(x, text_dim):
    try:
        x = bertweet(torch.tensor([tokenizer.encode(normalizeTweet(x))])).pooler_output
    except:
        x = bertweet(torch.tensor([tokenizer.encode('')])).pooler_output
    return nn.Linear(768, text_dim)(x).tolist()[0]

In [7]:
tweet_df['text_encoding'] = str([0] * TEXT_DIM)
if LANG == 'multilingual':
    tweet_df['text_encoding'] = [tweetencoder(text, TEXT_DIM) for text in tweet_df['text']]
else:
    tweet_df['text_encoding'] = [tweetencoder(text, TEXT_DIM) if lang == LANG else str([0] * TEXT_DIM) for text,lang in zip(tweet_df['text'], tweet_df['lang'])]

In [8]:
new_embedding_columns = [f'emb{i}' for i in range(TEXT_DIM)]

tweet_embeddings_split_df = pd.DataFrame(
    [x if not isinstance(x, str) else eval(x) for x in tweet_df['text_encoding'].tolist()],
    index=tweet_df.index,
    columns=new_embedding_columns
)
tweet_df = pd.concat([tweet_df, tweet_embeddings_split_df], axis=1)
tweet_df.dropna(inplace=True)
display(tweet_df)

Unnamed: 0,tweet_id,text,created_at,lang,source,num_retweets,num_replies,num_quote_tweets,text_emb,lang_emb,...,emb90,emb91,emb92,emb93,emb94,emb95,emb96,emb97,emb98,emb99
0,1238947475471454220,Antes de llegar a los pulmones dura 4 días en ...,2020-03-14 21:57:51,es,Twitter for Android,8,3,0,"[-0.0467078, 0.25795, 0.119816095, 0.4975067, ...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...",...,0.126746,0.073426,0.018710,-0.030212,0.033471,-0.057429,-0.227508,-0.088288,0.034988,0.114474
1,1295062953000042496,Aeroporto de Dubai em chamas. 🤕😧 https://t.co/...,2020-08-16 18:20:43,pt,Twitter for Android,6,0,5,"[-0.04832051, 0.22119147, 0.10080599, 0.506648...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,0.070786,0.184929,-0.058850,0.004570,0.048659,0.096716,0.030813,-0.073876,0.015296,0.127499
2,1294614020008312832,Fogo 🔥 no aeroporto de Dubai 😱😱 https://t.co/2...,2020-08-15 12:36:49,pt,Twitter for Android,24,11,7,"[-0.049368992, 0.20724605, 0.09472715, 0.51769...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,-0.081196,0.013280,0.029083,0.077871,0.031103,0.023791,0.020638,0.101853,-0.089112,-0.006730
3,1294701863489744896,Fogo no aeroporto de Dubai. https://t.co/yhQDe...,2020-08-15 18:25:53,pt,Twitter for Android,15,7,4,"[-0.054105558, 0.22766814, 0.09675161, 0.51384...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,-0.020154,-0.046229,-0.083547,0.131931,0.195300,-0.098313,-0.019943,-0.003674,0.061658,-0.061630
4,1295124644085805057,Incendio en el aeropuerto de Dubai https://t.c...,2020-08-16 22:25:52,es,Twitter for Android,33,5,3,"[-0.043678686, 0.23453882, 0.11631639, 0.50836...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...",...,0.037292,0.022898,-0.079963,0.111332,-0.128751,0.083455,-0.012020,-0.061838,-0.022083,-0.029660
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4056,783488210454376448,"Yes, @ClintonFdn has accepted millions from f...",2016-10-05 02:05:20,en,Twitter Web Client,270,10,3,"[-0.055992845, 0.23851915, 0.12183203, 0.51381...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...",...,-0.045623,-0.122109,-0.088662,-0.004221,0.017393,-0.072896,-0.119485,0.111239,-0.140027,-0.025499
4057,783486380777373696,"#VPDebate fact check: Yes, the Clinton Foundat...",2016-10-05 01:58:04,en,SocialFlow,227,23,23,"[-0.066264085, 0.2269559, 0.11466829, 0.508162...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...",...,0.143425,0.006828,0.082407,-0.148371,-0.060529,0.064475,0.102566,0.012263,-0.128234,-0.075311
4058,783493825931206656,"Yes, the Clinton Foundation has accepted milli...",2016-10-05 02:27:39,en,Twitter Web Client,6,0,1,"[-0.054452505, 0.23137671, 0.1135482, 0.503517...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...",...,0.033770,-0.064164,0.016560,0.049262,-0.016212,-0.099194,0.010516,-0.031040,0.071047,0.125053
4059,1337737596881911809,📌 สถานการณ์โรคติดเชื้อไวรัสโคโรนา 2019 (COVID-...,2020-12-12 12:34:31,th,Twitter for Android,39,0,5,"[-0.06284507, 0.2391499, 0.11289272, 0.5124814...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,0.041474,0.088375,0.057698,0.025241,0.161460,-0.095974,0.194330,-0.131195,0.001408,0.137069


In [9]:
MODE = 'multimodal'

new_features_df = pd.DataFrame(index=range(user_posted_tweet_subgraph.nodes['tweet'].data['feat'].shape[0]))
new_features_df = new_features_df.join(tweet_df[new_embedding_columns])#.fillna(0)
new_features_tensor = torch.tensor(new_features_df.values).double()

mixed_features = new_features_tensor
if MODE == 'multimodal':
    mixed_features = torch.cat([user_posted_tweet_subgraph.nodes['tweet'].data['feat'], new_features_tensor], axis = 1).double()

edges_index = torch.cat([
    user_posted_tweet_subgraph.edges(etype='posted')[0].unsqueeze(0),
    user_posted_tweet_subgraph.edges(etype='posted')[1].unsqueeze(0)
], dim=0)
new_features_data = tgData(
    x=mixed_features,
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets),
    val_mask=user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets),
    test_mask=user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets),
    edge_index=edges_index.long())

# Trustworthy?

In [10]:
# new_features_data_x_mim = torch.ones((new_features_data.x.shape[0], new_features_data.x.shape[1], 2))
# new_features_data_x_mim[:,:,0] = new_features_data.x
# new_features_data_x_mim.shape

In [11]:
hparams = {
    'input_dim': new_features_data.num_node_features,
    'hidden_dim': 16,
    'output_dim': max(new_features_data.y).item() + 1
}

model = GAT(**hparams).double()

MODEL_NAME = f'multilingual_multimodal'

model.load_state_dict(torch.load(f'./data/models/{MODEL_NAME}.pth'))
model.eval()

print('test original_features')

original_features_classification_output = model(new_features_data.x, new_features_data.edge_index)
probas = original_features_classification_output.exp()
original_pred_label = torch.argmax(probas, dim=1)

test original_features


## Filter features

In [12]:
trustworthy_features_list = random.sample(range(new_features_data.x.shape[1]), k=int(new_features_data.x.shape[1] * 0.7))
untrustworthy_features_list = [i for i in range(new_features_data.x.shape[1]) if i not in trustworthy_features_list]

# trustworthy_features_mask = torch.zeros(new_features_data.x.shape[1], dtype=torch.bool)
# untrustworthy_features_mask = torch.zeros(new_features_data.x.shape[1], dtype=torch.bool)
# trustworthy_features_mask[trustworthy_features_list] = True
# untrustworthy_features_mask[untrustworthy_features_list] = True

# trustworthy_features = new_features_data.x[:, trustworthy_features_list]
# untrustworthy_features = new_features_data.x[:, untrustworthy_features_list]

untrustworthy_features = copy.deepcopy(new_features_data.x)
untrustworthy_features[:,untrustworthy_features_list] = 0

untrustworthy_features_data = tgData(
    x=untrustworthy_features,
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets),
    val_mask=user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets),
    test_mask=user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets),
    edge_index=edges_index.long())

### test mask

In [13]:
original_preds = torch.argmax(model(new_features_data.x, new_features_data.edge_index)[new_features_data.test_mask].exp(), dim=1)
untrustworthy_preds = torch.argmax(model(untrustworthy_features_data.x, untrustworthy_features_data.edge_index)[untrustworthy_features_data.test_mask].exp(), dim=1)

### no mask

In [14]:
original_preds = torch.argmax(model(new_features_data.x, new_features_data.edge_index).exp(), dim=1)
untrustworthy_preds = torch.argmax(model(untrustworthy_features_data.x, untrustworthy_features_data.edge_index).exp(), dim=1)

In [None]:
flipped_preds_size = []
mistrust_idx = np.argwhere(original_preds != untrustworthy_preds).flatten()
print(f'Number of suspect predictions {len(mistrust_idx)}')
shouldnt_trust = set(mistrust_idx)
flipped_preds_size.append(len(shouldnt_trust))

In [None]:
import collections
mistrust = collections.defaultdict(lambda:set())
trust = collections.defaultdict(lambda: set())
trust_fn = lambda prev, curr: (prev > 0.5 and curr > 0.5) or (prev <= 0.5 and curr <= 0.5)
trust_fn_all = lambda exp, unt: len([x[0] for x in exp if x[0] in unt]) == 0

graphlime = GraphLIME(model, hop=2, rho=0.1, cached=True)

In [None]:
topk = 100
num_noise_feats = []

for i in range(new_features_data.x.shape[0]):
    
    coefs_originals = graphlime.explain_node(i, untrustworthy_features_data.x, untrustworthy_features_data.edge_index)
    coefs_originals = np.abs(coefs_originals)

    coefs_untrust = graphlime.explain_node(i, untrustworthy_features_data.x, untrustworthy_features_data.edge_index)
    coefs_untrust = np.abs(coefs_untrust)

    ### debug
    if np.where(coefs_originals != 0)[0] != []:
        print(feat_indices)
        print(num_noise_feat)
        print([untrustworthy_features_data.x.shape[0] for idx in feat_indices])

In [None]:
list(filter(lambda x: x != 0, num_noise_feats))

In [None]:
mistrust_idx.shape, untrustworthy_features_data.x.shape, new_features_data.x.shape

In [None]:
mistrust_idx

In [None]:
shouldnt_trust = set(mistrust_idx)

In [None]:
shouldnt_trust

# ignore below

In [None]:
model_filtered.load_state_dict(torch.load(f'./data/models/{MODEL_NAME}.pth'))
model_filtered.eval()

accuracy = tm.Accuracy(task='multiclass', num_classes=2, average='none')
stats_score = tm.StatScores(task='multiclass', num_classes=2, average='none')
precision = tm.Precision(task='multiclass', num_classes=2, average='none')
recall = tm.Recall(task='multiclass', num_classes=2, average='none')
f1_score = tm.classification.f_beta.F1Score(task='multiclass', num_classes=2, average='none')

print('test')

output = model_filtered(trustworthy_features_data.x, trustworthy_features_data.edge_index)

f1 = f1_score(output[trustworthy_features_data.test_mask], trustworthy_features_data.y[trustworthy_features_data.test_mask])
f1macro = torch.mean(f1)

metrics = {
    'Accuracy': accuracy(output[trustworthy_features_data.test_mask], trustworthy_features_data.y[trustworthy_features_data.test_mask]),
    'Precision': precision(output[new_features_data.test_mask], trustworthy_features_data.y[trustworthy_features_data.test_mask]),
    'Recall': recall(output[trustworthy_features_data.test_mask], trustworthy_features_data.y[trustworthy_features_data.test_mask]),
    'Stats_Score': stats_score(output[trustworthy_features_data.test_mask], trustworthy_features_data.y[trustworthy_features_data.test_mask]),
    'F1': f1,
    'F1-macro': f1macro,
    'Bootstrap': bootstrap.compute()
}

print(metrics)

In [None]:
MODEL_NAME = f'{LANG}_{MODE}_filtered'

model_filtered.load_state_dict(torch.load(f'./data/models/{MODEL_NAME}.pth'))
model_filtered.eval()

print('test trustworthy_features')

trustworthy_features_classification_output = model_filtered(trustworthy_features_data.x, trustworthy_features_data.edge_index)
trustworthy_pred_label = torch.argmax(trustworthy_features_classification_output.exp(), dim=1)

In [None]:
np.where((original_pred_label == trustworthy_pred_label).numpy() == False)[0].shape, np.where((original_pred_label == trustworthy_pred_label).numpy() == True)[0].shape

In [None]:
explainer_original = GraphLIME(model, hop=2, rho=0.1, cached=True)
explainer_filtered = GraphLIME(model_filtered, hop=2, rho=0.1, cached=True)

count_untrustworthy = 0
for i in list(np.stack(np.argwhere((original_pred_label == trustworthy_pred_label).numpy() == False),axis=1)[0]):
    
    coefs_original = explainer_original.explain_node(int(i), new_features_data.x, new_features_data.edge_index)
    coefs_filtered = explainer_filtered.explain_node(int(i), trustworthy_features_data.x, trustworthy_features_data.edge_index)
    
    count_untrustworthy += 1 if np.where(coefs_original != 0)[0].shape == np.where(coefs_filtered != 0)[0].shape else 0

count_trustworthy = int(np.stack(np.argwhere((original_pred_label == trustworthy_pred_label).numpy() == False),axis=1)[0].shape[0]) - count_untrustworthy
print(count_untrustworthy, count_trustworthy)

100 - count_untrustworthy/len(original_pred_label)*100, 100 - count_trustworthy/len(original_pred_label)*100

In [None]:
explainer_original = GraphLIME(model, hop=2, rho=0.1, cached=True)

with open('count_untrustworthy.txt', 'w') as f:
    f.write('idx,count_untrustworthy,count_trustworthy\n')

MODEL_NAME = f'{LANG}_{MODE}_filtered'
lr = 0.005
epochs = 400
MAX_SAMPLES = 100

for sample_iter in tqdm(range(MAX_SAMPLES)):
    trustworthy_features_list = random.sample(range(new_features_data.x.shape[1]), k=int(new_features_data.x.shape[1] * 0.7))
    untrustworthy_features_list = [i for i in range(new_features_data.x.shape[1]) if i not in trustworthy_features_list]

    trustworthy_features = new_features_data.x[:, trustworthy_features_list]
    untrustworthy_features = new_features_data.x[:, untrustworthy_features_list]

    trustworthy_features_data = tgData(
        x=trustworthy_features,
        y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
        train_mask=user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets),
        val_mask=user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets),
        test_mask=user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets),
        edge_index=edges_index.long())

    hparams = {
        'input_dim': trustworthy_features_data.num_node_features,
        'hidden_dim': 16,
        'output_dim': max(trustworthy_features_data.y).item() + 1
    }

    model_filtered = GAT(**hparams).double()

    model_filtered.train()
    optimizer = optim.Adam(model_filtered.parameters(), lr=lr)

    f1_score = tm.classification.f_beta.F1Score(task='multiclass', num_classes=2, average='none')
    best_f1macro = -1

    for epoch in range(epochs):
        optimizer.zero_grad()
        
        output = model_filtered(trustworthy_features_data.x, trustworthy_features_data.edge_index)
        loss = F.nll_loss(output[trustworthy_features_data.train_mask], trustworthy_features_data.y[trustworthy_features_data.train_mask])
        
        loss.backward()
        optimizer.step()
        
        f1 = f1_score(output[trustworthy_features_data.train_mask], trustworthy_features_data.y[trustworthy_features_data.train_mask])
        f1macro = torch.mean(f1)

        if f1macro > best_f1macro:
            best_f1macro = f1macro
            torch.save(model_filtered.state_dict(), f'./data/models/{MODEL_NAME}.pth')

    model_filtered.eval()

    trustworthy_features_classification_output = model_filtered(trustworthy_features_data.x, trustworthy_features_data.edge_index)
    trustworthy_pred_label = torch.argmax(trustworthy_features_classification_output.exp(), dim=1)

    explainer_filtered = GraphLIME(model_filtered, hop=2, rho=0.1, cached=True)

    count_untrustworthy = 0
    for i in list(np.stack(np.argwhere((original_pred_label == trustworthy_pred_label).numpy() == False),axis=1)[0]):
        
        coefs_original = explainer_original.explain_node(int(i), new_features_data.x, new_features_data.edge_index)
        coefs_filtered = explainer_filtered.explain_node(int(i), trustworthy_features_data.x, trustworthy_features_data.edge_index)
        
        count_untrustworthy += 1 if np.where(coefs_original != 0)[0].shape == np.where(coefs_filtered != 0)[0].shape else 0

    count_trustworthy = int(np.stack(np.argwhere((original_pred_label == trustworthy_pred_label).numpy() == False),axis=1)[0].shape[0]) - count_untrustworthy

    result_count_untrustworthy = 100 - count_untrustworthy/len(original_pred_label)
    result_count_trustworthy = 100 - count_trustworthy/len(original_pred_label)
    with open('count_untrustworthy.txt', 'a') as f:
        f.write(f'{sample_iter},{result_count_untrustworthy},{result_count_trustworthy}\n')

f.close()

In [None]:
coefs_original = explainer_original.explain_node(25, new_features_data.x, new_features_data.edge_index)
coefs_filtered = explainer_filtered.explain_node(25, trustworthy_features_data.x, trustworthy_features_data.edge_index)

In [None]:
np.where(coefs_original != 0)[0].shape == np.where(coefs_filtered != 0)[0].shape

In [None]:
explainer = GraphLIME(model, hop=2, rho=0.1, cached=True)

for i in new_features_data.test_mask.nonzero().view(-1).tolist():
    coefs = explainer.explain_node(i, new_features_data.x, new_features_data.edge_index)
    if np.where(coefs != 0)[0].tolist() != []:
        print(i)
        print(coefs)
        break

In [None]:
explainer = GraphLIME(model, hop=2, rho=0.1, cached=False)

probas = model(new_features_data.x, new_features_data.edge_index).exp()
coefs = explainer.explain_node(2, new_features_data.x, new_features_data.edge_index)
print(coefs, len(coefs))
print(np.where(coefs != 0))
np.where(coefs != 0)


In [None]:
# list(tweet_df[tweet_df['lang'] == 'pt'].index)

In [None]:
explainer = GraphLIME(model, hop=2, rho=0.1, cached=True)

# for node_idx in range(data.x.shape[0]):
#     coefs = explainer.explain_node(node_idx, data.x, data.edge_index)
#     if len(set(np.where(coefs != 0)[0]).intersection(set([3,4,5]))) != 0:
#         print(node_idx)

# try: 102 | 118 | 127
node_idx = 91

probas = model(new_features_data.x, new_features_data.edge_index).exp()
print(probas[node_idx], torch.argmax(probas[node_idx]).item())
coefs = explainer.explain_node(node_idx, new_features_data.x, new_features_data.edge_index)

print(coefs, len(coefs))
print(np.where(coefs != 0))

In [None]:
k = 91
print(tweet_df.loc[k].text)
tweet_df.loc[k]

In [None]:
tweet_df = dataset_mumin.nodes['tweet']
claim_df = dataset_mumin.nodes['claim']
x = dataset_mumin.rels[('tweet', 'discusses', 'claim')]
y = (tweet_df.merge(x, left_index=True, right_on='src')
                          .merge(claim_df, left_on='tgt', right_index=True)
                          .reset_index(drop=True))

y[y['tweet_id'] == 1334273990039375876]

In [None]:
tweet_df = dataset_mumin.nodes['tweet']
reply_df = dataset_mumin.nodes['reply']
reply_quoteof_tweet_df = dataset_mumin.rels[('reply', 'reply_to', 'tweet')]
reply_quoteof_tweet_df = (reply_df.merge(quote_of_df, left_index=True, right_on='src')
                        .merge(tweet_df, left_on='tgt', right_index=True)
                        .reset_index(drop=True))

reply_quoteof_tweet_df[reply_quoteof_tweet_df['tweet_id_y'] == 1334273990039375876]

In [None]:
from transformers_interpret import SequenceClassificationExplainer, MultiLabelClassificationExplainer
# cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)
cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)

In [None]:
word_attributions = cls_explainer(normalizeTweet(tweet_df.loc[127].text))

In [None]:
cls_explainer.predicted_class_name

In [None]:
cls_explainer.visualize()

In [None]:
word_attributions = cls_explainer(normalizeTweet('Head of Pfizer Research: Covid Vaccine is Female Sterilization – Health and Money News https://t.co/IDRLSVmkLz'))

In [None]:
cls_explainer.visualize()

In [None]:
bertweet = AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=2)
F.softmax(bertweet(torch.tensor([tokenizer.encode(normalizeTweet('my favourite text'))])).logits)

In [None]:
F.softmax(bertweet(torch.tensor([tokenizer.encode(normalizeTweet('my favourite text'))])).logits, dim=1)

In [None]:
cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)

In [None]:
word_attributions = cls_explainer(normalizeTweet(tweet_df.loc[102].text))
cls_explainer.visualize()

In [None]:
dataset_mumin.nodes['claim'].loc[0]['embedding'].shape

In [None]:
dataset_mumin.nodes['tweet']

In [None]:
bertweet.config.id2label

In [None]:
tweet_discusses_claim_df

In [None]:
dataset_mumin.nodes['claim']