# Graph Neural Networks for Social Recommendations

In [1]:
import torch
import torch.nn as nn
import random
import pickle
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

## Create torch dataset and preprocessing functions

In [2]:
class GraphDataset(Dataset):
    def __init__(self, data, u_items_list, u_user_list, u_users_items_list, i_users_list):
        self.data = data
        self.u_items_list = u_items_list
        self.u_users_list = u_user_list
        self.u_users_items_list = u_users_items_list
        self.i_users_list = i_users_list
    
    def __getitem__(self, index):
        uid = self.data[index][0]
        iid = self.data[index][1]
        label = self.data[index][2]
        u_items = self.u_items_list[uid]
        u_users = self.u_users_list[uid]
        u_users_items = self.u_users_items_list[uid]
        i_users = self.i_users_list[iid]

        return (uid, iid, label), u_items, u_users, u_users_items, i_users

    def __len__(self):
        return len(self.data)


In [3]:
truncate_len = 45

def collate_fn(batch_data):

    uids, iids, labels = [], [], []
    u_items, u_users, u_users_items, i_users = [], [], [], []
    u_items_len, u_users_len, i_users_len = [], [], []

    for data, u_items_u, u_users_u, u_users_items_u, i_users_i in batch_data:

        (uid, iid, label) = data
        uids.append(uid)
        iids.append(iid)
        labels.append(label)

        # user-items
        if len(u_items_u) <= truncate_len:
            u_items.append(u_items_u)
        else:
            u_items.append(random.sample(u_items_u, truncate_len))
        u_items_len.append(min(len(u_items_u), truncate_len))
        
        # user-users and user-users-items
        if len(u_users_u) <= truncate_len:
            u_users.append(u_users_u)
            u_u_items = [] 
            for uui in u_users_items_u:
                if len(uui) < truncate_len:
                    u_u_items.append(uui)
                else:
                    u_u_items.append(random.sample(uui, truncate_len))
            u_users_items.append(u_u_items)
        else:
            sample_index = random.sample(list(range(len(u_users_u))), truncate_len)
            u_users.append([u_users_u[si] for si in sample_index])

            u_users_items_u_tr = [u_users_items_u[si] for si in sample_index]
            u_u_items = [] 
            for uui in u_users_items_u_tr:
                if len(uui) < truncate_len:
                    u_u_items.append(uui)
                else:
                    u_u_items.append(random.sample(uui, truncate_len))
            u_users_items.append(u_u_items)

        u_users_len.append(min(len(u_users_u), truncate_len))	

        # item-users
        if len(i_users_i) <= truncate_len:
            i_users.append(i_users_i)
        else:
            i_users.append(random.sample(i_users_i, truncate_len))
        i_users_len.append(min(len(i_users_i), truncate_len))

    batch_size = len(batch_data)

    # padding
    u_items_maxlen = max(u_items_len)
    u_users_maxlen = max(u_users_len)
    i_users_maxlen = max(i_users_len)
    
    u_item_pad = torch.zeros([batch_size, u_items_maxlen, 2], dtype=torch.long)
    for i, ui in enumerate(u_items):
        u_item_pad[i, :len(ui), :] = torch.LongTensor(ui)
    
    u_user_pad = torch.zeros([batch_size, u_users_maxlen], dtype=torch.long)
    for i, uu in enumerate(u_users):
        u_user_pad[i, :len(uu)] = torch.LongTensor(uu)
    
    u_user_item_pad = torch.zeros([batch_size, u_users_maxlen, u_items_maxlen, 2], dtype=torch.long)
    for i, uu_items in enumerate(u_users_items):
        for j, ui in enumerate(uu_items):
            u_user_item_pad[i, j, :len(ui), :] = torch.LongTensor(ui)

    i_user_pad = torch.zeros([batch_size, i_users_maxlen, 2], dtype=torch.long)
    for i, iu in enumerate(i_users):
        i_user_pad[i, :len(iu), :] = torch.LongTensor(iu)

    uids = torch.LongTensor(uids)
    iids = torch.LongTensor(iids)
    labels = torch.FloatTensor(labels)

    return uids, iids, labels, u_item_pad, u_user_pad, u_user_item_pad, i_user_pad

