In [None]:
import os
import uuid
import json
import warnings
import random
import copy

import dgl
import torch
import torch_geometric as tg

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl.nn.pytorch as dglnn

from graphlime import GraphLIME
from mumin import MuminDataset

import torchmetrics as tm

from tqdm.notebook import tqdm
from collections import defaultdict

from torch_geometric.data import Data as tgData

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

from mumin_explainable.architectures.graphs import GAT
from mumin_explainable.processor.tweetnormalizer import normalizeTweet

from dotenv import load_dotenv

%matplotlib inline

warnings.filterwarnings('ignore')
_= torch.manual_seed(42)

# Setup mumin graph

In [None]:
size= 'small'
dataset_mumin = MuminDataset(
    twitter_bearer_token=os.getenv('TWITTER_BEARER_TOKEN'),
    size=size,
    dataset_path=f'./data/datasets/mumin-{size}.zip'
)
dataset_mumin.compile()
mumin_graph = dataset_mumin.to_dgl()
mumin_graph

# Setup subgraph

In [None]:
user_df = dataset_mumin.nodes['user']
tweet_df = dataset_mumin.nodes['tweet']

user_posted_tweet_df = dataset_mumin.rels[('user', 'posted', 'tweet')]
user_posted_tweet_subgraph = dgl.edge_type_subgraph(mumin_graph, etypes=[('user', 'posted', 'tweet')])

## Filter language

In [None]:
LANG = 'multilingual'
tweet_ds_copy = dataset_mumin.nodes['tweet'].dropna()
lang_tweets = (tweet_ds_copy['lang'] == LANG).to_list()
del(tweet_ds_copy)

if LANG == 'multilingual':
    lang_tweets = [True] * len(lang_tweets)
tweet_train_mask = user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets)
tweet_val_mask = user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets)
tweet_test_mask = user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets)

# Setup dataset

In [None]:
edge_index = torch.cat([
    user_posted_tweet_subgraph.edges(etype='posted')[0].unsqueeze(0),
    user_posted_tweet_subgraph.edges(etype='posted')[1].unsqueeze(0)
], dim=0)
graph_data = tgData(
    x=user_posted_tweet_subgraph.nodes['tweet'].data['feat'],
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=tweet_train_mask,
    val_mask=tweet_val_mask,
    test_mask=tweet_test_mask,
    edge_index=edge_index.long())

# Enhance with text-based features

In [None]:
TEXT_DIM = 100

LANG_TOOL_MAP = {
    'multilingual': {
        'bertweet': AutoModel.from_pretrained('vinai/bertweet-base'),
        'tokenizer': AutoTokenizer.from_pretrained('vinai/bertweet-base', use_fast=False)
    },
    'en': {
        'bertweet': AutoModel.from_pretrained('vinai/bertweet-base'),
        'tokenizer': AutoTokenizer.from_pretrained('vinai/bertweet-base', use_fast=False)
    },
    'pt': {
        'bertweet': AutoModel.from_pretrained('melll-uff/bertweetbr'),
        'tokenizer': AutoTokenizer.from_pretrained('melll-uff/bertweetbr', normalization=True)
    },
    'es': {
        'bertweet': AutoModel.from_pretrained('pysentimiento/robertuito-base-cased'),
        'tokenizer': AutoTokenizer.from_pretrained('pysentimiento/robertuito-base-cased')
    }
}

bertweet = LANG_TOOL_MAP[LANG]['bertweet']
tokenizer = LANG_TOOL_MAP[LANG]['tokenizer']

def tweetencoder(x, text_dim):
    try:
        x = bertweet(torch.tensor([tokenizer.encode(normalizeTweet(x))])).pooler_output
    except:
        x = bertweet(torch.tensor([tokenizer.encode('')])).pooler_output
    return nn.Linear(768, text_dim)(x).tolist()[0]

In [None]:
tweet_df['text_encoding'] = str([0] * TEXT_DIM)
if LANG == 'multilingual':
    tweet_df['text_encoding'] = [tweetencoder(text, TEXT_DIM) for text in tweet_df['text']]
else:
    tweet_df['text_encoding'] = [tweetencoder(text, TEXT_DIM) if lang == LANG else str([0] * TEXT_DIM) for text,lang in zip(tweet_df['text'], tweet_df['lang'])]

In [None]:
new_embedding_columns = [f'emb{i}' for i in range(TEXT_DIM)]

