In [1]:
import random
import numpy as np
import torch
import pandas as pd
from MetaMF import *
from collections import defaultdict
import matplotlib.pyplot as plt
from datetime import datetime as dt

In [2]:
random.seed(1)
np.random.seed(1)
torch.manual_seed(1) #set random seed for cpu
torch.cuda.manual_seed(1) #set random seed for current gpu
torch.cuda.manual_seed_all(1) #set random seed for all gpus

In [3]:
if torch.cuda.is_available():
    use_cuda = True
    torch.cuda.set_device(6)
else:
    use_cuda = False
print("CUDA available? " + str(use_cuda))
if use_cuda:
    print("Current device: %d" % torch.cuda.current_device())

CUDA available? True
Current device: 6


# Utility Functions
## Read Dataset

In [4]:
def read_dataset(path):
    trainset = pd.read_csv(path + ".train.rating", sep="\t", header=None).to_records(index=False).tolist()
    valset = pd.read_csv(path + ".valid.rating", sep="\t", header=None).to_records(index=False).tolist()
    testset = pd.read_csv(path + ".test.rating", sep="\t", header=None).to_records(index=False).tolist()
    
    return trainset, valset, testset

def read_usergroups(path):
    low_users = pd.read_csv(path + "_low.userlist", header=None, squeeze=True).values.tolist()
    med_users = pd.read_csv(path + "_med.userlist", header=None, squeeze=True).values.tolist()
    high_users = pd.read_csv(path + "_high.userlist", header=None, squeeze=True).values.tolist()
    
    return low_users, med_users, high_users

def read_useranditemlist(path):
    userlist = pd.read_csv(path + ".userlist", header=None, squeeze=True).values.tolist()
    itemlist = pd.read_csv(path + ".itemlist", header=None, squeeze=True).values.tolist()
    
    return userlist, itemlist

## Helpers for Model Training

In [5]:
def batchtoinput(batch, use_cuda):
    users = []
    items = []
    ratings = []
    for example in batch:
        users.append(example[0])
        items.append(example[1])
        ratings.append(example[2])
    users = torch.tensor(users, dtype=torch.int64)
    items = torch.tensor(items, dtype=torch.int64)
    ratings = torch.tensor(ratings, dtype=torch.float32)
    if use_cuda:
        users = users.cuda()
        items = items.cuda()
        ratings = ratings.cuda()
    return users, items, ratings

def getbatches(traindata, batch_size, use_cuda, shuffle):
    dataset = traindata.copy()
    if shuffle:
        random.shuffle(dataset)
    for batch_i in range(0,int(np.ceil(len(dataset)/batch_size))):
        start_i = batch_i*batch_size
        batch = dataset[start_i:start_i+batch_size]
        yield batchtoinput(batch, use_cuda)
        
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)
        
def get_eval(ratlist, predlist):
    mae = np.mean(np.abs(ratlist-predlist))
    mse = np.mean(np.square(ratlist-predlist))       
    return  mae, mse

## Other Functions

In [6]:
def sampling_procedure(dataset, beta):
    dataframe = pd.DataFrame(dataset, columns=["user_id", "item_id", "rating"])
    n_samples = np.ceil(dataframe.groupby("user_id").size() * (beta)).astype(int)
    new_dataset = []
    for uid, group in dataframe.groupby("user_id"):
        new_dataset.extend(group.sample(n=n_samples.loc[uid]).to_records(index=False).tolist())
    return new_dataset

