In [None]:
import os

import dgl
import torch
import torch_geometric as tg

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl.nn.pytorch as dglnn

from graphlime import GraphLIME
from mumin import MuminDataset

import torchmetrics as tm

from tqdm.notebook import tqdm
from collections import defaultdict

from torch_geometric.data import Data as tgData

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

from mumin_explainable.architectures.graphs import GAT
from mumin_explainable.processor.tweetnormalizer import normalizeTweet

from dotenv import load_dotenv

%matplotlib inline

# Setup mumin graph

In [None]:
size= 'small'

dataset_mumin = MuminDataset(
    twitter_bearer_token=os.getenv('TWITTER_BEARER_TOKEN'),
    size=size,
    dataset_path=f'./data/datasets/mumin-{size}.zip'
)
dataset_mumin.compile()
mumin_graph = dataset_mumin.to_dgl()
mumin_graph

In [None]:
plt.figure(figsize=(10, 7))
metagraph = mumin_graph.metagraph()
nx.draw_networkx(metagraph, 
                 pos=nx.shell_layout(metagraph), 
                 node_color='white', 
                 node_size=3000,
                 arrows=False)

# Setup subgraph

In [None]:
rel = ('user', 'posted', 'tweet')
posted_subgraph = dgl.edge_type_subgraph(mumin_graph, etypes=[rel])
train_mask = posted_subgraph.nodes['tweet'].data['train_mask']
val_mask = posted_subgraph.nodes['tweet'].data['val_mask']
test_mask = posted_subgraph.nodes['tweet'].data['test_mask']

plt.figure(figsize=(10, 7))
posted_metagraph = posted_subgraph.metagraph()
nx.draw_networkx(posted_metagraph, 
                 pos=nx.shell_layout(posted_metagraph), 
                 node_color='white', 
                 node_size=3000,
                 arrows=False)

# Setup dataset

In [None]:
edges_index = torch.cat([
    posted_subgraph.edges(etype='posted')[0].unsqueeze(0),
    posted_subgraph.edges(etype='posted')[1].unsqueeze(0)
], dim=0)
data = tgData(
    x=posted_subgraph.nodes['tweet'].data['feat'],
    y=posted_subgraph.nodes['tweet'].data['label'],
    train_mask=train_mask,
    val_mask=val_mask,
    test_mask=test_mask,
    edge_index=edges_index.long())

tweet_df = dataset_mumin.nodes['tweet']
reply_df = dataset_mumin.nodes['reply']
quote_of_df = dataset_mumin.rels[('reply', 'quote_of', 'tweet')]
reply_quote_tweet_df = (reply_df.merge(quote_of_df, left_index=True, right_on='src')
                          .merge(tweet_df, left_on='tgt', right_index=True)
                          .reset_index(drop=True))

# Train GAT

In [None]:
hparams = {
    'input_dim': data.num_node_features,
    'hidden_dim': 16,
    'output_dim': max(data.y).item() + 1
}

model = GAT(**hparams).double()

lr = 0.005
epochs = 400

model.train()
optimizer = optim.Adam(model.parameters(), lr=lr)

accuracy = tm.Accuracy()
stats_score = tm.StatScores()
precision_recall = tm.functional.precision_recall

for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    
    output = model(data.x, data.edge_index)
    loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
    
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        acc = accuracy(output[data.train_mask], data.y[data.train_mask])
        pr = precision_recall(output[data.train_mask], data.y[data.train_mask])
        ss = stats_score(output[data.train_mask], data.y[data.train_mask])
        print('Epoch: {:3d}, acc = {:.3f}, pr = {}, ss = {}'.format(epoch, acc, pr, ss))

model.eval()

In [None]:
accuracy = tm.Accuracy()
stats_score = tm.StatScores()#reduce='macro',num_classes=2)
precision_recall = tm.functional.precision_recall

print('test')
acc = accuracy(output[data.test_mask], data.y[data.test_mask])
pr = precision_recall(output[data.test_mask], data.y[data.test_mask], average='micro', num_classes=2)
ss = stats_score(output[data.test_mask], data.y[data.test_mask])
print('Epoch: {:3d}, acc = {:.3f}, pr = {}, ss = {}'.format(epoch, acc, pr, ss))

