In [1]:
import os
import shutil
import pandas as pd
import math
import time
import random

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torch_sparse import SparseTensor, spmm

from metrics import *
from transformer import SentenceTransformer
from mind_utils import *

TRAIN_S = './data/mind/training_small'
TRAIN_L = './data/mind/training_large'
VAL_S = '/data/mind/validation_small'
VAL_L = '/data/mind/validation_large'
TEST = '/data/mind/mind/test'

log_dir = './runs/tb/'
model_dir = './runs/models/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(42)
random.seed(42)

In [2]:
from collections import defaultdict
from mind_utils import *
import torch

class DataGraph(object):
    def __init__(self, set_dir, text_enc, device):
        
        self.device = device
        
        news_df = self.get_news(set_dir)
        hist_df = self.get_user_history(set_dir)

        self.mapping = {}
        self.mapping['news'] = self.map_col(news_df.news_id)
        self.mapping['users'] = self.map_col(hist_df.user_id)
        
        self.news_x = self.get_news_x(text_enc, news_df).to(device)
        self.edge_index = self.load_edge_index(hist_df.explode('history'), 'history', 'user_id', 'users')
       
        self.n_nodes = {key: len(self.mapping[key]) for key in self.mapping.keys()}
   
        
    def map_col(self, column):
        return {index: i for i, index in enumerate(column.unique())} 
         
         
    def load_edge_index(self, df, news_col, atr_col, atr_map):
        news_edge = [self.mapping['news'][index] for index in df[news_col]]
        atr_edge = [self.mapping[atr_map][index] for index in df[atr_col]]
 
        return torch.tensor([news_edge, atr_edge]).to(self.device)
        

    def get_news_x(self, text_enc, news_df):
        news_text = news_df['title'] + ' ' + news_df['abstract'].fillna('')
        news_x =  text_enc.encode(news_text.values)
        # remove zero'd columns
        news_x = news_x[:, news_x.std(dim=0) != 0.]
        # standardize
        news_x = (news_x - news_x.mean(dim=0)) / news_x.std(dim=0)
        return news_x
        
      
    def get_user_history(self, set_dir):
        df = get_behaviors(set_dir)
        user_hist = df[['user_id', 'history']].drop_duplicates().dropna(subset=['history'])
        user_hist['history'] = user_hist['history'].str.split(' ') 
        return user_hist


    def get_news(self, set_dir):
        news_path = os.path.join(set_dir, 'news.tsv')
        news = pd.read_table(news_path,
                            quoting=csv.QUOTE_NONE,
                            header=None,
                            names=['news_id', 'category', 'subcategory', 'title', 'abstract', 'url',
                                    'title_entities', 'abstract_entities'])
        news['title_entities'] = news['title_entities'].apply(lambda row: json.loads(row))
        return news