In [7]:
def run(path, traindata, valdata, testdata, userlist, itemlist, low, med, high, hyperparameters, betas=None, disable_meta_learning=False, save=False):
    if betas is None:
        betas = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
        
    results = []
    for beta in betas:
        results_dict = {"beta": beta}
        model_name = "beta_" + str(int(beta*100)) + "p"  
        print("==========================")
        print(model_name)
        print("==========================")
        starttime = dt.now()
        R_train_beta = sampling_procedure(traindata, beta)
        
        train_loss, validation_loss = [], []
        net = MetaMF(len(userlist), len(itemlist))
        if disable_meta_learning:
            net.disable_meta_learning()
        
        net.apply(weights_init)
        if use_cuda:
            net.cuda()
            
        optimizer = optim.Adam(net.parameters(), lr=hyperparameters["lr"], weight_decay=hyperparameters["lambda"])
        batch_size = hyperparameters["batch_size"]
        n_epochs = hyperparameters["n_epochs"]
        
        for epoch in range(n_epochs):
            net.train()
            error = 0
            num = 0
            for k, (users, items, ratings) in enumerate(getbatches(R_train_beta, batch_size, use_cuda, True)):
                optimizer.zero_grad()
                pred = net(users, items)

                loss = net.loss(pred, ratings)
                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                error += loss.detach().cpu().numpy()*len(users)
                num += len(users)
            train_loss.append(error/num)
            
            net.eval()
            groundtruth, estimation = [], []
            for users, items, ratings in getbatches(valdata, batch_size, use_cuda, False):
                predictions = net(users, items)
                estimation.extend(predictions.tolist())
                groundtruth.extend(ratings.tolist())
            mae, mse = get_eval(np.array(groundtruth), np.array(estimation))
            validation_loss.append(mse)
            
            print('Epoch {}/{} - Training Loss: {:.3f}, Validation Loss: {:.3f}, Time Elapsed: {}'.format(epoch+1, n_epochs, error/num, mse, dt.now()-starttime))
            
            if epoch+1 == n_epochs:
                if save:
                    torch.save(net, path + "/" + model_name + '.model')
                    print("Saved Model to " + path)
                
                results_dict["train_mse_all"] = error / num
                results_dict["val_mse_all"] = mse
        
        net.eval()
        plt.figure()
        plt.plot(range(n_epochs), train_loss, label="Train")
        plt.plot(range(n_epochs), validation_loss, label="Val")
        plt.legend()
        plt.ylabel("MSE")
        plt.xlabel("Epoch")
        plt.tight_layout()
        
        groundtruth, estimation = [], []
        group_groundtruth = defaultdict(list)
        group_estimation = defaultdict(list)
        for users, items, ratings in getbatches(testdata, batch_size, use_cuda, False):
            predictions = net(users, items)
            estimation.extend(predictions.tolist())
            groundtruth.extend(ratings.tolist())
            
            for uid, iid, r, p in zip(users.cpu().numpy(), items.cpu().numpy(), ratings.cpu().numpy(), pred.detach().cpu().numpy()):
                if uid in low:
                    group_groundtruth["low"].append(r)
                    group_estimation["low"].append(p)
                elif uid in med:
                    group_groundtruth["med"].append(r)
                    group_estimation["med"].append(p)
                elif uid in high:
                    group_groundtruth["high"].append(r)
                    group_estimation["high"].append(p)
        
        test_mae, test_mse = get_eval(np.array(groundtruth), np.array(estimation))
        low_mae, low_mse = get_eval(np.array(group_groundtruth["low"]), np.array(group_estimation["low"]))
        med_mae, med_mse = get_eval(np.array(group_groundtruth["med"]), np.array(group_estimation["med"]))
        high_mae, high_mse = get_eval(np.array(group_groundtruth["high"]), np.array(group_estimation["high"]))
        
        results_dict["test_mse_all"] = test_mse
        results_dict["test_mae_all"] = test_mae
        results_dict["test_mse_low"] = low_mse
        results_dict["test_mae_low"] = low_mae
        results_dict["test_mse_med"] = med_mse
        results_dict["test_mae_med"] = med_mae
        results_dict["test_mse_high"] = high_mse
        results_dict["test_mae_high"] = high_mae
        
        results.append(results_dict)
        
        if save:
            plt.savefig(path + "/" + model_name + ".png", dpi=300)
            pd.DataFrame(results).to_csv(path + "/results.csv", index=False)
            print("Saved Results to " + path)

In [None]:
train, val, test = read_dataset("data/ciao")
users, items = read_useranditemlist("data/ciao")
low, med, high = read_usergroups("data/ciao")
run("experiments/meta/ciao", train, val, test, users, items, low, med, high, save=True, disable_meta_learning=False, 
    hyperparameters={"lr": 0.0001, "lambda": 0.001, "batch_size": 64, "n_epochs": 100})

