<a href="https://colab.research.google.com/github/verma-saloni/Thesis-Work/blob/main/27_10_22_politifact_dgl_GraphSAGE_text_embs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -qq install jsonlines

In [2]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [3]:
from pathlib import Path
base_dir = Path("/gdrive/MyDrive/ResearchFND")
assert base_dir.exists()

## Data

In [4]:
import pandas as pd
import ast
import json

In [5]:
dataset_id = 'politifact'
text_embeddings = 'ft' # options: sbert, ft

In [6]:
df = pd.read_csv(base_dir/f'{dataset_id}_agg.csv')
df.head(2)

Unnamed: 0,title,text,tweets,retweets,label,url,tweet_ids,num_retweets,log_num_retweets,num_tweets,log_num_tweets
0,Actress Emma Stone ‘For the first time in his...,,[],"['1020554564334964741', '1020817527046197248',...",fake,,[],2911,7.976595,0,0.0
1,Breaking President Trump makes English the of...,,[],[],fake,,[],0,0.0,0,0.0


In [7]:
with open(base_dir/'t2u.json') as f:
    t2u = json.load(f)

with open(base_dir/'users_info.json') as f:
    users_info = json.load(f)

In [8]:
df['tweets'] = df.tweets.map(ast.literal_eval)

In [9]:
users_tweeted = df.tweets.map(lambda x: [int(e['user_id']) for e in x])
users_retweeted = df.retweets.map(lambda x: [t2u[str(e)] for e in x if (str(e) in t2u)])

In [10]:
len(users_tweeted), sum(users_tweeted.map(len) > 0)

(894, 149)

## GNN

### Data

In [11]:
#%%capture
!pip install dgl wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dgl
  Downloading dgl-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 5.0 MB/s 
[?25hCollecting wandb
  Downloading wandb-0.13.4-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 52.0 MB/s 
Collecting psutil>=5.8.0
  Downloading psutil-5.9.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (291 kB)
[K     |████████████████████████████████| 291 kB 62.9 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)
[K     |████████████████████████████████| 166 kB 16.4 MB/s 
[?25hCollecting

In [12]:
%env DGLBACKEND=pytorch

env: DGLBACKEND=pytorch


In [13]:
import os
import json
import jsonlines
import numpy as np
import torch
import dgl

import wandb
import IPython.display as ipd

In [14]:
u2i = {}

follow_src = []
follow_dst = []
with jsonlines.open(base_dir/"followers.jsonl") as reader:
    for line in reader:
        v = line["user_id"]
        if v not in u2i:
            u2i[v] = len(u2i)
        for u in line["followers"]:
            if u not in u2i:
                u2i[u] = len(u2i)
            follow_src.append(u2i[u])
            follow_dst.append(u2i[v])

In [15]:
with jsonlines.open(base_dir/"following.jsonl") as reader:
    for line in reader:
        u = line["user_id"]
        if u not in u2i:
            u2i[u] = len(u2i)
        for v in line["following"]:
            if v not in u2i:
                u2i[v] = len(u2i)
            follow_src.append(u2i[u])
            follow_dst.append(u2i[v])

In [16]:
for u, info in users_info.items():
    u = int(u)
    if u not in u2i:
        u2i[u] = len(u2i)
    for v in info['followers']:
        v = int(v)
        if v not in u2i:
            u2i[v] = len(u2i)
        follow_src.append(u2i[v])
        follow_dst.append(u2i[u])
    for v in info['friends']:
        v = int(v)
        if v not in u2i:
            u2i[v] = len(u2i)
        follow_src.append(u2i[u])
        follow_dst.append(u2i[v])

In [17]:
tweet_src = []
tweet_dst = []

for v, l in users_tweeted.iteritems():
    if not len(l):
        continue
    for u in l:
        u = int(u)
        if u in u2i:
            tweet_src.append(u2i[u])
            tweet_dst.append(v)

In [18]:
for v, l in users_retweeted.iteritems():
    if not len(l):
        continue
    
    for u in l:
        u = int(u)
        if u not in u2i:
            u2i[u] = len(u2i)
        tweet_src.append(u2i[u])
        tweet_dst.append(v)

In [19]:
text_embs = np.load(base_dir/f'{dataset_id}_{text_embeddings}_fulltext_embeddings.npy')
text_embs.shape

(894, 300)

In [20]:
follow_src = torch.tensor(follow_src)
follow_dst = torch.tensor(follow_dst)
tweet_src = torch.tensor(tweet_src)
tweet_dst = torch.tensor(tweet_dst)

graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'tweet', 'article'): (tweet_src, tweet_dst),
    ('article', 'tweeted-by', 'user'): (tweet_dst, tweet_src)})

