In [None]:
# default_exp models.tagnn

# TAGNN
> [Yu et. al. Target Attentive Graph Neural Networks for Session-based Recommendation. SIGIR, 2020.](https://arxiv.org/abs/2005.02844)

TAGNN first models all session sequences as session graphs. Then, graph neural networks capture rich item transitions in sessions. Lastly, from one session embedding vector, target-aware attention adaptively activates different user interests concerning varied target items to be predicted.

<img src='https://raw.githubusercontent.com/RecoHut-Projects/sessrec-gnn/main/report/images/img6.png'>

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import math
import datetime
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

from recohut.datasets.session import SampleSessionDataset, GraphData

In [None]:
#export
class GNN(nn.Module):
    def __init__(self, hidden_size, step=1):
        super(GNN, self).__init__()
        self.step = step
        self.hidden_size = hidden_size
        self.input_size = hidden_size * 2
        self.gate_size = 3 * hidden_size
        self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
        self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.hidden_size))
        self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
        self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
        self.b_iah = nn.Parameter(torch.Tensor(self.hidden_size))
        self.b_oah = nn.Parameter(torch.Tensor(self.hidden_size))

        self.linear_edge_in = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_edge_out = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_edge_f = nn.Linear(self.hidden_size, self.hidden_size, bias=True)

    def GNNCell(self, A, hidden):
        input_in = torch.matmul(A[:, :, :A.shape[1]], self.linear_edge_in(hidden)) + self.b_iah
        input_out = torch.matmul(A[:, :, A.shape[1]: 2 * A.shape[1]], self.linear_edge_out(hidden)) + self.b_oah
        inputs = torch.cat([input_in, input_out], 2)
        gi = F.linear(inputs, self.w_ih, self.b_ih)
        gh = F.linear(hidden, self.w_hh, self.b_hh)
        i_r, i_i, i_n = gi.chunk(3, 2)
        h_r, h_i, h_n = gh.chunk(3, 2)
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + resetgate * h_n)
        hy = newgate + inputgate * (hidden - newgate)
        return hy

    def forward(self, A, hidden):
        for i in range(self.step):
            hidden = self.GNNCell(A, hidden)
        return hidden

In [None]:
#export
class TAGNN(nn.Module):
    def __init__(self, opt):
        super(TAGNN, self).__init__()
        self.hidden_size = opt.hiddenSize
        self.n_node = opt.n_node
        self.batch_size = opt.batchSize
        self.nonhybrid = opt.nonhybrid
        self.embedding = nn.Embedding(self.n_node, self.hidden_size)
        self.gnn = GNN(self.hidden_size, step=opt.step)
        self.linear_one = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_two = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
        self.linear_three = nn.Linear(self.hidden_size, 1, bias=False)
        self.linear_transform = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=True)
        self.linear_t = nn.Linear(self.hidden_size, self.hidden_size, bias=False)  #target attention
        self.loss_function = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=opt.lr, weight_decay=opt.l2)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def compute_scores(self, hidden, mask):
        ht = hidden[torch.arange(mask.shape[0]).long(), torch.sum(mask, 1) - 1]  # batch_size x latent_size
        q1 = self.linear_one(ht).view(ht.shape[0], 1, ht.shape[1])  # batch_size x 1 x latent_size
        q2 = self.linear_two(hidden)  # batch_size x seq_length x latent_size
        alpha = self.linear_three(torch.sigmoid(q1 + q2))  # (b,s,1)
        # alpha = torch.sigmoid(alpha) # B,S,1
        alpha = F.softmax(alpha, 1) # B,S,1
        a = torch.sum(alpha * hidden * mask.view(mask.shape[0], -1, 1).float(), 1)  # (b,d)
        if not self.nonhybrid:
            a = self.linear_transform(torch.cat([a, ht], 1))
        b = self.embedding.weight[1:]  # n_nodes x latent_size
        # target attention: sigmoid(hidden M b)
        # mask  # batch_size x seq_length
        hidden = hidden * mask.view(mask.shape[0], -1, 1).float()  # batch_size x seq_length x latent_size
        qt = self.linear_t(hidden)  # batch_size x seq_length x latent_size
        # beta = torch.sigmoid(b @ qt.transpose(1,2))  # batch_size x n_nodes x seq_length
        beta = F.softmax(b @ qt.transpose(1,2), -1)  # batch_size x n_nodes x seq_length
        target = beta @ hidden  # batch_size x n_nodes x latent_size
        a = a.view(ht.shape[0], 1, ht.shape[1])  # b,1,d
        a = a + target  # b,n,d
        scores = torch.sum(a * b, -1)  # b,n
        # scores = torch.matmul(a, b.transpose(1, 0))
        return scores

    def forward(self, inputs, A):
        hidden = self.embedding(inputs)
        hidden = self.gnn(A, hidden)
        return hidden

## Training a session-based recommender using TAGNN

In [None]:
class Args():
    dataset = 'sample'
    batchSize = 100 # input batch size
    hiddenSize = 100 # hidden state size
    epoch = 30 # the number of epochs to train for
    lr = 0.001 # learning rate')  # [0.001, 0.0005, 0.000
    lr_dc = 0.1 # learning rate decay rate
    lr_dc_step = 3 # the number of steps after which the learning rate decay
    l2 = 1e-5 # l2 penalty')  # [0.001, 0.0005, 0.0001, 0.00005, 0.0000
    step = 1 # gnn propogation steps
    patience = 10 # the number of epoch to wait before early stop 
    nonhybrid = True # only use the global preference to predict
    validation = True # validation
    valid_portion = 0.1 # split the portion of training set as validation set
    n_node = 310

args = Args()

In [None]:
def trans_to_cuda(variable):
    if torch.cuda.is_available():
        return variable.cuda()
    else:
        return variable