run("experiments/nometa/ciao", train, val, test, users, items, low, med, high, save=True, disable_meta_learning=True, 
    hyperparameters={"lr": 0.0001, "lambda": 0.001, "batch_size": 64, "n_epochs": 100})

beta_100p
Epoch 1/100 - Training Loss: 1.369, Validation Loss: 1.017, Time Elapsed: 0:11:25.996626
Epoch 2/100 - Training Loss: 0.796, Validation Loss: 0.993, Time Elapsed: 0:22:39.187458
Epoch 3/100 - Training Loss: 0.557, Validation Loss: 1.030, Time Elapsed: 0:33:53.176392
Epoch 4/100 - Training Loss: 0.477, Validation Loss: 1.053, Time Elapsed: 0:45:08.207205
Epoch 5/100 - Training Loss: 0.418, Validation Loss: 1.069, Time Elapsed: 0:56:23.807940
Epoch 6/100 - Training Loss: 0.395, Validation Loss: 1.083, Time Elapsed: 1:07:39.890439
Epoch 7/100 - Training Loss: 0.399, Validation Loss: 1.071, Time Elapsed: 1:18:55.645740
Epoch 8/100 - Training Loss: 0.431, Validation Loss: 1.057, Time Elapsed: 1:30:10.246548
Epoch 9/100 - Training Loss: 0.474, Validation Loss: 1.017, Time Elapsed: 1:41:23.436225
Epoch 10/100 - Training Loss: 0.514, Validation Loss: 0.987, Time Elapsed: 1:52:35.301145
Epoch 11/100 - Training Loss: 0.543, Validation Loss: 0.965, Time Elapsed: 2:03:46.001713
Epoch 12/

Epoch 92/100 - Training Loss: 0.387, Validation Loss: 1.111, Time Elapsed: 17:09:25.344999
Epoch 93/100 - Training Loss: 0.386, Validation Loss: 1.131, Time Elapsed: 17:20:36.222064
Epoch 94/100 - Training Loss: 0.387, Validation Loss: 1.110, Time Elapsed: 17:31:47.084030
Epoch 95/100 - Training Loss: 0.387, Validation Loss: 1.100, Time Elapsed: 17:42:57.995836
Epoch 96/100 - Training Loss: 0.387, Validation Loss: 1.105, Time Elapsed: 17:54:09.007871
Epoch 97/100 - Training Loss: 0.386, Validation Loss: 1.108, Time Elapsed: 18:05:20.114289
Epoch 98/100 - Training Loss: 0.386, Validation Loss: 1.111, Time Elapsed: 18:16:31.168728
Epoch 99/100 - Training Loss: 0.385, Validation Loss: 1.109, Time Elapsed: 18:27:42.278307
Epoch 100/100 - Training Loss: 0.386, Validation Loss: 1.109, Time Elapsed: 18:38:53.388336
Saved Model to experiments/meta/ciao
Saved Results to experiments/meta/ciao
beta_90p
Epoch 1/100 - Training Loss: 1.423, Validation Loss: 1.043, Time Elapsed: 0:10:36.600237
Epoch 

Epoch 82/100 - Training Loss: 0.249, Validation Loss: 1.182, Time Elapsed: 14:02:38.423460
Epoch 83/100 - Training Loss: 0.249, Validation Loss: 1.188, Time Elapsed: 14:12:53.898851
Epoch 84/100 - Training Loss: 0.248, Validation Loss: 1.210, Time Elapsed: 14:23:08.051560
Epoch 85/100 - Training Loss: 0.248, Validation Loss: 1.173, Time Elapsed: 14:33:21.326383
Epoch 86/100 - Training Loss: 0.248, Validation Loss: 1.202, Time Elapsed: 14:43:34.572610
Epoch 87/100 - Training Loss: 0.247, Validation Loss: 1.178, Time Elapsed: 14:53:47.891076
Epoch 88/100 - Training Loss: 0.248, Validation Loss: 1.188, Time Elapsed: 15:04:01.113679
Epoch 89/100 - Training Loss: 0.247, Validation Loss: 1.184, Time Elapsed: 15:14:13.808905
Epoch 90/100 - Training Loss: 0.246, Validation Loss: 1.199, Time Elapsed: 15:24:26.445062
Epoch 91/100 - Training Loss: 0.245, Validation Loss: 1.195, Time Elapsed: 15:34:39.048338
Epoch 92/100 - Training Loss: 0.246, Validation Loss: 1.211, Time Elapsed: 15:44:51.662013