tweet_embeddings_split_df = pd.DataFrame(
    [x if not isinstance(x, str) else eval(x) for x in tweet_df['text_encoding'].tolist()],
    index=tweet_df.index,
    columns=new_embedding_columns
)
tweet_df = pd.concat([tweet_df, tweet_embeddings_split_df], axis=1)
tweet_df.dropna(inplace=True)
display(tweet_df)

text_features_df = pd.DataFrame(index=range(graph_data.x.shape[0]))
text_features_df = text_features_df.join(tweet_df[new_embedding_columns])#.fillna(0)
text_features_tensor = torch.tensor(text_features_df.values).double()

In [None]:
MODE = 'text'

if MODE == 'graph':
    modality_features = graph_data.x
elif MODE == 'text':
    modality_features = text_features_tensor
else: # multimodal
    modality_features = torch.cat([graph_data.x, text_features_tensor], axis=1).double()

modality_data = tgData(
    x=modality_features,
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets),
    val_mask=user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets),
    test_mask=user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets),
    edge_index=graph_data.edge_index)

### Generate new noisy features

In [None]:
EXP_ID = uuid.uuid4().hex
PER_NOISY = 0.0 # [0.0, 0.01, 0.1, 0.25, 0.5, 0.75, 1]
NUM_NOISY_FEATURES = int(modality_data.x.shape[1] * PER_NOISY)

noisy_features = torch.randn((modality_data.x.shape[0], NUM_NOISY_FEATURES))
noisy_features = noisy_features - noisy_features.mean(1, keepdim=True)

noisy_features_data = tgData(
    x=torch.cat([modality_data.x, noisy_features], dim=-1),
    y=user_posted_tweet_subgraph.nodes['tweet'].data['label'],
    train_mask=user_posted_tweet_subgraph.nodes['tweet'].data['train_mask'] & torch.tensor(lang_tweets),
    val_mask=user_posted_tweet_subgraph.nodes['tweet'].data['val_mask'] & torch.tensor(lang_tweets),
    test_mask=user_posted_tweet_subgraph.nodes['tweet'].data['test_mask'] & torch.tensor(lang_tweets),
    edge_index=modality_data.edge_index)

### train nosy model

In [None]:
hparams = {
    'input_dim': noisy_features_data.num_node_features,
    'hidden_dim': 16,
    'output_dim': max(noisy_features_data.y).item() + 1
}

MODEL_NAME = f'{LANG}_{MODE}_nosy'

model = GAT(**hparams).double()

lr = 0.005
epochs = 400

model.train()
optimizer = optim.Adam(model.parameters(), lr=lr)

accuracy = tm.Accuracy(task='multiclass', num_classes=2, average='none')
stats_score = tm.StatScores(task='multiclass', num_classes=2, average='none')
precision = tm.Precision(task='multiclass', num_classes=2, average='none')
recall = tm.Recall(task='multiclass', num_classes=2, average='none')
f1_score = tm.classification.f_beta.F1Score(task='multiclass', num_classes=2, average='none')
best_f1macro = -1

bootstrap = tm.BootStrapper(
    f1_score, num_bootstraps=200, sampling_strategy='multinomial', quantile=torch.tensor([0.05, 0.95])
)

for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    
    output = model(noisy_features_data.x, noisy_features_data.edge_index)
    loss = F.nll_loss(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask])
    
    loss.backward()
    optimizer.step()
    
    f1 = f1_score(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask])
    f1macro = torch.mean(f1)

    bootstrap.update(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask])

    if f1macro > best_f1macro:
        best_f1macro = f1macro
        torch.save(model.state_dict(), f'./data/models/{MODEL_NAME}.pth')

    if epoch % 10 == 0:
        metrics = {
            'Epoch': epoch,
            'Accuracy': accuracy(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask]),
            'Precision': precision(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask]),
            'Recall': recall(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask]),
            'Stats_Score': stats_score(output[noisy_features_data.train_mask], noisy_features_data.y[noisy_features_data.train_mask]),
            'F1': f1,
            'F1-macro': f1macro,
            'Bootstrap': bootstrap.compute()
        }
        print(metrics)

model.eval()

### Eval noisy model

In [None]:
model.load_state_dict(torch.load(f'./data/models/{MODEL_NAME}.pth'))
model.eval()