In [3]:
def train(exp_name, train_loader, model, epochs, batch_size, test=False):
    
    optimizer = torch.optim.Adam(model.parameters())

    exp_dir = log_dir + exp_name
    if os.path.exists(exp_dir) and os.path.isdir(exp_dir):
        shutil.rmtree(exp_dir) 
    writer = SummaryWriter(log_dir=exp_dir)

    
    
    losses, pos_accs, neg_accs = [], [], []
    for epoch in range(epochs):
        for b_i, batch in enumerate(train_loader):

            
            subgraph, news_x, b_users, b_articles, b_n_pos, b_n_neg = batch
            subgraph_feats = model(news_x, subgraph)
            
            for i in range(batch_size):
            
                user, articles, n_pos, n_neg = b_users[i], b_articles[i], b_n_pos[i], b_n_neg[i]
    
                user_emb = subgraph_feats['users'][user]
                art_emb = subgraph_feats['news'][articles]

                scores = angular_distance(user_emb, art_emb)

                targets = torch.zeros(n_pos+n_neg, device=device)
                weights = torch.full(targets.shape, 1/n_neg, device=device)
                targets[:n_pos] = 1
                weights[:n_pos] = 1/n_pos

                loss = F.binary_cross_entropy(scores, targets, (weights/2), reduction='sum')

                pos_accs.append((scores[:n_pos].round() == 1).mean().cpu().numpy())
                neg_accs.append((scores[n_pos:].round() == 0).mean().cpu().numpy())
                losses.append(loss.item())
                
    
                if (i+1) != batch_size:
                    loss.backward(retain_graph=True)
                
                elif (i+1) == batch_size:
                    loss.backward(retain_graph=False)
                    optimizer.step()
                    optimizer.zero_grad()


            if b_i == 1000:
                writer.add_scalar('pos_acc', np.mean(pos_accs), (epoch+1)*(b_i * batch_size))
                writer.add_scalar('neg_acc', np.mean(neg_accs), (epoch+1)*(b_i * batch_size))
                writer.add_scalar('loss', np.mean(losses), (epoch+1)*(b_i * batch_size))
                losses, pos_accs, neg_accs = [], [], []
    
                

            if b_i == 200 and test:
                break
            
        if not test:
            torch.save(model.state_dict(), f'{model_dir+exp_name}_epoch_{epoch}')
                
                
    writer.close()
    #torch.save(model.state_dict(), model_dir+exp_name)

In [4]:
def evaluate(exp_name, model, writer, iteration):

    exp_dir = log_dir + exp_name + '_eval'
    if os.path.exists(exp_dir) and os.path.isdir(exp_dir):
        shutil.rmtree(exp_dir) 
    writer = SummaryWriter(log_dir=exp_dir)      
    
    if not model:
        model = GCN(val_data.features['news'].shape[1], empty_nodes, atr_nodes).to(device)
        model.load_state_dict(torch.load(model_dir + exp_name), strict=False)
    model.eval()

    with torch.no_grad():
        feats = model(val_data, val_data.edge_index)


        aucs, mrrs, ndcg5s, ndcg10s, = [], [], [], []
        pos_accs, neg_accs = [], []

        for i, batch in enumerate(val_sessions):
            user, pos, neg = batch
            pos, neg = pos.reshape(-1), neg.reshape(-1)
            n_pos, n_neg = pos.shape[0], neg.shape[0]

            #subgraph = get_subgraph(val_data, user, pos.reshape(-1), neg.reshape(-1))  
            #subgraph_feats = model(val_data, subgraph)
            #user_emb = subgraph_feats['users'][user]
            #art_emb = torch.index_select(subgraph_feats['news'], 0, torch.cat((pos, neg), dim=1).squeeze())

            user_emb = feats['users'][user]
            art_emb = feats['news'][torch.cat((pos, neg))]

            scores = angular_distance(user_emb, art_emb)

            targets = torch.zeros(scores.shape, device=device)
            targets[:n_pos] = 1

            pos_accs.append((torch.round(scores[:n_pos]) == 1).float().mean().cpu().numpy())
            neg_accs.append((torch.round(scores[n_pos:]) == 0).float().mean().cpu().numpy())

            #writer.add_scalar('pos_acc', pos_acc, i)
            #writer.add_scalar('neg_acc', neg_acc, i)


            ranks = (torch.argsort(scores, descending=True)+1).cpu().numpy()
            y_score = [1./rank for rank in ranks]
            y_true = targets.cpu().numpy()

            auc = roc_auc_score(y_true,y_score)
            mrr = mrr_score(y_true,y_score)
            ndcg5 = ndcg_score(y_true,y_score,5)
            ndcg10 = ndcg_score(y_true,y_score,10)

            aucs.append(auc)
            mrrs.append(mrr)
            ndcg5s.append(ndcg5)
            ndcg10s.append(ndcg10)
                   
    
    writer.add_scalar('auc', np.mean(aucs), iteration)
    writer.add_scalar('mrr', np.mean(mrrs), iteration)
    writer.add_scalar('ndcg5', np.mean(ndcg5s), iteration)
    writer.add_scalar('ndcg10', np.mean(ndcg10s), iteration)
    writer.add_scalar('val_pos_acc', np.mean(pos_accs), iteration)
    writer.add_scalar('val_neg_acc', np.mean(neg_accs), iteration)
    
    
    return


