In [1]:
import pandas as pd
import numpy as np
from scipy import sparse, io, stats
import pickle
from tqdm import tqdm
import torch
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

def save_obj(obj, name):
    with open(name, 'wb') as f:
        pickle.dump(obj, f)

def load_obj(name):
    with open(name, 'rb') as f:
        return pickle.load(f)

def get_train_item(size=1024):
    train_sample_index = np.random.randint(0, high=len(train_item_u), size=size)
    user = train_item_u[train_sample_index]
    pos = train_item_i[train_sample_index]
    neg = np.random.choice(range(item_shape), size, replace=True)
    for i in range(size):
        item = train_item[user[i]].nonzero()[1]
        while neg[i] in item:
            neg[i] = np.random.choice(range(item_shape))
    return user, pos, neg

def get_train_bundle(size=1024):
    train_sample_index = np.random.randint(0, high=len(train_bundle_u), size=size)
    user = train_bundle_u[train_sample_index]
    pos = train_bundle_b[train_sample_index]
    neg = np.random.choice(range(bundle_shape), size, replace=True)
    for i in range(size):
        bundle = train_bundle[user[i]].nonzero()[1]
        while (neg[i] in bundle) or (len(bundle_item[neg[i]].nonzero()[1]) == 0):
            neg[i] = np.random.choice(range(bundle_shape))
    return user, pos, neg

def measure(user, K=3):
    recall = 0
    hit = 0
    MAP = 0
    zero = 0
    for u in user:
        pos = test_bundle[u].nonzero()[1]
        neg = negative[u]
        pos_val = pred_test_bundle(u, pos)
        rank = 0
        for n in neg:
            neg_val = pred_test_bundle(u, [n])
            if neg_val > pos_val:
                rank += 1
            if rank >= K:
                break
        if rank < K:
            recall += 1
            MAP += 1 / (rank + 1)
    return (recall / len(user)), (MAP / len(user))
    
def get_hit():
    Recall, MAP = measure(all_user, 5)
    print ("Recall: ", Recall, "MAP: ", MAP)
    
def pred_test_bundle(u, bid):
    be = []
    w = []
    for i, b in enumerate(bid):
        sample = bundle_item[b].nonzero()[1]
        bi = item_embeds[sample]
        w = F.softmax(torch.sum(user_embeds[[u] * len(sample)] * A[sample], 1, True), 0)
        be.append(torch.sum(bi * w, 0) + bundle_embeds[b])
    be = torch.stack(be)
    return dam(torch.cat((user_embeds[[u]], be), 1), True)
    
def pred_bundle(u, bid):
    be = []
    w = []
    for i, b in enumerate(bid):
        sample = bundle_item[b].nonzero()[1]
        bi = item_embeds[sample]
        w = F.softmax(torch.sum(user_embeds[[u[i]] * len(sample)] * A[sample], 1, True), 0)
        be.append(torch.sum(bi * w, 0) + bundle_embeds[b])
    be = torch.stack(be)
    return dam(torch.cat((user_embeds[[u]], be), 1), True)

def get_bundle_loss():
    u, pb, nb = get_train_bundle()
    upb = pred_bundle(u, pb)
    unb = pred_bundle(u, nb)
    return -torch.sum(torch.log(torch.sigmoid(upb - unb)))


def get_item_loss():
    u, p, n = get_train_item()
    u = user_embeds[[u]]
    p = item_embeds[[p]]
    n = item_embeds[[n]]
    
    up = torch.cat((u, p), 1)
    un = torch.cat((u, n), 1)
    
    up = dam(up)
    un = dam(un)
    
    return -torch.sum(torch.log(torch.sigmoid(up - un)))

In [2]:
user_item = load_obj('data/Youshu/user_item')
user_bundle = load_obj('data/Youshu/user_list')
bundle_item = load_obj('data/Youshu/list_item')

test_item, train_item = load_obj('data/Youshu/test_item'), load_obj('data/Youshu/train_item')
test_bundle, train_bundle = load_obj('data/Youshu/test'), load_obj('data/Youshu/train')
negative = load_obj('data/Youshu/neg')
train_item_u, train_item_i = train_item.nonzero()
train_bundle_u, train_bundle_b = train_bundle.nonzero()

user_shape, item_shape = user_item.shape
bundle_shape, item_shape = bundle_item.shape
all_user, _ = np.nonzero(np.sum(test_bundle, 1))

In [3]:
embed_shape = 5
user_embeds = Variable(torch.FloatTensor(user_shape, embed_shape).normal_(0, 0.001), requires_grad=True)
item_embeds = Variable(torch.FloatTensor(item_shape, embed_shape).normal_(0, 0.001), requires_grad=True)
bundle_embeds = Variable(torch.FloatTensor(bundle_shape, embed_shape).normal_(0, 0.001), requires_grad=True)

A = Variable(torch.FloatTensor(item_shape, embed_shape).normal_(0, 0.001), requires_grad=True)

class DAM(torch.nn.Module):
    def __init__(self):
        super(DAM, self).__init__()
        self.sdense = torch.nn.Linear(embed_shape * 2, embed_shape * 2)
        self.dense = torch.nn.Linear(embed_shape * 2, embed_shape * 2)
        self.ipred = torch.nn.Linear(embed_shape * 2, 1)
        self.bpred = torch.nn.Linear(embed_shape * 2, 1)
        
    def forward(self, x, bundle=False):
        x = torch.relu(self.sdense(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = torch.relu(self.dense(x))
        x = F.dropout(x, p=0.5, training=self.training)
        if bundle:
            x = torch.relu(self.bpred(x))
            return x
        x = torch.relu(self.ipred(x))
        return x

In [4]:
dam = DAM()
dam.train()
para = list(dam.parameters())
para.extend([user_embeds, item_embeds, bundle_embeds, A])
opt = torch.optim.Adam(para, lr=0.005, weight_decay=0.001)
for epoch in tqdm(range(800)):
    loss = get_item_loss()
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    loss = get_bundle_loss()
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if (epoch % 100 == 0 and epoch > 400):
        dam.eval()
        get_hit()
        dam.train()
        
dam.eval()
get_hit()

 63%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                               | 501/800 [16:14<3:22:42, 40.68s/it]

Recall:  0.5932782682996297 MAP:  0.3906484382417172


 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                          | 601/800 [21:18<2:17:06, 41.34s/it]

Recall:  0.599544289376246 MAP:  0.39711383271622563


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                     | 701/800 [26:20<1:08:34, 41.56s/it]

Recall:  0.5955568214183993 MAP:  0.3981154466913511


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 800/800 [29:06<00:00,  1.55s/it]


Recall:  0.5998291085160923 MAP:  0.40353650431975713
