In [1]:
# https://www.aclweb.org/anthology/P18-1184.pdf
# https://github.com/aykutfirat/pyTorchTree
# https://github.com/liamge/Pytorch_ReNN

# https://github.com/inyukwo1/tree-lstm
# https://github.com/dasguptar/treelstm.pytorch/blob/master/treelstm/model.py

# https://github.com/unbounce/pytorch-tree-lstm

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import random
from treelib import Node, Tree
import networkx as nx

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split

import nltk
from nltk.util import ngrams
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
from ekphrasis.classes.preprocessor import TextPreProcessor
from ekphrasis.classes.tokenizer import SocialTokenizer
from ekphrasis.dicts.emoticons import emoticons
import gensim
from gensim.models import KeyedVectors

from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm, trange, tqdm_notebook

from IPython.core.debugger import set_trace

In [3]:
args = {
    'data_dir': '/data/rumor_detection/data/rumor_acl/rumor_detection_acl2017/twitter15/',
    'tweet_content_file': 'tweet_contents.txt',
    'tree_dir': 'tree',
    'label_file': 'label.txt',
    'w2v': 'twitter_preprocess_3_w2c_400.txt',

    'max_graph_size': 50,
    'K': 2,
    'hidden_dim': 200,
    'target_size': 4,
    'batch_size': 8,
    'learning_rate': 1e-3,
    'n_epoches': 70,
    'logging_steps': 100,
    'do_eval': True,
    'aggregator': 'mean',
    'n_splits': 5,
    'seed': 1234,
}

In [4]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(args['seed'])

## Data

In [5]:
def create_text_processor():
    text_processor = TextPreProcessor(
            normalize=['url', 'email', 'percent', 'money', 'phone', 'user',
                       'time', 'date', 'number'],
            fix_html=True,
            segmenter="twitter",
            corrector="twitter",

            unpack_hashtags=True,
            unpack_contractions=True,
            spell_correct_elong=True,

            # tokenizer=SocialTokenizer(lowercase=True).tokenize,
            tokenizer=RegexpTokenizer(r'\w+').tokenize,

            dicts=[emoticons]
        )

    return text_processor

def remove_stopword(tokens):
    stop_words = stopwords.words('english')
#     stop_words.append('url')
    return [word for word in tokens if word not in stop_words]

def stemming(tokens, ps):
    tokens = [ps.stem(token) for token in tokens]
    return tokens

def lemmatizer(tokens, wn):
    tokens = [wn.lemmatize(token) for token in tokens]
    return tokens

def remove_last_url(tokens):
    if len(tokens) > 0 and tokens[-1] == 'url':
        return tokens[:-1]
    else:
        return tokens
    
def pre_process(s):
    text = s.content
    text = text.replace("\/", '/')
    text = text.lower()

    tokens = text_processor.pre_process_doc(text)
    tokens = remove_stopword(tokens)
    tokens = stemming(tokens, ps)
    tokens = lemmatizer(tokens, wn)
    # tokens = remove_last_url(tokens)
    n_grams = set.union(set(ngrams(tokens, 1)), set(ngrams(tokens, 2)))
    return n_grams

In [6]:
word_vectors = KeyedVectors.load_word2vec_format(args['data_dir'] + args['w2v'], binary=False)
embed_dim = word_vectors.vector_size
text_processor = create_text_processor()
ps = nltk.PorterStemmer()
wn = nltk.WordNetLemmatizer()

def load_tweet_content(tweet_content_file):
    def embed_content(s):
        tokens = s.content_tokens
        content_embedding = torch.tensor([word_vectors[token] for token in tokens if token in word_vectors], dtype=torch.float)
        content_embedding = torch.mean(content_embedding, axis=0)
        if torch.isnan(content_embedding).any():
            content_embedding = torch.zeros((embed_dim, ))
        return content_embedding

    content_df = pd.read_csv(tweet_content_file, sep='\t', header=None, names=['id', 'content'])
    content_df['content_tokens'] = content_df.apply(pre_process, axis=1)
    content_df['content_embedding'] = content_df.apply(embed_content, axis=1)
    content_dict = {row['id']:row['content_embedding'] for i, row in content_df.iterrows()}
    
    return content_dict

