In [1]:
import os
import shutil
import pandas as pd
import math
import time
import random

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torch_sparse import SparseTensor, spmm

from metrics import *
from GNN import *
from transformer import SentenceTransformer
from graph import *
from mind_utils import *

TRAIN_S = './data/mind/training_small'
TRAIN_L = './data/mind/training_large'
VAL_S = '/data/mind/validation_small'
VAL_L = '/data/mind/validation_large'
TEST = '/data/mind/mind/test_large'

log_dir = './runs/tb/'
model_dir = './runs/models/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(42)
random.seed(42)

In [2]:
def train(exp_name, train_loader, model, epochs, batch_size, test=False):
    
    optimizer = torch.optim.Adam(model.parameters())

    exp_dir = log_dir + exp_name
    if os.path.exists(exp_dir) and os.path.isdir(exp_dir):
        shutil.rmtree(exp_dir) 
    writer = SummaryWriter(log_dir=exp_dir)

    
    
    losses, pos_accs, neg_accs = [], [], []
    for epoch in range(epochs):
        for b_i, batch in enumerate(train_loader):

            
            subgraph, news_x, user_batch, pos_batch, neg_batch = batch
            subgraph_feats = model(news_x, subgraph)
            
            for session in range(batch_size):
            
                user, pos, neg = user_batch[session], pos_batch[session], neg_batch[session]
              
                user = torch.tensor([0])
                pos = torch.tensor([0, 1])
                neg = torch.tensor([3, 4, 5])
                n_pos, n_neg = pos.shape[0], neg.shape[0]

                user_emb = subgraph_feats['users'][user]
                art_emb = subgraph_feats['news'][torch.cat((pos, neg))]

                scores = angular_distance(user_emb, art_emb)

                targets = torch.zeros(n_pos+n_neg, device=device)
                weights = torch.full(targets.shape, 1/n_neg, device=device)
                targets[:n_pos] = 1
                weights[:n_pos] = 1/n_pos

                loss = F.binary_cross_entropy(scores, targets, (weights/2), reduction='sum')

                pos_accs.extend((scores[:n_pos].round() == 1).cpu().numpy())
                neg_accs.extend((scores[n_pos:].round() == 0).cpu().numpy())
                losses.append(loss.item())
                
    
                if (session+1) != batch_size:
                    loss.backward(retain_graph=True)
                
                elif (session+1) == batch_size:
                    loss.backward(retain_graph=False)
                    optimizer.step()
                    optimizer.zero_grad()

                    writer.add_scalar('pos_acc', np.mean(pos_accs), (epoch+1)*(b_i * batch_size))
                    writer.add_scalar('neg_acc', np.mean(neg_accs), (epoch+1)*(b_i * batch_size))
                    writer.add_scalar('loss', np.mean(losses), (epoch+1)*(b_i * batch_size))
                    losses, pos_accs, neg_accs = [], [], []
    
                
            
            #if (i+1)%5000 == 0:
                #writer.close()
                #writer = SummaryWriter(log_dir=exp_dir)
                

            if b_i == 200 and test:
                break
            
        if not test:
            torch.save(model.state_dict(), f'{model_dir+exp_name}_epoch_{epoch}')
                
                
    writer.close()
    #torch.save(model.state_dict(), model_dir+exp_name)

In [3]:
def evaluate(exp_name, model, writer, iteration):

    exp_dir = log_dir + exp_name + '_eval'
    if os.path.exists(exp_dir) and os.path.isdir(exp_dir):
        shutil.rmtree(exp_dir) 
    writer = SummaryWriter(log_dir=exp_dir)      
    
    if not model:
        model = GCN(val_data.features['news'].shape[1], empty_nodes, atr_nodes).to(device)
        model.load_state_dict(torch.load(model_dir + exp_name), strict=False)
    model.eval()

    with torch.no_grad():
        feats = model(val_data, val_data.edge_index)


        aucs, mrrs, ndcg5s, ndcg10s, = [], [], [], []
        pos_accs, neg_accs = [], []

        for i, batch in enumerate(val_sessions):
            user, pos, neg = batch
            pos, neg = pos.reshape(-1), neg.reshape(-1)
            n_pos, n_neg = pos.shape[0], neg.shape[0]

            #subgraph = get_subgraph(val_data, user, pos.reshape(-1), neg.reshape(-1))  
            #subgraph_feats = model(val_data, subgraph)
            #user_emb = subgraph_feats['users'][user]
            #art_emb = torch.index_select(subgraph_feats['news'], 0, torch.cat((pos, neg), dim=1).squeeze())

            user_emb = feats['users'][user]
            art_emb = feats['news'][torch.cat((pos, neg))]

            scores = angular_distance(user_emb, art_emb)

            targets = torch.zeros(scores.shape, device=device)
            targets[:n_pos] = 1

            pos_accs.append((torch.round(scores[:n_pos]) == 1).float().mean().cpu().numpy())
            neg_accs.append((torch.round(scores[n_pos:]) == 0).float().mean().cpu().numpy())

            #writer.add_scalar('pos_acc', pos_acc, i)
            #writer.add_scalar('neg_acc', neg_acc, i)


            ranks = (torch.argsort(scores, descending=True)+1).cpu().numpy()
            y_score = [1./rank for rank in ranks]
            y_true = targets.cpu().numpy()

            auc = roc_auc_score(y_true,y_score)
            mrr = mrr_score(y_true,y_score)
            ndcg5 = ndcg_score(y_true,y_score,5)
            ndcg10 = ndcg_score(y_true,y_score,10)

            aucs.append(auc)
            mrrs.append(mrr)
            ndcg5s.append(ndcg5)
            ndcg10s.append(ndcg10)
                   
    
    writer.add_scalar('auc', np.mean(aucs), iteration)
    writer.add_scalar('mrr', np.mean(mrrs), iteration)
    writer.add_scalar('ndcg5', np.mean(ndcg5s), iteration)
    writer.add_scalar('ndcg10', np.mean(ndcg10s), iteration)
    writer.add_scalar('val_pos_acc', np.mean(pos_accs), iteration)
    writer.add_scalar('val_neg_acc', np.mean(neg_accs), iteration)
    
    
    return


