In [None]:
import pickle
import pandas as pd
import numpy as np
from sklearn import model_selection, metrics, preprocessing
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

In [3]:
class Model(torch.nn.Module):
    def __init__(self, n_users, n_items, n_factors=20):
        super().__init__()
        # create user and item embeddings
        self.user_factors = torch.nn.Embedding(n_users, n_factors)
        self.movie_factors = torch.nn.Embedding(n_items, n_factors)
        # fills weights with values from a uniform distribution [0, 0.5]
        self.user_factors.weight.data.uniform_(0, 0.05)
        self.movie_factors.weight.data.uniform_(0, 0.05)
    
    def forward(self, data):
        # matrix multiplication between user and item factors, and then concatenates them to one column
        return (self.user_factors(data[:,0])*self.movie_factors(data[:,1])).sum(1)

In [4]:
with open('recSys.pkl', 'rb') as f:
    model = pickle.load(f)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.data)

user_factors.weight tensor([[ 1.2879e+00,  1.2351e+00,  1.4475e+00,  ...,  8.7125e-01,
          1.2981e+00,  1.4134e+00],
        [-4.7784e-01,  1.9576e+00,  8.5370e-01,  ...,  1.1529e+00,
          1.2158e+00,  1.2119e+00],
        [ 1.6130e+00, -4.9406e-01,  1.2355e+00,  ...,  4.9674e-01,
          5.0186e-01, -5.5076e-01],
        ...,
        [ 1.6056e+00, -5.8401e-04, -6.1805e-01,  ...,  1.2846e+00,
          1.2682e+00,  7.6898e-01],
        [ 9.3464e-01,  1.5208e+00,  3.5752e-01,  ...,  1.1465e+00,
          6.3008e-01,  1.1093e+00],
        [ 9.2243e-01,  5.6559e-01,  1.7987e+00,  ...,  1.8294e+00,
          4.6336e-01,  1.0299e+00]])
movie_factors.weight tensor([[ 0.2463,  0.5204,  0.7100,  ...,  0.6783,  0.3932,  0.7677],
        [ 0.1969,  0.6254,  0.5293,  ...,  0.2248,  0.6878,  0.3593],
        [ 0.1623,  0.5284,  0.5087,  ...,  0.1088, -0.0968,  1.0053],
        ...,
        [ 0.4141,  0.4185,  0.3778,  ...,  0.3825,  0.4037,  0.4099],
        [ 0.4147,  0.3899,  0.4097

  return torch.load(io.BytesIO(b))