Epoch 72/100 - Training Loss: 0.212, Validation Loss: 1.245, Time Elapsed: 10:55:29.013294
Epoch 73/100 - Training Loss: 0.212, Validation Loss: 1.227, Time Elapsed: 11:04:34.233409
Epoch 74/100 - Training Loss: 0.211, Validation Loss: 1.254, Time Elapsed: 11:13:39.359139
Epoch 75/100 - Training Loss: 0.212, Validation Loss: 1.249, Time Elapsed: 11:22:44.523439
Epoch 76/100 - Training Loss: 0.210, Validation Loss: 1.250, Time Elapsed: 11:31:49.637828
Epoch 77/100 - Training Loss: 0.209, Validation Loss: 1.251, Time Elapsed: 11:40:54.719257
Epoch 78/100 - Training Loss: 0.210, Validation Loss: 1.246, Time Elapsed: 11:49:59.813554
Epoch 79/100 - Training Loss: 0.209, Validation Loss: 1.243, Time Elapsed: 11:59:04.935491
Epoch 80/100 - Training Loss: 0.209, Validation Loss: 1.248, Time Elapsed: 12:08:10.151175
Epoch 81/100 - Training Loss: 0.208, Validation Loss: 1.245, Time Elapsed: 12:17:15.272757
Epoch 82/100 - Training Loss: 0.208, Validation Loss: 1.233, Time Elapsed: 12:26:20.466984

Epoch 62/100 - Training Loss: 0.173, Validation Loss: 1.250, Time Elapsed: 8:16:38.511102
Epoch 63/100 - Training Loss: 0.172, Validation Loss: 1.269, Time Elapsed: 8:24:38.032785
Epoch 64/100 - Training Loss: 0.172, Validation Loss: 1.250, Time Elapsed: 8:32:37.724807
Epoch 65/100 - Training Loss: 0.171, Validation Loss: 1.257, Time Elapsed: 8:40:37.261665
Epoch 66/100 - Training Loss: 0.170, Validation Loss: 1.256, Time Elapsed: 8:48:36.817454
Epoch 67/100 - Training Loss: 0.170, Validation Loss: 1.257, Time Elapsed: 8:56:36.433636
Epoch 68/100 - Training Loss: 0.168, Validation Loss: 1.247, Time Elapsed: 9:04:35.929987
Epoch 69/100 - Training Loss: 0.168, Validation Loss: 1.265, Time Elapsed: 9:12:35.531460
Epoch 70/100 - Training Loss: 0.167, Validation Loss: 1.266, Time Elapsed: 9:20:35.105636
Epoch 71/100 - Training Loss: 0.167, Validation Loss: 1.262, Time Elapsed: 9:28:34.626348
Epoch 72/100 - Training Loss: 0.167, Validation Loss: 1.260, Time Elapsed: 9:36:34.226772
Epoch 73/1

Epoch 52/100 - Training Loss: 0.132, Validation Loss: 1.302, Time Elapsed: 5:58:56.714101
Epoch 53/100 - Training Loss: 0.131, Validation Loss: 1.287, Time Elapsed: 6:05:49.684204
Epoch 54/100 - Training Loss: 0.130, Validation Loss: 1.306, Time Elapsed: 6:12:42.823995
Epoch 55/100 - Training Loss: 0.129, Validation Loss: 1.296, Time Elapsed: 6:19:35.796979
Epoch 56/100 - Training Loss: 0.127, Validation Loss: 1.307, Time Elapsed: 6:26:28.947634
Epoch 57/100 - Training Loss: 0.126, Validation Loss: 1.304, Time Elapsed: 6:33:21.932831
Epoch 58/100 - Training Loss: 0.126, Validation Loss: 1.308, Time Elapsed: 6:40:14.914324
Epoch 59/100 - Training Loss: 0.125, Validation Loss: 1.299, Time Elapsed: 6:47:08.284314
Epoch 60/100 - Training Loss: 0.124, Validation Loss: 1.293, Time Elapsed: 6:54:01.360395
Epoch 61/100 - Training Loss: 0.123, Validation Loss: 1.310, Time Elapsed: 7:00:54.266085
Epoch 62/100 - Training Loss: 0.122, Validation Loss: 1.309, Time Elapsed: 7:07:47.169544
Epoch 63/1