In [5]:
from torch_geometric.nn import LayerNorm, GraphNorm, InstanceNorm
from torch_geometric.utils import degree
import torch.nn as nn
import torch


class MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, out_channels)

        self.act = nn.Tanh()
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()


    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)
        return x

    

###
# Layers
###

class FirstConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.FFN = MLP(in_channels, hidden_channels, out_channels)
        
        
    def forward(self, subgraph_feats, edge_index):
        n_neighbors = degree(edge_index.storage.col(), num_nodes=edge_index.storage._sparse_sizes[1])
        agr_feats = edge_index.t() @ subgraph_feats['news']
        agr_feats = agr_feats / (n_neighbors[:,None] + 1e-08)
        
        out = self.FFN(agr_feats)
        return out
        


class NewsConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.FFN = MLP(in_channels, hidden_channels, out_channels)
        
    def forward(self, subgraph_feats, edge_index):
        n_neighbors = degree(edge_index.storage.row(), num_nodes=edge_index.storage._sparse_sizes[0])
        agr_feats = edge_index @ subgraph_feats['users']
        agr_feats = agr_feats / (n_neighbors[:,None] + 1e-08)

        out = torch.cat((subgraph_feats['news'], agr_feats), dim=1)
        out = self.FFN(out)
        return out
        


class UserConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.FFN = MLP(in_channels, hidden_channels, out_channels)
        
    def forward(self, subgraph_feats, edge_index):
        n_neighbors = degree(edge_index.storage.col(), num_nodes=edge_index.storage._sparse_sizes[1])
        agr_feats = edge_index.t() @ subgraph_feats['news']
        agr_feats = agr_feats / (n_neighbors[:,None] + 1e-08)
    
        out = torch.cat((subgraph_feats['users'], agr_feats), dim=1)
        out = self.FFN(out)
        return out

    

## Net    
class GCN(nn.Module):
    def __init__(self, news_dim):
        super().__init__()
        in_dim = news_dim
        hidden_dim = 256
        out_dim = in_dim
                
        self.first_conv = FirstConv(in_dim, hidden_dim, out_dim)
        self.news_conv = NewsConv(in_dim + out_dim, hidden_dim*2, out_dim)
        self.user_conv = UserConv(out_dim*2, hidden_dim, out_dim)
        
    def forward(self, news_x, edge_index):

        subgraph_feats = {'news': news_x}
        
        # prop news feats to empty user nodes
        subgraph_feats['users'] = self.first_conv(subgraph_feats, edge_index)

        # update news nodes
        subgraph_feats['news'] = self.news_conv(subgraph_feats, edge_index)
        
        # update atr nodes
        subgraph_feats['users'] = self.user_conv(subgraph_feats, edge_index)
          
        return subgraph_feats










In [6]:
from torch_geometric.utils.mask import index_to_mask
from torch_geometric.utils import dropout_edge