## Create model classes

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim//2, bias=True),
            nn.ReLU(),
            nn.Linear(input_dim//2, output_dim, bias=True)
        )

    def forward(self, x):
        return self.mlp(x)

class Aggregator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Aggregator, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, output_dim, bias=True),
            nn.ReLU()
        )

    def forward(self, x):
        return self.mlp(x)


class UserModel(nn.Module):
    def __init__(self, emb_dim, user_emb, item_emb, rating_emb):
        super(UserModel, self).__init__()
        self.emb_dim = emb_dim
        self.user_emb = user_emb
        self.item_emb = item_emb
        self.rating_emb = rating_emb

        self.g_v = MLP(2*self.emb_dim, self.emb_dim)
        
        self.user_item_attn = MLP(2*self.emb_dim, 1)
        self.aggr_items = Aggregator(self.emb_dim, self.emb_dim)

        self.user_user_attn = MLP(2*self.emb_dim, 1)
        self.aggr_neighbors = Aggregator(self.emb_dim, self.emb_dim)

        self.mlp = nn.Sequential(
            nn.Linear(2*self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias = True),
            nn.ReLU()
        )

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.eps = 1e-10

    def forward(self, uids, u_item_pad, u_user_pad, u_user_item_pad):

        q_a = self.item_emb(u_item_pad[:,:,0])
        u_item_er = self.rating_emb(u_item_pad[:,:,1])
        x_ia = self.g_v(torch.cat([q_a, u_item_er], dim=2).view(-1, 2*self.emb_dim)).view(q_a.size())
        mask_u = torch.where(u_item_pad[:,:,0]>0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        p_i = mask_u.unsqueeze(2).expand_as(x_ia) * self.user_emb(uids).unsqueeze(1).expand_as(x_ia)
        alpha = self.user_item_attn(torch.cat([x_ia, p_i], dim=2).view(-1, 2*self.emb_dim)).view(mask_u.size())
        alpha = torch.exp(alpha)*mask_u
        alpha = alpha / (torch.sum(alpha, 1).unsqueeze(1).expand_as(alpha) + self.eps)
        h_iI = self.aggr_items(torch.sum(alpha.unsqueeze(2).expand_as(x_ia) * x_ia, 1))


        q_a_s = self.item_emb(u_user_item_pad[:,:,:,0])
        u_user_item_er = self.rating_emb(u_user_item_pad[:,:,:,1])
        x_ia_s = self.g_v(torch.cat([q_a_s, u_user_item_er], dim=2).view(-1, 2*self.emb_dim)).view(q_a_s.size())
        mask_s = torch.where(u_user_item_pad[:,:,:,0]>0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        p_i_s = mask_s.unsqueeze(3).expand_as(x_ia_s) * self.user_emb(u_user_pad).unsqueeze(2).expand_as(x_ia_s)
        alpha_s = self.user_item_attn(torch.cat([x_ia_s, p_i_s], dim=3).view(-1, 2*self.emb_dim)).view(mask_s.size())
        alpha_s = torch.exp(alpha_s)*mask_s
        alpha_s = alpha_s / (torch.sum(alpha_s, 2).unsqueeze(2).expand_as(alpha_s) + self.eps)
        h_oI_temp = torch.sum(alpha_s.unsqueeze(3).expand_as(x_ia_s) * x_ia_s, 2)
        h_oI = self.aggr_items(h_oI_temp.view(-1, self.emb_dim)).view(h_oI_temp.size())
        
        beta = self.user_user_attn(torch.cat([h_oI, self.user_emb(u_user_pad)], dim = 2).view(-1, 2 * self.emb_dim)).view(u_user_pad.size())
        mask_su = torch.where(u_user_pad > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        beta = torch.exp(beta) * mask_su
        beta = beta / (torch.sum(beta, 1).unsqueeze(1).expand_as(beta) + self.eps)
        h_iS = self.aggr_neighbors(torch.sum(beta.unsqueeze(2).expand_as(h_oI) * h_oI, 1))

        h_i = self.mlp(torch.cat([h_iI, h_iS], dim = 1))

        return h_i


class ItemModel(nn.Module):
    def __init__(self, emb_dim, user_emb, item_emb, rating_emb):
        super(ItemModel, self).__init__()
        self.emb_dim = emb_dim
        self.user_emb = user_emb
        self.item_emb = item_emb
        self.rating_emb = rating_emb

        self.g_u = MLP(2*self.emb_dim, self.emb_dim)

        self.item_users_attn = MLP(2*self.emb_dim, 1)
        self.aggr_users = Aggregator(self.emb_dim, self.emb_dim)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.eps = 1e-10
    
    def forward(self, iids, i_user_pad):

        p_t = self.user_emb(i_user_pad[:,:,0])
        i_user_er = self.rating_emb(i_user_pad[:,:,1])
        mask_i = torch.where(i_user_pad[:,:,0] > 0, torch.tensor([1.], device=self.device), torch.tensor([0.], device=self.device))
        f_jt = self.g_u(torch.cat([p_t, i_user_er], dim = 2).view(-1, 2 * self.emb_dim)).view(p_t.size())
        q_j = mask_i.unsqueeze(2).expand_as(f_jt) * self.item_emb(iids).unsqueeze(1).expand_as(f_jt)
        mu_jt = self.item_users_attn(torch.cat([f_jt, q_j], dim = 2).view(-1, 2 * self.emb_dim)).view(mask_i.size())
        mu_jt = torch.exp(mu_jt) * mask_i
        mu_jt = mu_jt / (torch.sum(mu_jt, 1).unsqueeze(1).expand_as(mu_jt) + self.eps)
        
        z_j = self.aggr_users(torch.sum(mu_jt.unsqueeze(2).expand_as(f_jt) * f_jt, 1))

        return z_j
        
    
class GraphRec(nn.Module):
    def __init__(self, n_users, n_items, n_ratings, emb_dim = 64):
        super(GraphRec, self).__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.n_ratings = n_ratings
        self.emb_dim = emb_dim

        self.user_emb = nn.Embedding(self.n_users, self.emb_dim, padding_idx=0)
        self.item_emb = nn.Embedding(self.n_items, self.emb_dim, padding_idx=0)
        self.rating_emb = nn.Embedding(self.n_ratings, self.emb_dim, padding_idx=0)

        self.user_model = UserModel(self.emb_dim, self.user_emb, self.item_emb, self.rating_emb)
        self.item_model = ItemModel(self.emb_dim, self.user_emb, self.item_emb, self.rating_emb)

        self.mlp = nn.Sequential(
            nn.Linear(2*self.emb_dim, self.emb_dim, bias=True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, self.emb_dim, bias=True),
            nn.ReLU(),
            nn.Linear(self.emb_dim, 1)
        )

    def forward(self, uids, iids, u_item_pad, u_user_pad, u_user_item_pad, i_user_pad):

        h_i = self.user_model(uids, u_item_pad, u_user_pad, u_user_item_pad)
        z_j = self.item_model(iids, i_user_pad)

        r_ij = self.mlp(torch.cat([h_i, z_j], dim=1))

        return r_ij

## Set up hyper-parameters

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device - ' + str(device))
batch_size = 128
embed_dim = 64
learning_rate = 0.001
n_epochs = 30

device - cuda


## Read dataset and preprocess it to form batches

In [6]:
with open('data/dataset_epinions.pkl', 'rb') as f:
    train_set = pickle.load(f)
    valid_set = pickle.load(f)
    test_set = pickle.load(f)

with open('data/list_epinions.pkl', 'rb') as f:
    u_items_list = pickle.load(f)
    u_users_list = pickle.load(f)
    u_users_items_list = pickle.load(f)
    i_users_list = pickle.load(f)
    (user_count, item_count, rate_count) = pickle.load(f)

In [7]:
train_data = GraphDataset(train_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
valid_data = GraphDataset(valid_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
test_data = GraphDataset(test_set, u_items_list, u_users_list, u_users_items_list, i_users_list)

In [None]:
for i in train_data:
    for j in i:
        print(j)
    break

In [9]:
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn = collate_fn)
valid_loader = DataLoader(valid_data, batch_size = batch_size, shuffle = False, collate_fn = collate_fn)
test_loader = DataLoader(test_data, batch_size = batch_size, shuffle = False, collate_fn = collate_fn)

In [10]:
len(train_loader)

5704

In [None]:
for i in train_loader:
    for j in i:
        print(j)
    break

## Create the model and set up training process

In [12]:
model = GraphRec(user_count+1, item_count+1, rate_count+1, embed_dim).to(device)

In [13]:
optimizer = torch.optim.RMSprop(model.parameters(), learning_rate)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 4, gamma = 0.1)

## Train the model

In [14]:
for epoch in range(n_epochs):

    # Training step
    model.train()
    s_loss = 0
    for i, (uids, iids, labels, u_items, u_users, u_users_items, i_users) in tqdm(enumerate(train_loader), total=len(train_loader)):
        uids = uids.to(device)
        iids = iids.to(device)
        labels = labels.to(device)
        u_items = u_items.to(device)
        u_users = u_users.to(device)
        u_users_items = u_users_items.to(device)
        i_users = i_users.to(device)
        
        optimizer.zero_grad()
        outputs = model(uids, iids, u_items, u_users, u_users_items, i_users)
        loss = criterion(outputs, labels.unsqueeze(1))

        loss.backward()
        optimizer.step()

        loss_val = loss.item()
        s_loss += loss_val

        iter_num = epoch * len(train_loader) + i + 1

    # Validate step
    model.eval()
    errors = []
    with torch.no_grad():
        for uids, iids, labels, u_items, u_users, u_users_items, i_users in tqdm(valid_loader):
            uids = uids.to(device)
            iids = iids.to(device)
            labels = labels.to(device)
            u_items = u_items.to(device)
            u_users = u_users.to(device)
            u_users_items = u_users_items.to(device)
            i_users = i_users.to(device)
            preds = model(uids, iids, u_items, u_users, u_users_items, i_users)
            error = torch.abs(preds.squeeze(1) - labels)
            errors.extend(error.data.cpu().numpy().tolist())
    
    mae = np.mean(errors)
    rmse = np.sqrt(np.mean(np.power(errors, 2)))

    scheduler.step()

    ckpt_dict = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }

    torch.save(ckpt_dict, 'trained models epinions/latest_checkpoint.pth')

    if epoch == 0:
        best_mae = mae
    elif mae < best_mae:
        best_mae = mae
        torch.save(ckpt_dict, 'trained models epinions/best_checkpoint_{}.pth'.format(embed_dim))

    print('Epoch {} validation: MAE: {:.4f}, RMSE: {:.4f}, Best MAE: {:.4f}'.format(epoch+1, mae, rmse, best_mae))

100%|██████████| 2852/2852 [08:30<00:00,  5.58it/s]
100%|██████████| 357/357 [00:48<00:00,  7.32it/s]


Epoch 1 validation: MAE: 2.1474, RMSE: 4.8366, Best MAE: 2.1474


100%|██████████| 2852/2852 [08:36<00:00,  5.53it/s]
100%|██████████| 357/357 [00:42<00:00,  8.30it/s]


Epoch 2 validation: MAE: 1.7126, RMSE: 4.5229, Best MAE: 1.7126


100%|██████████| 2852/2852 [08:01<00:00,  5.92it/s]
100%|██████████| 357/357 [00:45<00:00,  7.93it/s]


Epoch 3 validation: MAE: 1.7524, RMSE: 4.3608, Best MAE: 1.7126


100%|██████████| 2852/2852 [07:57<00:00,  5.97it/s]
100%|██████████| 357/357 [00:43<00:00,  8.16it/s]


Epoch 4 validation: MAE: 1.4650, RMSE: 4.1344, Best MAE: 1.4650


100%|██████████| 2852/2852 [08:14<00:00,  5.76it/s]
100%|██████████| 357/357 [00:41<00:00,  8.50it/s]


Epoch 5 validation: MAE: 1.4697, RMSE: 4.1269, Best MAE: 1.4650


100%|██████████| 2852/2852 [07:32<00:00,  6.30it/s]
100%|██████████| 357/357 [00:46<00:00,  7.70it/s]


Epoch 6 validation: MAE: 1.4452, RMSE: 4.1138, Best MAE: 1.4452


100%|██████████| 2852/2852 [07:42<00:00,  6.17it/s]
100%|██████████| 357/357 [00:40<00:00,  8.74it/s]


Epoch 7 validation: MAE: 1.4397, RMSE: 4.1021, Best MAE: 1.4397


100%|██████████| 2852/2852 [07:23<00:00,  6.43it/s]
100%|██████████| 357/357 [00:40<00:00,  8.86it/s]


Epoch 8 validation: MAE: 1.4319, RMSE: 4.0940, Best MAE: 1.4319


100%|██████████| 2852/2852 [07:23<00:00,  6.43it/s]
100%|██████████| 357/357 [00:39<00:00,  8.94it/s]


Epoch 9 validation: MAE: 1.4318, RMSE: 4.0959, Best MAE: 1.4318


100%|██████████| 2852/2852 [07:25<00:00,  6.40it/s]
100%|██████████| 357/357 [00:43<00:00,  8.15it/s]


Epoch 10 validation: MAE: 1.4270, RMSE: 4.0926, Best MAE: 1.4270


100%|██████████| 2852/2852 [07:46<00:00,  6.12it/s]
100%|██████████| 357/357 [00:40<00:00,  8.81it/s]


Epoch 11 validation: MAE: 1.4257, RMSE: 4.0904, Best MAE: 1.4257


100%|██████████| 2852/2852 [07:30<00:00,  6.33it/s]
100%|██████████| 357/357 [00:41<00:00,  8.59it/s]


Epoch 12 validation: MAE: 1.4278, RMSE: 4.0885, Best MAE: 1.4257


100%|██████████| 2852/2852 [08:01<00:00,  5.93it/s]
100%|██████████| 357/357 [00:46<00:00,  7.62it/s]


Epoch 13 validation: MAE: 1.4245, RMSE: 4.0888, Best MAE: 1.4245


100%|██████████| 2852/2852 [07:26<00:00,  6.38it/s]
100%|██████████| 357/357 [00:40<00:00,  8.87it/s]


Epoch 14 validation: MAE: 1.4244, RMSE: 4.0889, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:45<00:00,  6.12it/s]
100%|██████████| 357/357 [00:40<00:00,  8.84it/s]


Epoch 15 validation: MAE: 1.4246, RMSE: 4.0888, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:40<00:00,  6.19it/s]
100%|██████████| 357/357 [00:43<00:00,  8.25it/s]