In [None]:
def get_all_nodes_explanation(model, explainer, data, dataset_mumin):

    feature_id_map = {
        'num_retweets': 0,
        'num_replies': 1,
        'num_quote_tweets': 2
    }

    user_df = dataset_mumin.nodes['user']
    tweet_df = dataset_mumin.nodes['tweet']
    reply_df = dataset_mumin.nodes['reply']

    reply_quoteof_tweet_df = dataset_mumin.rels[('reply', 'quote_of', 'tweet')]
    reply_quoteof_tweet_df = (reply_df.merge(quote_of_df, left_index=True, right_on='src')
                            .merge(tweet_df, left_on='tgt', right_index=True)
                            .reset_index(drop=True))

    user_posted_tweet_df = dataset_mumin.rels[('user', 'posted', 'tweet')]
    user_posted_tweet_df = (user_df.merge(user_posted_tweet_df, left_index=True, right_on='src')
                            .merge(tweet_df, left_on='tgt', right_index=True)
                            .reset_index(drop=True))

    user_posted_reply_df = dataset_mumin.rels[('user', 'posted', 'reply')]
    user_posted_reply_df = (user_df.merge(user_posted_reply_df, left_index=True, right_on='src')
                            .merge(tweet_df, left_on='tgt', right_index=True)
                            .reset_index(drop=True))

    for node_idx in range(data.x.shape[0]):
        
        coefs = explainer.explain_node(node_idx, data.x, data.edge_index)
        probas = model(data.x, data.edge_index).exp()
        # fact_or_fake = 'fact' if torch.argmax(probas[node_idx]).item() == 0 else 'fake'

        if tweet_df.iloc[node_idx]['num_quote_tweets'] != 0 and \
            feature_id_map['num_quote_tweets'] not in list(np.where(coefs != 0)[0]): # get only inferences explained by quotes

            tgt_tweet_id = dataset_mumin.nodes['tweet'].iloc[node_idx]['tweet_id']

            # manual traverse
            replies_src = list(reply_quoteof_tweet_df.query(f'tweet_id_y == {tgt_tweet_id}')['src'])                      
            quoters_ids = list(user_posted_reply_df.query(f'tgt in {str(replies_src)}')['user_id'])
            quoters_posts = user_posted_tweet_df[user_posted_tweet_df['user_id'].isin(quoters_ids)]

            if not quoters_posts.empty:
                # print(fact_or_fake, quoters_ids)
                print(quoters_ids)
                # break

In [None]:
user_df = dataset_mumin.nodes['user']
explainer = GraphLIME(model, hop=2, rho=0.1, cached=True)
get_all_nodes_explanation(model, explainer, data, dataset_mumin)

# Enhance with text-based features

In [None]:
from transformers import AutoModelForSequenceClassification

TEXT_DIM = 3
# bertweet = AutoModel.from_pretrained("vinai/bertweet-base")
# bertweet = AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=2)
bertweet = AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False)

In [None]:
def tweetencoder(x, text_dim):
    try:
        x = bertweet(torch.tensor([tokenizer.encode(normalizeTweet(x))])).pooler_output
    except RuntimeError:
        x = bertweet(torch.tensor([tokenizer.encode('')])).pooler_output
    
    return nn.Linear(768, text_dim)(x).tolist()[0]

In [None]:
# tweet_df['text_encoding'] = tweet_df['text'].apply(lambda x: tweetencoder(x))
tweet_df['text_encoding'] = str([0] * TEXT_DIM)
for i in range(len(tweet_df)):
    try:
        tweet_df.loc[i,'text_encoding'] = str(tweetencoder(tweet_df.loc[i]['text'], TEXT_DIM))
    except:
        pass

In [None]:
new_embedding_columns = [f'emb{i}' for i in range(TEXT_DIM)]

tweet_embeddings_split_df = pd.DataFrame(
    [eval(x) for x in tweet_df['text_encoding'].tolist()],
    index=tweet_df.index,
    columns=new_embedding_columns
)
tweet_df = pd.concat([tweet_df, tweet_embeddings_split_df], axis=1)
display(tweet_df)

In [None]:
new_features_df = pd.DataFrame(index=range(posted_subgraph.nodes['tweet'].data['feat'].shape[0]))
new_features_df = new_features_df.join(tweet_df[new_embedding_columns]).fillna(0)
new_features_tensor = torch.tensor(new_features_df.values).double()
mixed_features = new_features_tensor#torch.cat([posted_subgraph.nodes['tweet'].data['feat'], new_features_tensor], axis = 1).double()

edges_index = torch.cat([
    posted_subgraph.edges(etype='posted')[0].unsqueeze(0),
    posted_subgraph.edges(etype='posted')[1].unsqueeze(0)
], dim=0)
new_features_data = tgData(
    x=mixed_features,
    y=posted_subgraph.nodes['tweet'].data['label'],
    train_mask=posted_subgraph.nodes['tweet'].data['train_mask'],
    val_mask=posted_subgraph.nodes['tweet'].data['val_mask'],
    test_mask=posted_subgraph.nodes['tweet'].data['test_mask'],
    edge_index=edges_index.long())

In [None]:
hparams = {
    'input_dim': new_features_data.num_node_features,
    'hidden_dim': 16,
    'output_dim': max(new_features_data.y).item() + 1
}

model = GAT(**hparams).double()