graph.nodes['user'].data['feat'] = torch.arange(graph.num_nodes('user'))
graph.nodes['article'].data['feat'] = torch.tensor(text_embs)
graph.nodes['article'].data['label'] = torch.tensor((df.label=="real").to_numpy()).long()

In [21]:
graph

Graph(num_nodes={'article': 894, 'user': 639764},
      num_edges={('article', 'tweeted-by', 'user'): 4587, ('user', 'follow', 'user'): 695148, ('user', 'followed-by', 'user'): 695148, ('user', 'tweet', 'article'): 4587},
      metagraph=[('article', 'user', 'tweeted-by'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by'), ('user', 'article', 'tweet')])

In [22]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(shuffle=True, random_state=124)

In [23]:
labels = graph.ndata['label']['article']

train_idx, valid_idx = next(skf.split(labels, labels))

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [25]:
sampler = dgl.dataloading.NeighborSampler([10, 10])
train_loader = dgl.dataloading.DataLoader(
    graph,
    {'article':train_idx},
    sampler,
    device=device,
    batch_size=64,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

In [26]:
eval_sampler = dgl.dataloading.NeighborSampler([-1, -1])
eval_loader = dgl.dataloading.DataLoader(
    graph,
    {'article':valid_idx},
    eval_sampler,
    device=device,
    batch_size=64,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

In [27]:
batch = next(iter(train_loader))



### Model

In [28]:
from collections import defaultdict
import torch.nn as nn
import torch.nn.functional as F

d_emb_dict = defaultdict(lambda: 64)

def flatten_dict(d):
    for k, v in d.items():
        d[k] = v.flatten(1)
    return d

In [29]:
class NodeEmbedding(nn.Module):

    def __init__(self, n_nodes:dict, d_in:dict, d_emb:dict, proj_nodes=None, embed_nodes=None):
        super().__init__()
        self.proj_nodes = proj_nodes if proj_nodes is not None else list(d_in.keys())
        self.embed_nodes = embed_nodes if embed_nodes is not None else list(n_nodes.keys())
        self.emb = nn.ModuleDict({k:nn.Embedding(n_nodes[k], d_emb) for k in self.embed_nodes})
        self.proj = nn.ModuleDict({k:nn.Linear(d_in[k], d_emb, bias=False) for k in self.proj_nodes})
        self.init()

    def forward(self, nx):
        out = {}
        for k, m  in self.emb.items():
            out[k] = m(nx[k])
        for k, m  in self.proj.items():
            out[k] = m(nx[k])
        return out

    def init(self):
        for _, m in self.emb.items():
            torch.nn.init.xavier_uniform_(m.weight)
        for _, m in self.proj.items():
            torch.nn.init.xavier_uniform_(m.weight)

In [30]:
class Residual(nn.Module):

    def __init__(self, conv):
        super().__init__()
        self.conv = conv

    def forward(self, graph, x):
        h = self.conv(graph, x)
        res = x[1]
        return h + res

In [31]:
in_proj = NodeEmbedding({k:graph.num_nodes(k) for k in ["user"]}, {"article":text_embs.shape[1]}, 64)
conv = dgl.nn.HeteroGraphConv({rel:dgl.nn.SAGEConv(64, 64, 'pool') for rel in graph.etypes})

In [32]:
blocks = batch[-1]
block = blocks[0]
x = block.ndata['feat']

with torch.no_grad():
    h = in_proj(x)
    res = conv(block, h)

In [33]:
res['article'].shape, res['user'].shape

(torch.Size([64, 64]), torch.Size([71, 64]))

In [34]:
class Encoder(torch.nn.Module):

    def __init__(self, d_in, d_h, etypes, dropout=0.0, agg='pool'):
        super().__init__()
        self.layers = nn.ModuleList([
            dgl.nn.HeteroGraphConv({
                rel : dgl.nn.SAGEConv(d_in, d_h, agg) for rel in etypes
            }),
            dgl.nn.HeteroGraphConv({
                rel : dgl.nn.SAGEConv(d_h, d_h, agg) for rel in etypes
            })
        ])

    def forward(self, blocks, x):
        
        for layer, block in zip(self.layers, blocks):
            x = layer(block, x)
        return x

In [35]:
class GNN(nn.Module):

    def __init__(self, g, d_h:int, tgt_ntype:str, emb_nodes:list=['user'], proj_nodes:list=['article']):
        super().__init__()
        self.tgt_ntype = tgt_ntype
        self.in_proj = NodeEmbedding(
            {k:g.num_nodes(k) for k in emb_nodes}, 
            {k:graph.ndata['feat'][k].shape[1] for k in proj_nodes},
            d_h
        )
        self.encoder = Encoder(d_h, d_h, g.etypes)
        self.head = nn.Linear(d_h, 2)

    def forward(self, blocks, x):
        h = self.in_proj(x)
        h = self.encoder(blocks, h)
        return self.head(h[self.tgt_ntype])

    @torch.no_grad()
    def get_embeddings(self, graph, x):
        h = self.emb(x)
        h = self.encoder(graph, h)
        return h[self.tgt_ntype]

In [36]:
model = GNN(graph, 128, 'article')

with torch.no_grad():
    logits = model(blocks, x)

In [37]:
logits.shape

torch.Size([64, 2])

In [38]:
def accuracy(logits, labels):
    return (logits.argmax(-1) == labels).float().mean()

In [39]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(shuffle=True, random_state=124)

In [40]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

metrics = [accuracy_score, f1_score, precision_score, recall_score]
def get_name(score_func):
    return score_func.__name__.split("_")[0]

In [41]:
class AverageMeter:

    def __init__(self, store_vals=False, store_avgs=False):
        self.store_vals = store_vals
        self.store_avgs = store_avgs
        if store_vals: self.values = []
        if store_avgs: self.avgs = []
        self.tot, self.n = 0, 0

    def update(self, v, n=1):
        if self.store_vals: self.values.append(v)
        self.n += n
        self.tot += v*n

    @property
    def avg(self):
        if self.n == 0:
            return
        return self.tot / self.n

    def reset(self):
        if self.store_avgs and self.avg: self.avgs.append(self.avg)
        self.tot, self.n = 0, 0

In [42]:
def train(fold, train_idx, valid_idx, params):

    model = GNN(graph, 128, 'article')
    opt = torch.optim.Adam(model.parameters(), params['lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=10, verbose=True)

    train_loss = AverageMeter(store_avgs=True)
    train_acc = AverageMeter(store_avgs=True)
    valid_loss = AverageMeter(store_avgs=True)
    valid_acc = AverageMeter(store_avgs=True)

    best_acc = 0
    for epoch in range(params['n_epochs']):
        model.train()
        for batch in train_loader:
            blocks = batch[-1]
            x = blocks[0].ndata['feat']
            logits = model(blocks, x)
            
            labels = blocks[-1].dstdata['label']['article']
            loss = F.cross_entropy(logits, labels)
            acc = accuracy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            train_loss.update(loss.item(), len(labels))
            train_acc.update(acc, len(labels))

        model.eval()
        for i, batch in enumerate(eval_loader):
            blocks = batch[-1]
            x = blocks[0].ndata['feat']
            with torch.no_grad():
                logits = model(blocks, x)

                labels = blocks[-1].dstdata['label']['article']
                val_loss = F.cross_entropy(logits, labels)
                val_acc = accuracy(logits, labels)

                valid_loss.update(val_loss.item(), len(labels))
                valid_acc.update(val_acc, len(labels))
        
        scheduler.step(valid_loss.avg)
        wandb.log({'train_loss':loss.item(), 'train_acc':acc, 'valid_loss':val_loss.item(), 'valid_acc':val_acc}, step=epoch)
        print(f"{epoch+1:>3}: Train loss {train_loss.avg:.4f}, acc {train_acc.avg:.4f}%; validation loss {valid_loss.avg:.4f}, acc {valid_acc.avg:.4f}%")
        
        if valid_acc.avg >= best_acc:
            best_acc = valid_acc.avg
            torch.save(model.state_dict(), f'models/model-{fold}.pt')
        
        train_loss.reset()
        train_acc.reset()
        valid_loss.reset()
        valid_acc.reset()

        

    # load best model and evaluate
    model.load_state_dict(torch.load(f'models/model-{fold}.pt'))
    model.eval()
    preds = []
    targs = []
    for i, batch in enumerate(eval_loader):
        blocks = batch[-1]
        x = blocks[0].ndata['feat']
        with torch.no_grad():
            logits = model(blocks, x)

            labels = blocks[-1].dstdata['label']['article']
            
        preds.append(logits.argmax(-1).cpu().numpy())
        targs.append(labels.cpu().numpy())
    preds = np.concatenate(preds)
    targs = np.concatenate(targs)
    eval_results = {get_name(f):f(y_pred=preds, y_true=targs) for f in metrics}
    print("Final evaluation results:")
    for k,v in eval_results.items():
        print(f"{k:<16}{v:.4f}")
    
    wandb.log(eval_results)
    wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
                            y_true=targs, preds=preds,
                            class_names=["Fake", "Real"])})

    return {
        'train_loss':train_loss,
        'train_acc':train_acc,
        'valid_loss':valid_loss,
        'valid_acc':valid_acc
    }

In [43]:
params = {
    "n_epochs":100,
    'bs': 16,
    'lr':1e-2,
    "seed":124,
}

labels = graph.ndata['label']['article']

In [45]:
if not os.path.exists('models'):
    os.mkdir('models')

WANDB_ENTITY = 'saloniteam'
WANDB_PROJECT = 'fnd'
GROUP = f"{dataset_id}-{text_embeddings}-fulltext-graphsage"
for fold_id, (train_idx, valid_idx) in enumerate(skf.split(labels, labels)):
    ipd.clear_output()
    with wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, group=GROUP, name=f"{GROUP}-fold-{fold_id}") as run:
        log = train(fold_id, train_idx, valid_idx, params)
    break

[34m[1mwandb[0m: Currently logged in as: [33msaloni[0m ([33msaloniteam[0m). Use [1m`wandb login --relogin`[0m to force relogin




  1: Train loss 0.5628, acc 0.6783%; validation loss 0.4586, acc 0.7821%
  2: Train loss 0.4407, acc 0.8238%; validation loss 0.4032, acc 0.8771%


0,1
train_acc,▁█
train_loss,█▁
valid_acc,▁█
valid_loss,▁█

0,1
train_acc,1.0
train_loss,0.36462
valid_acc,0.86275
valid_loss,0.38857


RuntimeError: ignored

In [None]:
model = GNN(graph, 128, "article")

model.load_state_dict(torch.load(f'models/model-{fold_id}.pt'))