Epoch 16 validation: MAE: 1.4247, RMSE: 4.0884, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:31<00:00,  6.32it/s]
100%|██████████| 357/357 [00:40<00:00,  8.83it/s]


Epoch 17 validation: MAE: 1.4244, RMSE: 4.0888, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:26<00:00,  6.38it/s]
100%|██████████| 357/357 [00:42<00:00,  8.34it/s]


Epoch 18 validation: MAE: 1.4246, RMSE: 4.0889, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:24<00:00,  6.41it/s]
100%|██████████| 357/357 [00:40<00:00,  8.86it/s]


Epoch 19 validation: MAE: 1.4244, RMSE: 4.0887, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:25<00:00,  6.40it/s]
100%|██████████| 357/357 [00:40<00:00,  8.81it/s]


Epoch 20 validation: MAE: 1.4247, RMSE: 4.0892, Best MAE: 1.4244


100%|██████████| 2852/2852 [07:32<00:00,  6.30it/s]
100%|██████████| 357/357 [00:44<00:00,  7.97it/s]


Epoch 21 validation: MAE: 1.4244, RMSE: 4.0890, Best MAE: 1.4244


100%|██████████| 2852/2852 [08:25<00:00,  5.64it/s]
100%|██████████| 357/357 [00:50<00:00,  7.09it/s]