class Impressions(Dataset):
    def __init__(self, data, set_dir, device=device):
        self.device = device
        df = get_behaviors(set_dir).dropna(subset=['history'])
        
        self.user = [data.mapping['users'][user_id] for user_id in df.user_id]
        self.pos = [[data.mapping['news'][article] for article in session] for session in df.pos]
        self.neg = [[data.mapping['news'][article] for article in session] for session in df.neg]
        
        self.edge_index = data.edge_index
        
        self.n_nodes = data.n_nodes
        self.news_x = data.news_x
    
    def __len__(self):
        return len(self.pos)
    
    def __getitem__(self, idx):
        user = torch.tensor(self.user[idx], dtype=torch.long, device=self.device)
        
        pos = self.pos[idx]
        neg = self.neg[idx]
        
        n_pos = len(pos)
        n_neg = len(neg)
        
        articles = torch.tensor(pos + neg, dtype=torch.long, device=self.device)

        return user, articles, n_pos, n_neg, n_pos+n_neg

    

    
    def relabel_nodes(self, subset, edge_index, size, targets):
        '''
        Relabels the sampled edge_index so that node indices range from [0...n_sampled_nodes] instead of [0...n_all_nodes]
        Input:
        subset: tuple of (news, atr) nodes to be included
        edge_index: edges from sampling
        size: amount of nodes in non-sampled graph
        '''
        
        news_subset, user_subset = subset
        articles, users = targets     

        news_subset = index_to_mask(news_subset, size=size[0])           
        user_subset = index_to_mask(user_subset, size=size[1])
        
        # relabel nodes
        node_idx_news = edge_index.new_zeros(news_subset.size(0))
        node_idx_user = edge_index.new_zeros(user_subset.size(0))
        
        node_idx_news[news_subset] = torch.arange(int(news_subset.sum()),
                                            device=node_idx_news.device)

        node_idx_user[user_subset] = torch.arange(int(user_subset.sum()),
                                            device=node_idx_user.device)


        articles = node_idx_news[articles]
        users = node_idx_user[users]

        edge_index = torch.stack([
            node_idx_news[edge_index[0]],
            node_idx_user[edge_index[1]],
        ], dim=0)

        return edge_index, articles, users

    def dropout_nodes(self, p, edge_index, n_nodes, dim):
        prob = torch.rand(n_nodes, device=edge_index.device)
        node_mask = prob > p
        
        edge_mask = node_mask[edge_index[dim]]
        edge_index = edge_index[:, edge_mask]
        return edge_index



    def sampler(self, users, articles):

        # Hop 1: article history for target user
        sample = self.edge_index[:, torch.isin(self.edge_index[1], users)]
        new_index = sample
        to_sample = torch.cat((sample[0], articles))

        # Hop 2: users connected to article history + target articles
        sample = self.edge_index[:, torch.isin(self.edge_index[0], to_sample)]
        sample = self.dropout_nodes(0.7, sample, self.n_nodes['users'], 1)
        new_index = torch.cat((new_index, sample), dim=1)
        to_sample = sample[1]

        # Hop 3: grab articles from these users
        sample = self.edge_index[:, torch.isin(self.edge_index[1], to_sample)]
        ssample = self.dropout_nodes(0.5, sample, self.n_nodes['news'], 0)
        new_index = torch.cat((new_index, sample), dim=1)

        # relabel
        news_nodes = torch.cat((new_index[0], articles)).unique()
        user_nodes = new_index[1].unique()
        
        index, articles, users = self.relabel_nodes((news_nodes, user_nodes), 
                                                        new_index, 
                                                        (self.n_nodes['news'], self.n_nodes['users']),
                                                        (articles, users))

        edge_index = SparseTensor.from_edge_index(index, sparse_sizes=(news_nodes.shape[0], user_nodes.shape[0]))
        
        news_x = self.news_x[news_nodes]

        return edge_index, news_x, articles, users
    
    
    
    def collate(self, batch):

        users, articles, n_pos, n_neg, n_all = list(zip(*batch))

        subgraph, news_x, articles, users = self.sampler(torch.stack(users), torch.cat(articles))

        articles = torch.split(articles, n_all)

        return subgraph, news_x, users, articles, n_pos, n_neg


    

In [19]:
from torch_geometric.nn import LayerNorm, InstanceNorm
from torch_geometric.utils import degree
import torch.nn as nn
import torch


class NewsEnc(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, out_channels)
        self.act = nn.PReLU()
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()


    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)

        return x