Epoch 42/100 - Training Loss: 0.125, Validation Loss: 1.239, Time Elapsed: 4:01:01.923410
Epoch 43/100 - Training Loss: 0.122, Validation Loss: 1.227, Time Elapsed: 4:06:45.823418
Epoch 44/100 - Training Loss: 0.122, Validation Loss: 1.239, Time Elapsed: 4:12:29.296867
Epoch 45/100 - Training Loss: 0.118, Validation Loss: 1.263, Time Elapsed: 4:18:13.088354
Epoch 46/100 - Training Loss: 0.117, Validation Loss: 1.209, Time Elapsed: 4:23:57.122852
Epoch 47/100 - Training Loss: 0.115, Validation Loss: 1.230, Time Elapsed: 4:29:41.081940
Epoch 48/100 - Training Loss: 0.113, Validation Loss: 1.251, Time Elapsed: 4:35:25.489774
Epoch 49/100 - Training Loss: 0.111, Validation Loss: 1.255, Time Elapsed: 4:41:09.874561
Epoch 50/100 - Training Loss: 0.110, Validation Loss: 1.280, Time Elapsed: 4:46:54.297502
Epoch 51/100 - Training Loss: 0.108, Validation Loss: 1.245, Time Elapsed: 4:52:38.775406
Epoch 52/100 - Training Loss: 0.107, Validation Loss: 1.275, Time Elapsed: 4:58:23.254629
Epoch 53/1

Epoch 32/100 - Training Loss: 0.097, Validation Loss: 1.218, Time Elapsed: 2:30:30.625188
Epoch 33/100 - Training Loss: 0.096, Validation Loss: 1.218, Time Elapsed: 2:35:12.905731
Epoch 34/100 - Training Loss: 0.096, Validation Loss: 1.231, Time Elapsed: 2:39:55.129621
Epoch 35/100 - Training Loss: 0.095, Validation Loss: 1.245, Time Elapsed: 2:44:37.398316
Epoch 36/100 - Training Loss: 0.093, Validation Loss: 1.220, Time Elapsed: 2:49:19.557451
Epoch 37/100 - Training Loss: 0.093, Validation Loss: 1.251, Time Elapsed: 2:54:01.636464
Epoch 38/100 - Training Loss: 0.091, Validation Loss: 1.248, Time Elapsed: 2:58:43.695468
Epoch 39/100 - Training Loss: 0.092, Validation Loss: 1.264, Time Elapsed: 3:03:25.703323
Epoch 40/100 - Training Loss: 0.089, Validation Loss: 1.269, Time Elapsed: 3:08:07.682171
Epoch 41/100 - Training Loss: 0.090, Validation Loss: 1.231, Time Elapsed: 3:12:49.640630
Epoch 42/100 - Training Loss: 0.088, Validation Loss: 1.250, Time Elapsed: 3:17:31.444983
Epoch 43/1

Epoch 22/100 - Training Loss: 0.071, Validation Loss: 1.163, Time Elapsed: 1:19:23.284412
Epoch 23/100 - Training Loss: 0.067, Validation Loss: 1.169, Time Elapsed: 1:22:59.993695
Epoch 24/100 - Training Loss: 0.067, Validation Loss: 1.185, Time Elapsed: 1:26:37.126236
Epoch 25/100 - Training Loss: 0.067, Validation Loss: 1.177, Time Elapsed: 1:30:14.579466
Epoch 26/100 - Training Loss: 0.066, Validation Loss: 1.193, Time Elapsed: 1:33:52.077279
Epoch 27/100 - Training Loss: 0.064, Validation Loss: 1.201, Time Elapsed: 1:37:29.583328
Epoch 28/100 - Training Loss: 0.063, Validation Loss: 1.204, Time Elapsed: 1:41:07.302227
Epoch 29/100 - Training Loss: 0.063, Validation Loss: 1.196, Time Elapsed: 1:44:45.025594
Epoch 30/100 - Training Loss: 0.063, Validation Loss: 1.201, Time Elapsed: 1:48:22.717045
Epoch 31/100 - Training Loss: 0.063, Validation Loss: 1.221, Time Elapsed: 1:52:00.342410
Epoch 32/100 - Training Loss: 0.062, Validation Loss: 1.206, Time Elapsed: 1:55:38.114722
Epoch 33/1