Epoch 22 validation: MAE: 1.4242, RMSE: 4.0888, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:02<00:00,  5.25it/s]
100%|██████████| 357/357 [00:50<00:00,  7.08it/s]


Epoch 23 validation: MAE: 1.4243, RMSE: 4.0888, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:06<00:00,  5.22it/s]
100%|██████████| 357/357 [00:50<00:00,  7.08it/s]


Epoch 24 validation: MAE: 1.4243, RMSE: 4.0888, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:08<00:00,  5.20it/s]
100%|██████████| 357/357 [00:51<00:00,  6.94it/s]


Epoch 25 validation: MAE: 1.4243, RMSE: 4.0886, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:05<00:00,  5.23it/s]
100%|██████████| 357/357 [00:50<00:00,  7.04it/s]


Epoch 26 validation: MAE: 1.4245, RMSE: 4.0890, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:01<00:00,  5.27it/s]
100%|██████████| 357/357 [00:50<00:00,  7.11it/s]


Epoch 27 validation: MAE: 1.4247, RMSE: 4.0889, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:02<00:00,  5.26it/s]
100%|██████████| 357/357 [00:50<00:00,  7.14it/s]


Epoch 28 validation: MAE: 1.4244, RMSE: 4.0889, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:04<00:00,  5.23it/s]
100%|██████████| 357/357 [00:50<00:00,  7.05it/s]