Reading twitter - 1grams ...
Reading twitter - 2grams ...
Reading twitter - 1grams ...


  regexes = {k.lower(): re.compile(self.expressions[k]) for k, v in


In [19]:
def load_rumor_trees(tree_dir_path, content_dict):
    trees = {}
    for f in os.listdir(tree_dir_path):
        file_path = os.path.join(tree_dir_path, f)

        if os.path.isfile(file_path) and '.txt' in file_path:
            tree = Tree()
            tweet_ids = []
            root_id = int(f.split('.')[0])
            tweet_ids.append(root_id)
            if root_id in content_dict:
                content = content_dict[root_id]
            else:
                content = torch.zeros((embed_dim, ))
            tree.create_node(tag=root_id, identifier=root_id, data=content)
            with open(file_path, 'r') as file:
                for line in file:
                    line_arr = line.split("'")
                    if 'ROOT' not in line:
                        user1 = int(line_arr[1])
                        tweet1 = int(line_arr[3])
                        user2 = int(line_arr[7])
                        tweet2 = int(line_arr[9])
                        
                        if tweet2 not in tweet_ids: 
                            tweet_ids.append(tweet2)
                        
                        if tweet2 not in tree.nodes:
                            if tweet2 in content_dict:
                                content = content_dict[tweet2]
                            else:
                                content = torch.zeros((embed_dim, ))
                            tree.create_node(tag=tweet2, identifier=tweet2, parent=tweet1, data=content)
                
                tweet_ids.reverse()
                trees[root_id] = (tweet_ids, tree)
        
    return trees

In [8]:
def load_labels(label_file):
    label_df = pd.read_csv(label_file, sep=':', header=None, names=['label', 'id'])
    label_df['label'] = label_df['label'].map({'unverified': 0, 'non-rumor': 1, 'true': 2, 'false': 3})
    label_dict = {row['id']:row['label'] for i, row in label_df.iterrows()}
    
    return label_dict

## Model

In [9]:
class RumorDataset(Dataset):
    def __init__(self, ids_list, tree_list, label_list):
        super(RumorDataset, self).__init__()
        self.ids_list = ids_list
        self.tree_list = tree_list
        self.label_list = label_list

    def __getitem__(self, item):
        return (self.ids_list[item], self.tree_list[item], self.label_list[item])

    def __len__(self):
        return len(self.label_list)

In [49]:
class RumorModel(nn.Module):
    def __init__(self, embed_dim, hidden_dim, target_size):
        super(RumorModel, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        self.gru = nn.GRUCell(embed_dim, hidden_dim, bias=False)
        self.linear_out = nn.Linear(hidden_dim, target_size)
        
    def forward(self, node_list, tree):
        node_out_dict = {}
        for node in node_list:
            node_input = tree.get_node(node).data
            node_hidden = torch.zeros((self.hidden_dim, ))
            childrens = tree.children(node)
            for child in childrens:
                node_hidden += node_out_dict[child.identifier]
                
            node_out_dict[node] = self.gru(node_input.unsqueeze(0), node_hidden.unsqueeze(0)).squeeze()
            
        last_node_hidden = node_out_dict[node_list[-1]]
        output = self.linear_out(last_node_hidden)
        return output

## Train

In [47]:
content_dict = load_tweet_content(os.path.join(args['data_dir'], args['tweet_content_file']))
trees = load_rumor_trees(os.path.join(args['data_dir'], args['tree_dir']), content_dict)
label_dict = load_labels(os.path.join(args['data_dir'], args['label_file']))
ids_list = []
tree_list = []
label_list = []

for root, (tweet_ids, tree) in trees.items():
    ids_list.append(tweet_ids)
    tree_list.append(tree)
    label_list.append(label_dict[root])

In [50]:
splits = list(StratifiedKFold(n_splits=args['n_splits'], shuffle=True, random_state=args['seed']).split(tree_list, label_list))

for idx, (train_idx, val_idx) in enumerate(splits):
    print('Train Fold {}'.format(idx))
    
    train_ids_list = [ids_list[i] for i in train_idx]
    train_tree_list = [tree_list[i] for i in train_idx]
    train_label_list = [label_list[i] for i in train_idx]
    
    valid_ids_list = [ids_list[i] for i in val_idx]
    valid_tree_list = [tree_list[i] for i in val_idx]
    valid_label_list = [label_list[i] for i in val_idx]
    
    model = RumorModel(embed_dim=embed_dim, hidden_dim=args['hidden_dim'], 
                       target_size=args['target_size'])
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=args['learning_rate'])

    for epoch in trange(args['n_epoches'], desc='Epoch'):
        model.train()
        tr_loss = 0.0

        for tweet_ids, tree, label in tqdm_notebook(zip(train_ids_list, train_tree_list, train_label_list)):
            preds = model(tweet_ids, tree)
            loss = criterion(preds.unsqueeze(0), torch.tensor(label).unsqueeze(0))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            tr_loss += loss.item()

        train_loss = tr_loss / len(label_list)
        print(f"Epoch {epoch}, train loss {train_loss}")

Epoch:   0%|          | 0/70 [00:00<?, ?it/s]

Train Fold 0


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   1%|▏         | 1/70 [00:17<20:16, 17.63s/it]


Epoch 0, train loss 1.1083291183438218


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   3%|▎         | 2/70 [00:35<20:03, 17.69s/it]


Epoch 1, train loss 1.1082199005945492


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   4%|▍         | 3/70 [00:52<19:41, 17.64s/it]


Epoch 2, train loss 1.108195432284583


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   6%|▌         | 4/70 [01:10<19:24, 17.65s/it]


Epoch 3, train loss 1.10818769317246


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   7%|▋         | 5/70 [01:28<19:01, 17.57s/it]


Epoch 4, train loss 1.108184933984167


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:   9%|▊         | 6/70 [01:45<18:41, 17.52s/it]


Epoch 5, train loss 1.1081838314188959


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  10%|█         | 7/70 [02:04<18:45, 17.87s/it]


Epoch 6, train loss 1.108183347261869


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  11%|█▏        | 8/70 [02:22<18:38, 18.04s/it]


Epoch 7, train loss 1.1081831255744183


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  13%|█▎        | 9/70 [02:40<18:13, 17.93s/it]


Epoch 8, train loss 1.1081830177069032


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  14%|█▍        | 10/70 [02:58<18:04, 18.08s/it]


Epoch 9, train loss 1.1081829648590602


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  16%|█▌        | 11/70 [03:16<17:47, 18.10s/it]


Epoch 10, train loss 1.1081829390384246


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  17%|█▋        | 12/70 [03:37<18:14, 18.87s/it]


Epoch 11, train loss 1.1081829274553359


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  19%|█▊        | 13/70 [03:59<18:49, 19.81s/it]


Epoch 12, train loss 1.1081829222268582


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  20%|██        | 14/70 [04:22<19:29, 20.88s/it]


Epoch 13, train loss 1.1081829186875811


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  21%|██▏       | 15/70 [04:44<19:18, 21.06s/it]


Epoch 14, train loss 1.1081829148265514


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  23%|██▎       | 16/70 [05:04<18:40, 20.74s/it]


Epoch 15, train loss 1.108182915952685


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  24%|██▍       | 17/70 [05:23<17:48, 20.15s/it]


Epoch 16, train loss 1.1081829169179425


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  26%|██▌       | 18/70 [05:42<17:12, 19.85s/it]


Epoch 17, train loss 1.1081829149874276


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:  27%|██▋       | 19/70 [05:59<16:19, 19.20s/it]


Epoch 18, train loss 1.1081829147461133


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))







KeyboardInterrupt: 