In [1]:
import os
import uuid
import json
import warnings
import random
import copy
import pickle
import torch

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics as tm

from graphlime import GraphLIME
from tqdm.notebook import tqdm
from torch_geometric.data import Data as tgData
from mumin_explainable.architectures.graphs import GAT

%matplotlib inline

warnings.filterwarnings('ignore')

In [2]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

# Load data

In [3]:
LANGUAGE = 'multilingual'
MODALITY = 'graph'
TEXT_DIM = 100


data = pickle.load(open(f'./data/preload/data-jupyter_{MODALITY}_{LANGUAGE}.pickle', 'rb'))
data.x.shape

torch.Size([4061, 812])

# Trustworthiness batch experiment

In [139]:
fpr = {}
fre = {}
ff1 = {}
LOOPS = 5

for TOP_K in tqdm([1, 2, 3, 5, 10]):

    EXP_ID = uuid.uuid4().hex
    
    prs = []
    res = []
    f1s = []

    for _ in tqdm(range(LOOPS)):

        # Setup trustworthiness
        trustworthy_features_list = random.sample(range(data.x.shape[1]), k=int(data.x.shape[1] * 0.7))
        untrustworthy_features_list = [i for i in range(data.x.shape[1]) if i not in trustworthy_features_list]

        untrustworthy_features = copy.deepcopy(data.x)
        untrustworthy_features[:,untrustworthy_features_list] = 0

        untrustworthy_features_data = tgData(
            x=untrustworthy_features,
            y=data.y,
            train_mask=data.train_mask,
            val_mask=data.val_mask,
            test_mask=data.test_mask,
            edge_index=data.edge_index
        )

        # Train model
        hparams = {
            'input_dim': data.num_node_features,
            'hidden_dim': 16,
            'output_dim': max(untrustworthy_features_data.y).item() + 1
        }
        
        MODEL_NAME = f'{LANGUAGE}_{MODALITY}_trustworthy_batch'


        model = GAT(**hparams).double()

        lr = 0.005
        epochs = 400

        model.train()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        f1_score = tm.classification.f_beta.F1Score(task='multiclass', num_classes=2, average='none')
        best_f1macro = -1

        for epoch in tqdm(range(epochs)):
            optimizer.zero_grad()
            
            output = model(data.x, data.edge_index)
            loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
            
            loss.backward()
            optimizer.step()
            
            f1 = f1_score(output[data.train_mask], data.y[data.train_mask])
            f1macro = torch.mean(f1)

            if f1macro > best_f1macro:
                best_f1macro = f1macro
                torch.save(model.state_dict(), f'./data/models/{MODEL_NAME}.pth')

        # Load best model
        model = GAT(**hparams).double()
        model.load_state_dict(torch.load(f'./data/models/{MODEL_NAME}.pth'))
        model.eval()

        # Eval explainability
        graphlime = GraphLIME(model, hop=2, rho=0.1, cached=True)

        # # all nodes predictions
        # original_preds = torch.argmax(model(data.x, data.edge_index).exp(), dim=1)
        # untrustworthy_preds = torch.argmax(model(untrustworthy_features_data.x, untrustworthy_features_data.edge_index).exp(), dim=1)
        # nodes_set = range(untrustworthy_features_data.x.shape[0])

        # filter test nodes predictions
        original_preds = torch.argmax(model(data.x, data.edge_index)[data.test_mask].exp(), dim=1)
        untrustworthy_preds = torch.argmax(model(untrustworthy_features_data.x, untrustworthy_features_data.edge_index)[untrustworthy_features_data.test_mask].exp(), dim=1)
        nodes_set = np.where(untrustworthy_features_data.test_mask == True)[0]

        # f1 = tm.classification.BinaryF1Score()(untrustworthy_preds, original_preds)
        # pre = tm.classification.BinaryPrecision()(untrustworthy_preds, original_preds)
        # rec = tm.classification.BinaryRecall()(untrustworthy_preds, original_preds)
        # print(f1, pre, rec)

        trust_preds = original_preds == untrustworthy_preds # True if trustworthy, False if untrustworthy | l_i in paper oracle
        mistrust_preds = original_preds != untrustworthy_preds # False if trustworthy, True if untrustworthy | l_i in paper oracle

        mistrust_idx = np.argwhere(untrustworthy_preds != original_preds).flatten()
        shouldnt_trust = set(mistrust_idx)
        mistrust = set()
        trust = set()
        trust_fn = lambda prev, curr: (prev > 0.5 and curr > 0.5) or (prev <= 0.5 and curr <= 0.5)
        trust_fn_all = lambda exp, unt: len([x[0] for x in exp if x[0] in unt]) == 0

        #####
        #####
        #####

        for node_idx in tqdm(nodes_set):
            node_idx = int(node_idx)
            coefs_original = graphlime.explain_node(node_idx, data.x, data.edge_index)
            feat_indices_original = coefs_original.argsort()[-TOP_K:]
            feat_indices_original = [idx for idx in feat_indices_original if coefs_original[idx] > 0.0]

            prev_tot = model(data.x, data.edge_index).exp()[node_idx]
            prev_tot2 = sum(coefs_original) + sum(coefs_original)/len(coefs_original)
            tot = prev_tot2 - sum(coefs_original[i] for i in feat_indices_original if i in untrustworthy_features_list)
            trust.add(node_idx) if trust_fn(tot, max(prev_tot)) else mistrust.add(node_idx)

        false_positives = set(trust).intersection(shouldnt_trust)
        true_positives = set(trust).difference(shouldnt_trust)
        false_negatives = set(mistrust).difference(shouldnt_trust)
        true_negatives = set(mistrust).intersection(shouldnt_trust)

        temp = len(true_positives) + len(false_positives)
        pr = len(true_positives) / temp if temp > 0 else 0
        
        temp = len(true_positives) + len(false_negatives)
        re = len(true_positives) / temp if temp > 0 else 0
        
        temp = pr + re
        f1 = 2 * (pre * re) / temp if temp > 0 else 0

        prs.append(pr)
        res.append(re)
        f1s.append(f1)
    
    fpr[TOP_K] = sum(prs)/len(prs)
    fre[TOP_K] = sum(res)/len(res)
    ff1[TOP_K] = sum(f1s)/len(f1s)

    try:
        with open('./data/results/trustworthiness-v1-batch-jupyter.json', 'r') as f:
            trustworthiness_dict = json.load(f)
    except:
        trustworthiness_dict = {}
    finally:
        with open('./data/results/trustworthiness-v1-batch-jupyter.json', 'w') as f:
            trustworthiness_dict[EXP_ID] = {
                'TOPK': TOP_K,
                'LOOPS': LOOPS,
                'modality': MODALITY,
                'language': LANGUAGE,
                'text_dim': TEXT_DIM,
                'pr': fpr[TOP_K],
                're': fre[TOP_K],
                'f1': float(ff1[TOP_K])
            }
            json.dump(trustworthiness_dict, f)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/439 [00:00<?, ?it/s]