Epoch 29 validation: MAE: 1.4244, RMSE: 4.0888, Best MAE: 1.4242


100%|██████████| 2852/2852 [09:02<00:00,  5.25it/s]
100%|██████████| 357/357 [00:50<00:00,  7.13it/s]


Epoch 30 validation: MAE: 1.4244, RMSE: 4.0887, Best MAE: 1.4242


## Test the model

In [14]:
embed_dim = 64
checkpoint = torch.load('trained models epinions/best_checkpoint_{}.pth'.format(embed_dim))
model = GraphRec(user_count+1, item_count+1, rate_count+1, embed_dim).to(device)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [15]:
model.eval()
test_errors = []
with torch.no_grad():
    for uids, iids, labels, u_items, u_users, u_users_items, i_users in tqdm(test_loader):
        uids = uids.to(device)
        iids = iids.to(device)
        labels = labels.to(device)
        u_items = u_items.to(device)
        u_users = u_users.to(device)
        u_users_items = u_users_items.to(device)
        i_users = i_users.to(device)
        preds = model(uids, iids, u_items, u_users, u_users_items, i_users)
        error = torch.abs(preds.squeeze(1) - labels)
        test_errors.extend(error.data.cpu().numpy().tolist())

test_mae = np.mean(test_errors)
test_rmse = np.sqrt(np.mean(np.power(test_errors, 2)))
print('Test: MAE: {:.4f}, RMSE: {:.4f}'.format(test_mae, test_rmse))

100%|██████████| 713/713 [04:41<00:00,  2.53it/s]

Test: MAE: 1.2101, RMSE: 3.0421