Epoch 12/100 - Training Loss: 0.140, Validation Loss: 1.103, Time Elapsed: 0:29:46.193229
Epoch 13/100 - Training Loss: 0.132, Validation Loss: 1.112, Time Elapsed: 0:32:12.351340
Epoch 14/100 - Training Loss: 0.127, Validation Loss: 1.108, Time Elapsed: 0:34:38.736096
Epoch 15/100 - Training Loss: 0.121, Validation Loss: 1.109, Time Elapsed: 0:37:05.102549
Epoch 16/100 - Training Loss: 0.115, Validation Loss: 1.113, Time Elapsed: 0:39:31.486933
Epoch 17/100 - Training Loss: 0.115, Validation Loss: 1.134, Time Elapsed: 0:41:57.956112
Epoch 18/100 - Training Loss: 0.103, Validation Loss: 1.128, Time Elapsed: 0:44:24.165681
Epoch 19/100 - Training Loss: 0.108, Validation Loss: 1.121, Time Elapsed: 0:46:49.908840
Epoch 20/100 - Training Loss: 0.101, Validation Loss: 1.130, Time Elapsed: 0:49:15.897266
Epoch 21/100 - Training Loss: 0.091, Validation Loss: 1.121, Time Elapsed: 0:51:41.422070
Epoch 22/100 - Training Loss: 0.095, Validation Loss: 1.125, Time Elapsed: 0:54:37.587833
Epoch 23/1

Epoch 2/100 - Training Loss: 0.975, Validation Loss: 1.419, Time Elapsed: 0:03:11.216632
Epoch 3/100 - Training Loss: 0.593, Validation Loss: 1.359, Time Elapsed: 0:04:34.330349
Epoch 4/100 - Training Loss: 0.376, Validation Loss: 1.328, Time Elapsed: 0:05:56.879167
Epoch 5/100 - Training Loss: 0.278, Validation Loss: 1.280, Time Elapsed: 0:07:19.489024
Epoch 6/100 - Training Loss: 0.197, Validation Loss: 1.250, Time Elapsed: 0:08:41.711910
Epoch 7/100 - Training Loss: 0.151, Validation Loss: 1.235, Time Elapsed: 0:10:03.607676
Epoch 8/100 - Training Loss: 0.129, Validation Loss: 1.224, Time Elapsed: 0:11:25.458032
Epoch 9/100 - Training Loss: 0.118, Validation Loss: 1.205, Time Elapsed: 0:12:47.145465
Epoch 10/100 - Training Loss: 0.115, Validation Loss: 1.183, Time Elapsed: 0:14:08.868709
Epoch 11/100 - Training Loss: 0.110, Validation Loss: 1.170, Time Elapsed: 0:15:30.673906
Epoch 12/100 - Training Loss: 0.106, Validation Loss: 1.171, Time Elapsed: 0:16:52.443317
Epoch 13/100 - Tra

In [None]:
train, val, test = read_dataset("data/jester")
users, items = read_useranditemlist("data/jester")
low, med, high = read_usergroups("data/jester")
run("experiments/meta/jester", train, val, test, users, items, low, med, high, save=True, disable_meta_learning=False, 
    hyperparameters={"lr": 0.0001, "lambda": 0.001, "batch_size": 64, "n_epochs": 100})

run("experiments/nometa/jester", train, val, test, users, items, low, med, high, save=True, disable_meta_learning=True, 
    hyperparameters={"lr": 0.0001, "lambda": 0.001, "batch_size": 64, "n_epochs": 100})