lr = 0.005
epochs = 400

model.train()
optimizer = optim.Adam(model.parameters(), lr=lr)

accuracy = tm.Accuracy()
precision_recall = tm.functional.precision_recall

for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    
    output = model(new_features_data.x, new_features_data.edge_index)
    loss = F.nll_loss(output[new_features_data.train_mask], new_features_data.y[new_features_data.train_mask])
    
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        acc = accuracy(output[new_features_data.train_mask], new_features_data.y[new_features_data.train_mask])
        pr = precision_recall(output[new_features_data.train_mask], new_features_data.y[new_features_data.train_mask])
        print('Epoch: {:3d}, acc = {:.3f}, pr = {}'.format(epoch, acc, pr))

model.eval()

In [None]:
accuracy = tm.Accuracy()
stats_score = tm.StatScores()#reduce='macro',num_classes=2)
precision_recall = tm.functional.precision_recall
f1_score = tm.classification.f_beta.F1Score(num_classes=2, average='none')

print('test')
acc = accuracy(output[new_features_data.test_mask], new_features_data.y[new_features_data.test_mask])
pr = precision_recall(output[new_features_data.test_mask], new_features_data.y[new_features_data.test_mask], average='micro', num_classes=2)
ss = stats_score(output[new_features_data.test_mask], new_features_data.y[new_features_data.test_mask])
f1 = f1_score(output[new_features_data.test_mask], new_features_data.y[new_features_data.test_mask])
print('Epoch: {:3d}, acc = {:.3f}, pr = {}, f1 = {}, ss = {}'.format(epoch, acc, pr, f1, ss))

In [None]:
explainer = GraphLIME(model, hop=2, rho=0.1, cached=True)

# for node_idx in range(data.x.shape[0]):
#     coefs = explainer.explain_node(node_idx, data.x, data.edge_index)
#     if len(set(np.where(coefs != 0)[0]).intersection(set([3,4,5]))) != 0:
#         print(node_idx)

# try: 102 | 118 | 127
node_idx = 420

probas = model(new_features_data.x, new_features_data.edge_index).exp()
print(probas[node_idx], torch.argmax(probas[node_idx]).item())
coefs = explainer.explain_node(node_idx, new_features_data.x, new_features_data.edge_index)

print(coefs, len(coefs))
print(np.where(coefs != 0))

In [None]:
k = 205
print(tweet_df.loc[k].text)
tweet_df.loc[k]

In [None]:
tweet_df = dataset_mumin.nodes['tweet']
claim_df = dataset_mumin.nodes['claim']
x = dataset_mumin.rels[('tweet', 'discusses', 'claim')]
y = (tweet_df.merge(x, left_index=True, right_on='src')
                          .merge(claim_df, left_on='tgt', right_index=True)
                          .reset_index(drop=True))

y[y['tweet_id'] == 1334273990039375876]

In [None]:
tweet_df = dataset_mumin.nodes['tweet']
reply_df = dataset_mumin.nodes['reply']
reply_quoteof_tweet_df = dataset_mumin.rels[('reply', 'reply_to', 'tweet')]
reply_quoteof_tweet_df = (reply_df.merge(quote_of_df, left_index=True, right_on='src')
                        .merge(tweet_df, left_on='tgt', right_index=True)
                        .reset_index(drop=True))

reply_quoteof_tweet_df[reply_quoteof_tweet_df['tweet_id_y'] == 1334273990039375876]

In [None]:
from transformers_interpret import SequenceClassificationExplainer, MultiLabelClassificationExplainer
# cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)
cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)

In [None]:
word_attributions = cls_explainer(normalizeTweet(tweet_df.loc[127].text))

In [None]:
cls_explainer.predicted_class_name

In [None]:
cls_explainer.visualize()

In [None]:
word_attributions = cls_explainer(normalizeTweet('Head of Pfizer Research: Covid Vaccine is Female Sterilization – Health and Money News https://t.co/IDRLSVmkLz'))

In [None]:
cls_explainer.visualize()

In [None]:
bertweet = AutoModelForSequenceClassification.from_pretrained("vinai/bertweet-base", num_labels=2)
F.softmax(bertweet(torch.tensor([tokenizer.encode(normalizeTweet('my favourite text'))])).logits)

In [None]:
F.softmax(bertweet(torch.tensor([tokenizer.encode(normalizeTweet('my favourite text'))])).logits, dim=1)

In [None]:
cls_explainer = SequenceClassificationExplainer(bertweet,tokenizer)

In [None]:
word_attributions = cls_explainer(normalizeTweet(tweet_df.loc[102].text))
cls_explainer.visualize()

In [None]:
dataset_mumin.nodes['claim'].loc[0]['embedding'].shape

In [None]:
dataset_mumin.nodes['tweet']

In [None]:
bertweet.config.id2label