In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import argparse
import random
from model import KGCN
from data_loader import DataLoader
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

In [3]:
parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default='movie', help='which dataset to use')
parser.add_argument('--aggregator', type=str, default='sum', help='which aggregator to use')
parser.add_argument('--n_epochs', type=int, default=5, help='the number of epochs')
parser.add_argument('--neighbor_sample_size', type=int, default=4, help='the number of neighbors to be sampled')
parser.add_argument('--dim', type=int, default=32, help='dimension of user and entity embeddings')
parser.add_argument('--n_iter', type=int, default=2, help='number of iterations when computing entity representation')
parser.add_argument('--batch_size', type=int, default=8192, help='batch size')
parser.add_argument('--l2_weight', type=float, default=1e-4, help='weight of l2 regularization')
parser.add_argument('--lr', type=float, default=2e-2, help='learning rate')
parser.add_argument('--ratio', type=float, default=0.8, help='size of training dataset')

args = parser.parse_args(['--l2_weight', '1e-4'])

In [4]:
data_loader = DataLoader(args.dataset)
kg = data_loader.load_kg()
df_dataset = data_loader.load_data()
df_dataset = df_dataset.sample(frac=0.5, replace=False)
df_dataset.head()

Construct knowledge graph ... Done
Build dataset dataframe ... Done


Unnamed: 0,index,userID,itemID,label
1923284,4160929,2391,3,1
11293178,3461789,902,5,1
371596,14777886,1138,4,1
4194371,11810670,18638,4,1
180133,8506288,11673,4,1


In [5]:
class KGCNDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        user_id = np.array(self.df.iloc[idx]['userID'])
        movie_id = np.array(self.df.iloc[idx]['itemID'])
        label = np.array(self.df.iloc[idx]['label'], dtype=np.float32)
        return user_id, movie_id, label

In [6]:
x_train, x_test, y_train, y_test = train_test_split(df_dataset, df_dataset['label'], test_size=1 - args.ratio)
train_dataset = KGCNDataset(x_train)
test_dataset = KGCNDataset(x_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)

In [7]:
from model import KGCN
num_user, num_entity, num_relation = data_loader.get_num()
user_encoder, entity_encoder, relation_encoder = data_loader.get_encoders()
net = KGCN(num_user, num_entity, num_relation, kg, args)
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr)

In [None]:
pn = 50
loss_list = []
test_loss_list = []
auc_score_list = []
for epoch in range(args.n_epochs):
    running_loss = 0.0
    for i, (user_id, movie_id, label) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(user_id, movie_id)
        loss = criterion(outputs, label)
        loss.backward()
        
        optimizer.step()

        running_loss += loss.item()
        if i % pn == pn-1:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch+1, i + 1, running_loss / pn))
            loss_list.append(running_loss / pn)
            running_loss = 0.0
            
            with torch.no_grad():
                test_loss = 0
                total_roc = 0
                for (uid, mid, l) in test_loader:
                    out = net(uid, mid)
                    test_loss += criterion(out, l).item()
                    total_roc += roc_auc_score(l.detach().numpy(), out.detach().numpy())
                print('test_loss: ', test_loss / pn)
                test_loss_list.append(test_loss / pn)
                auc_score_list.append(total_roc / pn)

In [None]:
plt.plot(loss_list)
plt.plot(test_loss_list)
plt.plot(auc_score_list)
plt.show()