In [None]:
import time
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data
from model import MLP, GCN, FCN_LP
import random

args = {}
args['seed'] = 42
args['device'] = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
args['hidden'] = 32
args['num_classes'] = 2
args['dropout'] = 0.5
args['lpaiters'] = 1
args['gcnnum'] = 4
args['epochs'] = 500
args['lr'] = 5e-5
args['weight_decay'] = 5e-4
def setup_seed(an_int):
    np.random.seed(an_int)
    random.seed(an_int)
    torch.manual_seed(an_int)
    torch.cuda.manual_seed(an_int)
    torch.cuda.manual_seed_all(an_int)
    torch.backends.cudnn.deterministic =True
setup_seed(args['seed'])
dataset_name = 'twitter' # twitter or weibo or pheme

In [None]:
from torch_geometric.data import Data
# Load data and embeddings
train_data = pd.read_csv('dataset/' + dataset_name + '/dataforGCN_train.csv')
test_data = pd.read_csv('dataset/' + dataset_name + '/dataforGCN_test.csv')
tweet_embeds = torch.load('dataset/' + dataset_name + '/ALLCAT_embeds_cross95.pt')
tweet_graph = torch.load('dataset/' + dataset_name + '/edge_cross95.pt')

# label to onehot vector
label_list_train = train_data["label"].tolist()
event_list_train = train_data["event"].tolist()

label_list_test = test_data["label"].tolist()
event_list_test = test_data["event"].tolist()

labels = []
for i, label_list in enumerate([label_list_train, label_list_test]):
    labels_i = torch.zeros([len(label_list), 2], requires_grad=False)
    for j, label in enumerate(label_list):
        labels_i[j] = torch.FloatTensor([1.0, 0.0]) if label == 1 else torch.FloatTensor([0.0, 1.0])
    labels.append(labels_i)

labels = torch.cat(labels, 0)
# Create data object
data = Data(
    x=tweet_embeds.float(),
    edge_index=tweet_graph.coalesce().indices(),
    edge_attr=tweet_graph.coalesce().values().unsqueeze(-1),
    train_mask=torch.tensor([True]*len(label_list_train) + [False]*(len(labels)-len(label_list_train))).bool(),
    test_mask=torch.tensor([False]*len(label_list_train) + [True]*(len(labels)-len(label_list_train))).bool(),
    y=labels
).to(args['device'])


In [None]:
#Splitting seen and unseen tweet
def get_data_splits(label_list_train, event_list_train, selected_events, unselected_events):
    event_map = {}
    for i, (label, event) in enumerate(zip(label_list_train, event_list_train)):
        if event in event_map:
            event_map[event][0].append(i) if label == 1 else event_map[event][1].append(i)
        else:
            event_map[event]= [[],[]]
            event_map[event][0].append(i) if label == 1 else event_map[event][1].append(i)

    seen_real, seen_fake, unseen_real, unseen_fake = [], [], [], []
    for event in selected_events:
        seen_real.extend(event_map[event][0])
        seen_fake.extend(event_map[event][1])
    for event in unselected_events:
        unseen_real.extend(event_map[event][0])
        unseen_fake.extend(event_map[event][1])

    return seen_real, seen_fake, unseen_real, unseen_fake


if dataset_name == 'weibo':
    all_tweets = set(range(0, len(label_list_train)))
    unseen = set(random.sample(all_tweets,  len(label_list_train)// 3))
    seen = list(all_tweets - unseen)
    seen_real = [idx for idx in seen if label_list_train[idx] == 1]
    seen_fake = [idx for idx in seen if label_list_train[idx] == 0]
    unseen_real = [idx for idx in unseen if label_list_train[idx] == 1]
    unseen_fake = [idx for idx in unseen if label_list_train[idx] == 0]    
if dataset_name == 'twitter':
    selected_events = ['boston','columbianChemicals', 'nepal',  'pigFish', 'bringback', 'sochi', 'malaysia', 'sandy', 'passport', 'underwater', 'livr']
    unselected_events = ['elephant', 'garissa', 'eclipse', 'samurai']
    seen_real, seen_fake, unseen_real, unseen_fake = get_data_splits(label_list_train, event_list_train, selected_events, unselected_events)
    seen = seen_real + seen_fake
if dataset_name == 'pheme':
    selected_events = ['Ottawa Shooting', 'sydney siege', 'Charlie Hebdo', 'GermanwingsCrash']
    unselected_events = ['Ferguson']
    seen_real, seen_fake, unseen_real, unseen_fake = get_data_splits(label_list_train, event_list_train, selected_events, unselected_events)
    seen = seen_real + seen_fake

In [None]:
from mmd import MMDLoss
def accuracy(output, labels):
    if output.ndim == 1:
        output = torch.where(output>=0.5, 1, 0) 
        correct = (output == labels).sum()
        accuracy = correct / len(labels)
    else:
        preds = output.max(1)[1].type_as(labels)   
        labels = labels.max(1)[1]
        correct = preds.eq(labels).double()
        correct = correct.sum()
        tp = torch.sum(preds * labels)
        fp = torch.sum(preds * (1 - labels))
        fn = torch.sum((1 - preds) * labels)
        tn = torch.sum((1 - preds) * (1 - labels))
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp + 1e-10) 
        recall = tp / (tp + fn + 1e-10) 
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
    return accuracy, precision, recall, f1
#model = GCN(tweet_embeds.shape[1], args['hidden'], args['num_classes'], args['dropout']).to(args['device'])
#model = MLP(tweet_embeds.shape[1], args['hidden'], args['num_classes'], args['dropout']).to(args['device'])
model = FCN_LP(tweet_embeds.shape[1], args['hidden'], args['num_classes'], args['dropout'], data.num_edges,
                args['lpaiters'], args['gcnnum']).to(args['device'])

optimizer = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
crition1 = nn.CrossEntropyLoss() 
crition2 = MMDLoss(kernel_type = 'linear')

best_score = 0

for epoch in range(args['epochs']):
    with torch.autograd.set_detect_anomaly(True):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        out, yhat, x = model(data)
        FCN_loss = crition1(out[seen], data.y[seen])
        LPN_loss = crition1(yhat[seen], data.y[seen])
        MMD_loss = crition2(x[unseen_real], x[seen_real]) + crition2(x[unseen_fake], x[seen_fake]) 
        loss_train = FCN_loss + LPN_loss + MMD_loss      
        acc_train, precision, recall, f1 = accuracy(yhat[data.train_mask], data.y[data.train_mask])   
        loss_train.backward()
        optimizer.step()
    
    model.eval()
    out, yhat, x = model(data)
    acc_test, precision_test, recall_test, f1_test = accuracy(yhat[data.test_mask], data.y[data.test_mask])

    if epoch % 10 == 0:

        print('Epoch: {:04d}'.format(epoch + 1),\
              'loss_train: {:.4f}'.format(loss_train.item()),\
              'acc_train: {:.4f}'.format(acc_train.item()),\
              'acc_test: {:.4f}'.format(acc_test.item()),\
              'time: {:.4f}s'.format(time.time() - t))