class GNNLayer(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.lin = nn.Linear(out_channels, out_channels)
        self.norm = InstanceNorm(out_channels)
        self.act = nn.PReLU()
    
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.norm.reset_parameters()

    def forward(self, x):
        x = self.lin(x)
        x = self.norm(x)
        x = self.act(x)
        
        return x  



class PostLayer(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.lin1 = nn.Linear(out_channels, out_channels)
        self.lin2 = nn.Linear(out_channels, out_channels)
        self.act = nn.Tanh()
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()


    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)

        return x




###
# Layers
###

class FirstConv(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.layer = GNNLayer(out_channels)
        
        
    def forward(self, subgraph_feats, edge_index):
        n_neighbors = degree(edge_index.storage.col(), num_nodes=edge_index.storage._sparse_sizes[1])
        
        x = subgraph_feats['news']
        x = self.layer(x)
        
        agr_feats = edge_index.t() @ x
        out = agr_feats / n_neighbors[:,None]
        
        return out
        


class NewsConv(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.layer = GNNLayer(out_channels)
        
    def forward(self, subgraph_feats, edge_index):     
        n_neighbors = degree(edge_index.storage.row(), num_nodes=edge_index.storage._sparse_sizes[0])

        x = self.layer(subgraph_feats['users'])

        agr_feats = edge_index @ x
        agr_feats = agr_feats / (n_neighbors[:,None] + 1e-08)
        
        out = subgraph_feats['news'] + agr_feats
        
        return out
        


class UserConv(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.layer = GNNLayer(out_channels)
        
    def forward(self, subgraph_feats, edge_index):
        n_neighbors = degree(edge_index.storage.col(), num_nodes=edge_index.storage._sparse_sizes[1])

        x = self.layer(subgraph_feats['news'])

        agr_feats = edge_index.t() @ x
        agr_feats = agr_feats / (n_neighbors[:,None] + 1e-08)
    
        out = subgraph_feats['users'] + agr_feats
        
        return out

    

## Net    
class GCN2(nn.Module):
    def __init__(self, news_dim):
        super().__init__()
        out_dim = 128
                
        self.news_enc = MLP(news_dim, 256, out_dim)
        
        self.first_conv = FirstConv(out_dim)
        self.news_conv = NewsConv(out_dim)
        self.user_conv = UserConv(out_dim)
        
        self.news_post = PostLayer(out_dim)
        self.user_post = PostLayer(out_dim)
        
    def forward(self, news_x, edge_index):

        subgraph_feats = {'news': self.news_enc(news_x)}
        
        # prop news feats to empty user nodes
        subgraph_feats['users'] = self.first_conv(subgraph_feats, edge_index)

        # update news nodes
        subgraph_feats['news'] = self.news_conv(subgraph_feats, edge_index)
        
        # update atr nodes
        subgraph_feats['users'] = self.user_conv(subgraph_feats, edge_index)

        # post
        subgraph_feats['news'] = self.news_post(subgraph_feats['news'])
        subgraph_feats['users'] = self.user_post(subgraph_feats['users'])

        return subgraph_feats










In [20]:
torch.manual_seed(42)
random.seed(42)

TRAIN_DIR = TRAIN_L
VAL_DIR = VAL_L



#st_model = 'all-MiniLM-L6-v2'
st_model = 'all-mpnet-base-v2'
encoder = SentenceTransformer(st_model, 'cuda')

batch_size = 8
epochs = 1

#train_data = DataGraph(TRAIN_DIR, encoder, device)
#train_impressions = Impressions(train_data, TRAIN_DIR)
#train_loader = DataLoader(train_impressions,  batch_size, shuffle=True, collate_fn=train_impressions.collate) #set shuffle true

torch.manual_seed(42)
random.seed(42)

model = GCN2(train_data.news_x.shape[1]).to(device)
print('training')
train('L_biglm', train_loader, model, epochs, batch_size, test=True)




training


In [13]:
for i, batch in enumerate(train_loader):
    if i == 200:
        break