accuracy = tm.Accuracy(task='multiclass', num_classes=2, average='none')
stats_score = tm.StatScores(task='multiclass', num_classes=2, average='none')
precision = tm.Precision(task='multiclass', num_classes=2, average='none')
recall = tm.Recall(task='multiclass', num_classes=2, average='none')
f1_score = tm.classification.f_beta.F1Score(task='multiclass', num_classes=2, average='none')

print('test')

output = model(noisy_features_data.x, noisy_features_data.edge_index)

f1 = f1_score(output[noisy_features_data.test_mask], noisy_features_data.y[noisy_features_data.test_mask])
f1macro = torch.mean(f1)

metrics = {
    'Accuracy': accuracy(output[noisy_features_data.test_mask], noisy_features_data.y[noisy_features_data.test_mask]),
    'Precision': precision(output[noisy_features_data.test_mask], noisy_features_data.y[noisy_features_data.test_mask]),
    'Recall': recall(output[noisy_features_data.test_mask], noisy_features_data.y[noisy_features_data.test_mask]),
    'Stats_Score': stats_score(output[noisy_features_data.test_mask], noisy_features_data.y[noisy_features_data.test_mask]),
    'F1': f1,
    'F1-macro': f1macro,
    'Bootstrap': bootstrap.compute()
}

print(metrics)

In [None]:
graphlime = GraphLIME(model, hop=2, rho=0.1, cached=True)
test_nodes = np.where(noisy_features_data.test_mask == True)[0]
TOP_K = 10

In [None]:
num_noise_feats = []
per_noise_feats = []

for node_idx in tqdm(test_nodes):

    coefs = graphlime.explain_node(int(node_idx), noisy_features_data.x, noisy_features_data.edge_index)

    feat_indices = coefs.argsort()[-TOP_K:]
    feat_indices = [idx for idx in feat_indices if coefs[idx] > 0.0]
    num_noise_feat = sum(idx >= modality_data.x.shape[1] for idx in feat_indices)
    num_noise_feats.append(num_noise_feat)
    per_noise_feats.append(num_noise_feat / len(feat_indices) if len(feat_indices) > 0 else 0)

sns.distplot(num_noise_feats, hist=True, kde=True, color='red')

In [None]:
try:
    with open('./data/results/robustness-v2-jupyter.json', 'r') as f:
        robustness_dict = json.load(f)
except:
    robustness_dict = {}
finally:
    with open('./data/results/robustness-v2-jupyter.json', 'w') as f:
        robustness_dict[EXP_ID] = {
            'PER_NOISE': PER_NOISY,
            'NUM_NOISE_FEATURES': NUM_NOISY_FEATURES,
            'num_noise_feats': [int(x) for x in num_noise_feats],
            'avg_num_noise_feats': sum(num_noise_feats)/len(num_noise_feats),
            'per_noise_feats': [float(x) for x in per_noise_feats],
            'avg_per_noise_feats': sum(per_noise_feats)/len(per_noise_feats),
            'modality': MODE,
            'language': LANG
        }
        json.dump(robustness_dict, f)

print(robustness_dict[EXP_ID])

## Joint analysis

In [None]:
with open('./data/results/robustness-v2-jupyter.json', 'r') as f:
    robustness_dict = json.load(f)

robustness_df = pd.DataFrame(robustness_dict).T
robustness_df.columns = map(str.upper, robustness_df.columns)
robustness_df

In [None]:
# repensar essa visualização em termos de TPR e FPR (ROC) ou Precision e Recall (PR) ou AUC
sns.lineplot(data=robustness_df, x='AVG_PER_NOISE_FEATS', y='PER_NOISE', palette='rocket')
sns.lineplot(data=robustness_df, x='PER_NOISE', y='PER_NOISE', palette='rocket')

In [None]:
robustness_expand_df = robustness_df.drop(robustness_df[robustness_df['PER_NOISE'] == 0.0].index)
robustness_expand_df = robustness_expand_df.explode(['NUM_NOISE_FEATS', 'PER_NOISE_FEATS'])

sns.displot(data=robustness_expand_df, x='PER_NOISE_FEATS', kind='kde', label='PER_NOISE', hue='PER_NOISE', fill=True, palette='rocket')
sns.displot(data=robustness_expand_df, x='PER_NOISE_FEATS', kde=True, label='PER_NOISE', hue='PER_NOISE', palette='rocket')