def trans_to_cpu(variable):
    if torch.cuda.is_available():
        return variable.cpu()
    else:
        return variable


def forward(model, i, data):
    alias_inputs, A, items, mask, targets = data.get_slice(i)
    alias_inputs = trans_to_cuda(torch.Tensor(alias_inputs).long())
    items = trans_to_cuda(torch.Tensor(items).long())
    A = trans_to_cuda(torch.Tensor(A).float())
    mask = trans_to_cuda(torch.Tensor(mask).long())
    hidden = model(items, A)
    get = lambda i: hidden[i][alias_inputs[i]]
    seq_hidden = torch.stack([get(i) for i in torch.arange(len(alias_inputs)).long()])
    return targets, model.compute_scores(seq_hidden, mask)


def train_test(model, train_data, test_data):
    model.scheduler.step()
    print('start training: ', datetime.datetime.now())
    model.train()
    total_loss = 0.0
    slices = train_data.generate_batch(model.batch_size)
    for i, j in zip(slices, np.arange(len(slices))):
        model.optimizer.zero_grad()
        targets, scores = forward(model, i, train_data)
        targets = trans_to_cuda(torch.Tensor(targets).long())
        loss = model.loss_function(scores, targets - 1)
        loss.backward()
        model.optimizer.step()
        total_loss += loss.item()
        if j % int(len(slices) / 5 + 1) == 0:
            print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item()))
    print('\tLoss:\t%.3f' % total_loss)

    print('start predicting: ', datetime.datetime.now())
    model.eval()
    hit, mrr = [], []
    slices = test_data.generate_batch(model.batch_size)
    for i in slices:
        targets, scores = forward(model, i, test_data)
        sub_scores = scores.topk(20)[1]
        sub_scores = trans_to_cpu(sub_scores).detach().numpy()
        for score, target, mask in zip(sub_scores, targets, test_data.mask):
            hit.append(np.isin(target - 1, score))
            if len(np.where(score == target - 1)[0]) == 0:
                mrr.append(0)
            else:
                mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))
    hit = np.mean(hit) * 100
    mrr = np.mean(mrr) * 100
    return hit, mrr


def split_validation(train_set, valid_portion):
    train_set_x, train_set_y = train_set
    n_samples = len(train_set_x)
    sidx = np.arange(n_samples, dtype='int32')
    np.random.shuffle(sidx)
    n_train = int(np.round(n_samples * (1. - valid_portion)))
    valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
    valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
    train_set_x = [train_set_x[s] for s in sidx[:n_train]]
    train_set_y = [train_set_y[s] for s in sidx[:n_train]]

    return (train_set_x, train_set_y), (valid_set_x, valid_set_y)

In [None]:
import pickle
import time

_ = SampleSessionDataset('./session_ds')
train_data = pickle.load(open('./session_ds/processed/train.txt', 'rb'))

if args.validation:
    train_data, valid_data = split_validation(train_data, args.valid_portion)
    test_data = valid_data
else:
    test_data = pickle.load(open('./session_ds/processed/test.txt', 'rb'))

train_data = GraphData(train_data, shuffle=True)
test_data = GraphData(test_data, shuffle=False)

model = trans_to_cuda(TAGNN(args))

start = time.time()
best_result = [0, 0]
best_epoch = [0, 0]
bad_counter = 0

for epoch in range(args.epoch):
    print('-------------------------------------------------------')
    print('epoch: ', epoch)
    hit, mrr = train_test(model, train_data, test_data)
    flag = 0
    if hit >= best_result[0]:
        best_result[0] = hit
        best_epoch[0] = epoch
        flag = 1
    if mrr >= best_result[1]:
        best_result[1] = mrr
        best_epoch[1] = epoch
        flag = 1
    print('Best Result:')
    print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d'% (best_result[0], best_result[1], best_epoch[0], best_epoch[1]))
    bad_counter += 1 - flag
    if bad_counter >= args.patience:
        break
print('-------------------------------------------------------')
end = time.time()
print("Run time: %f s" % (end - start))

-------------------------------------------------------
epoch:  0
start training:  2021-12-23 11:17:17.358225
[0/11] Loss: 5.7136




[3/11] Loss: 5.7028
[6/11] Loss: 5.7000
[9/11] Loss: 5.6981
	Loss:	62.736
start predicting:  2021-12-23 11:17:18.800668
Best Result:
	Recall@20:	62.8099	MMR@20:	46.2166	Epoch:	0,	0
-------------------------------------------------------
epoch:  1
start training:  2021-12-23 11:17:18.881197
[0/11] Loss: 5.6988
[3/11] Loss: 5.6828
[6/11] Loss: 5.6606
[9/11] Loss: 5.6704
	Loss:	62.421
start predicting:  2021-12-23 11:17:20.583203
Best Result:
	Recall@20:	62.8099	MMR@20:	46.2166	Epoch:	0,	0
-------------------------------------------------------
epoch:  2
start training:  2021-12-23 11:17:20.668376
[0/11] Loss: 5.6598
[3/11] Loss: 5.6544
[6/11] Loss: 5.6495
[9/11] Loss: 5.6483
	Loss:	62.155
start predicting:  2021-12-23 11:17:22.019096
Best Result:
	Recall@20:	62.8099	MMR@20:	46.2166	Epoch:	0,	0
-------------------------------------------------------
epoch:  3
start training:  2021-12-23 11:17:22.106588
[0/11] Loss: 5.6384
[3/11] Loss: 5.6427
[6/11] Loss: 5.6393
[9/11] Loss: 5.6504
	Loss:	

---

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-23 11:17:51

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

IPython: 5.5.0
numpy  : 1.19.5
torch  : 1.10.0+cu111