In [4]:
from torch_geometric.utils.mask import index_to_mask
from torch_geometric.utils import dropout_edge

class Impressions(Dataset):
    def __init__(self, data, set_dir, device=device):
        self.device = device
        df = get_behaviors(set_dir).dropna(subset=['history'])
        
        self.user = [data.mapping['users'][user_id] for user_id in df.user_id]
        self.pos = [[data.mapping['news'][article] for article in session] for session in df.pos]
        self.neg = [[data.mapping['news'][article] for article in session] for session in df.neg]
        
        self.edge_index = data.edge_index
        self.atr_nodes = data.atr_nodes
        self.n_nodes = data.n_nodes
        self.news_x = data.features['news']
    
    def __len__(self):
        return len(self.pos)
    
    def __getitem__(self, idx):
        user = torch.tensor(self.user[idx], dtype=torch.long, device=self.device)
        pos = torch.tensor(self.pos[idx], dtype=torch.long, device=self.device)
        neg = torch.tensor(self.neg[idx], dtype=torch.long, device=self.device)
        return (user, pos, neg)

    

    
    

    def bipartite_subgraph(self, subset, edge_index, size, targets):
        '''
        Relabels the sampled edge_index so that node indices range from [0...n_sampled_nodes] instead of [0...n_all_nodes]
        Input:
        subset: tuple of (news, atr) nodes to be included
        edge_index: edges from sampling
        size: amount of nodes in non-sampled graph
        '''

        src_subset, dst_subset = subset
        t_src_subset, t_dst_subset = targets

        src_subset = index_to_mask(src_subset, size=size[0])
        dst_subset = index_to_mask(dst_subset, size=size[1])
        t_src_subset = index_to_mask(t_src_subset, size=size[0])
        t_dst_subset = index_to_mask(t_dst_subset, size=size[1])


        # relabel nodes
        node_idx_i = edge_index.new_zeros(src_subset.size(0))
        node_idx_j = edge_index.new_zeros(dst_subset.size(0))

        node_idx_i[src_subset] = torch.arange(int(src_subset.sum()),
                                              device=node_idx_i.device)

        node_idx_j[dst_subset] = torch.arange(int(dst_subset.sum()),
                                              device=node_idx_j.device)

        edge_index = torch.stack([
            node_idx_i[edge_index[0]],
            node_idx_j[edge_index[1]],
        ], dim=0)

        new_articles = node_idx_i[t_src_subset]
        new_users = node_idx_j[t_dst_subset]
        
        return edge_index, new_articles, new_users
    
    
    
    
    
    def sampler(self, users, articles):
        new_index = {node_type: [] for node_type in self.atr_nodes}
        to_sample = {node_type: [] for node_type in self.atr_nodes}

        # Hop 1: grab article history of users for which to predict
        index = self.edge_index['users']
        sample = index[:, torch.isin(index[1], users)]
        new_index['users'] = [sample]
        to_sample['news'] = torch.cat((sample[0], articles))

        # Hop 2: from user history and session postive/negative articles grab connected users/atributes
        for atr in self.atr_nodes:
            index = self.edge_index[atr]
            sample = index[:, torch.isin(index[0], to_sample['news'])]
            new_index[atr] += [sample]
            to_sample[atr] = sample[1]

        # Hop 3: grab articles from these users/atributes
        for atr in self.atr_nodes:
            index = self.edge_index[atr]
            sample = index[:, torch.isin(index[1], to_sample[atr])]
            sample, _ = dropout_edge(sample, p=0.7)
            new_index[atr] += [sample]

        #### create dict
        edge_index = {}
        for atr, index in new_index.items():
            index = torch.cat(index, dim=1).unique(dim=1)
            
            news_nodes = index[0].unique()
            atr_nodes = index[1].unique()
            
            index, articles, users = self.bipartite_subgraph((news_nodes, atr_nodes), 
                                               index, 
                                               (self.n_nodes['news'], self.n_nodes[atr]),
                                               (articles, users))
            edge_index[atr] = SparseTensor.from_edge_index(index, sparse_sizes=(news_nodes.shape[0], atr_nodes.shape[0]))
            
            news_x = self.news_x[news_nodes]

        return edge_index, news_x, articles, users
    
    
    
    
    def collate(self, batch):
        users, pos, neg = list(zip(*batch))
        subgraph, news_x, articles, users = self.sampler(torch.stack(users), torch.cat((*pos, *neg)))
        
        
        return subgraph, news_x, users, pos, neg