In [16]:
import pandas as pd
movies_df = pd.read_csv('movies.csv')
ratings_df = pd.read_csv('ratings.csv')

In [17]:
print('The dimensions of movies dataframe are:', movies_df.shape,'\nThe dimensions of ratings dataframe are:', ratings_df.shape)

The dimensions of movies dataframe are: (9742, 3) 
The dimensions of ratings dataframe are: (100836, 4)


In [18]:
movies_df.head()    

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy


In [19]:
# Take a look at ratings_df
ratings_df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931


In [20]:
movie_names = movies_df.set_index('movieId')['title'].to_dict()
n_users = len(ratings_df.userId.unique())
n_items = len(ratings_df.movieId.unique())
print("Number of unique users:", n_users)
print("Number of unique movies:", n_items)
print("The full rating matrix will have:", n_users*n_items, 'elements.')
print('----------')
print("Number of ratings:", len(ratings_df))
print("Therefore: ", len(ratings_df) / (n_users*n_items) * 100, '% of the matrix is filled.')
print("We have an incredibly sparse matrix to work with here.")
print("And... as you can imagine, as the number of users and products grow, the number of elements will increase by n*2")
print("You are going to need a lot of memory to work with global scale... storing a full matrix in memory would be a challenge.")
print("One advantage here is that matrix factorization can realize the rating matrix implicitly, thus we don't need all the data")

Number of unique users: 610
Number of unique movies: 9724
The full rating matrix will have: 5931640 elements.
----------
Number of ratings: 100836
Therefore:  1.6999683055613624 % of the matrix is filled.
We have an incredibly sparse matrix to work with here.
And... as you can imagine, as the number of users and products grow, the number of elements will increase by n*2
You are going to need a lot of memory to work with global scale... storing a full matrix in memory would be a challenge.
One advantage here is that matrix factorization can realize the rating matrix implicitly, thus we don't need all the data


In [22]:
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
class Loader(Dataset):
    def __init__(self):
        self.ratings = ratings_df.copy()
        users = ratings_df.userId.unique()
        movies = ratings_df.movieId.unique()
        self.userid2idx = {o:i for i,o in enumerate(users)}
        self.movieid2idx = {o:i for i,o in enumerate(movies)}
        
        self.idx2userid = {i:o for o,i in self.userid2idx.items()}
        self.idx2movieid = {i:o for o,i in self.movieid2idx.items()}
       
        self.ratings.movieId = ratings_df.movieId.apply(lambda x: self.movieid2idx[x])
        self.ratings.userId = ratings_df.userId.apply(lambda x: self.userid2idx[x])
        
        self.x = self.ratings.drop(['rating', 'timestamp'], axis=1).values
        self.y = self.ratings['rating'].values
        self.x, self.y = torch.tensor(self.x), torch.tensor(self.y) # Transforms the data to tensors (ready for torch models.)

    def __getitem__(self, index):
        return (self.x[index], self.y[index])

    def __len__(self):
        return len(self.ratings)

In [23]:
num_epochs = 128
cuda = torch.cuda.is_available()

print("Is running on GPU:", cuda)

model = MatrixFactorization(n_users, n_items, n_factors=8)
print(model)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

if cuda:
    model = model.cuda()

loss_fn = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_set = Loader()
train_loader = DataLoader(train_set, 128, shuffle=True)

Is running on GPU: False
MatrixFactorization(
  (user_factors): Embedding(610, 8)
  (item_factors): Embedding(9724, 8)
)
user_factors.weight tensor([[0.0090, 0.0422, 0.0428,  ..., 0.0323, 0.0233, 0.0267],
        [0.0210, 0.0101, 0.0456,  ..., 0.0394, 0.0337, 0.0369],
        [0.0248, 0.0178, 0.0171,  ..., 0.0250, 0.0088, 0.0217],
        ...,
        [0.0060, 0.0430, 0.0345,  ..., 0.0243, 0.0360, 0.0111],
        [0.0163, 0.0353, 0.0425,  ..., 0.0491, 0.0334, 0.0096],
        [0.0236, 0.0109, 0.0111,  ..., 0.0223, 0.0257, 0.0412]])
item_factors.weight tensor([[0.0311, 0.0057, 0.0391,  ..., 0.0061, 0.0314, 0.0444],
        [0.0397, 0.0284, 0.0215,  ..., 0.0198, 0.0097, 0.0420],
        [0.0473, 0.0287, 0.0430,  ..., 0.0451, 0.0484, 0.0131],
        ...,
        [0.0292, 0.0118, 0.0404,  ..., 0.0185, 0.0261, 0.0354],
        [0.0409, 0.0311, 0.0285,  ..., 0.0271, 0.0263, 0.0242],
        [0.0131, 0.0291, 0.0415,  ..., 0.0407, 0.0051, 0.0197]])


In [24]:
from tqdm import tqdm

for it in tqdm(range(num_epochs)):
    losses = []
    for x, y in train_loader:
        if cuda:
            x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        outputs = model(x)
        loss = loss_fn(outputs.squeeze(), y.type(torch.float32))
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    print("iter #{}".format(it), "Loss:", sum(losses) / len(losses))

  1%|▋                                                                                 | 1/128 [00:01<03:19,  1.57s/it]

iter #0 Loss: 11.063116820330547


  2%|█▎                                                                                | 2/128 [00:03<03:16,  1.56s/it]

iter #1 Loss: 4.739772123431191


  2%|█▉                                                                                | 3/128 [00:04<03:10,  1.52s/it]

iter #2 Loss: 2.4748059207110233


  3%|██▌                                                                               | 4/128 [00:06<03:03,  1.48s/it]

iter #3 Loss: 1.7216763558424064


  4%|███▏                                                                              | 5/128 [00:07<03:00,  1.46s/it]

iter #4 Loss: 1.3460094182170588


  5%|███▊                                                                              | 6/128 [00:08<02:57,  1.46s/it]

iter #5 Loss: 1.1283574264666756


  5%|████▍                                                                             | 7/128 [00:10<02:56,  1.46s/it]

iter #6 Loss: 0.9912560458867078


  6%|█████▏                                                                            | 8/128 [00:11<02:52,  1.44s/it]

iter #7 Loss: 0.9004685515679683


  7%|█████▊                                                                            | 9/128 [00:13<02:50,  1.43s/it]

iter #8 Loss: 0.8375387200092906


  8%|██████▎                                                                          | 10/128 [00:14<02:49,  1.44s/it]

iter #9 Loss: 0.7924832529222905


  9%|██████▉                                                                          | 11/128 [00:16<02:47,  1.44s/it]

iter #10 Loss: 0.7593215110186998


  9%|███████▌                                                                         | 12/128 [00:17<02:47,  1.44s/it]

iter #11 Loss: 0.7347150303915068


 10%|████████▏                                                                        | 13/128 [00:19<02:49,  1.47s/it]

iter #12 Loss: 0.7159909804914203


 11%|████████▊                                                                        | 14/128 [00:20<02:47,  1.47s/it]

iter #13 Loss: 0.7017569793813725


 12%|█████████▍                                                                       | 15/128 [00:21<02:42,  1.44s/it]

iter #14 Loss: 0.6905749632698025


 12%|██████████▏                                                                      | 16/128 [00:23<02:37,  1.41s/it]

iter #15 Loss: 0.6814725599842628


 13%|██████████▊                                                                      | 17/128 [00:24<02:36,  1.41s/it]

iter #16 Loss: 0.6750362537007042


 14%|███████████▍                                                                     | 18/128 [00:26<02:37,  1.44s/it]

iter #17 Loss: 0.6698792765134483


 15%|████████████                                                                     | 19/128 [00:27<02:36,  1.43s/it]

iter #18 Loss: 0.6659562979555372


 16%|████████████▋                                                                    | 20/128 [00:28<02:34,  1.43s/it]

iter #19 Loss: 0.663022628630781


 16%|█████████████▎                                                                   | 21/128 [00:30<02:31,  1.42s/it]

iter #20 Loss: 0.6605822234816358


 17%|█████████████▉                                                                   | 22/128 [00:31<02:32,  1.44s/it]

iter #21 Loss: 0.6589100635445058


 18%|██████████████▌                                                                  | 23/128 [00:33<02:33,  1.46s/it]

iter #22 Loss: 0.6577334813206329


 19%|███████████████▏                                                                 | 24/128 [00:34<02:35,  1.49s/it]

iter #23 Loss: 0.6568893155243796


 20%|███████████████▊                                                                 | 25/128 [00:36<02:34,  1.50s/it]

iter #24 Loss: 0.6559200096251395


 20%|████████████████▍                                                                | 26/128 [00:37<02:33,  1.51s/it]

iter #25 Loss: 0.6552231367393798


 21%|█████████████████                                                                | 27/128 [00:39<02:30,  1.49s/it]

iter #26 Loss: 0.6546192309502418


 22%|█████████████████▋                                                               | 28/128 [00:40<02:30,  1.50s/it]

iter #27 Loss: 0.6535344957775876


 23%|██████████████████▎                                                              | 29/128 [00:42<02:27,  1.49s/it]

iter #28 Loss: 0.6528411105955918


 23%|██████████████████▉                                                              | 30/128 [00:43<02:27,  1.50s/it]

iter #29 Loss: 0.6515228750590746


 24%|███████████████████▌                                                             | 31/128 [00:45<02:25,  1.50s/it]

iter #30 Loss: 0.6503880463805295


 25%|████████████████████▎                                                            | 32/128 [00:46<02:22,  1.49s/it]

iter #31 Loss: 0.6488460227847099


 26%|████████████████████▉                                                            | 33/128 [00:48<02:19,  1.47s/it]

iter #32 Loss: 0.6467615088381743


 27%|█████████████████████▌                                                           | 34/128 [00:49<02:16,  1.46s/it]

iter #33 Loss: 0.6442524165974051


 27%|██████████████████████▏                                                          | 35/128 [00:51<02:13,  1.43s/it]

iter #34 Loss: 0.6418197549463529


 28%|██████████████████████▊                                                          | 36/128 [00:52<02:09,  1.40s/it]

iter #35 Loss: 0.6384381148491414


 29%|███████████████████████▍                                                         | 37/128 [00:53<02:09,  1.42s/it]

iter #36 Loss: 0.6344715735407045


 30%|████████████████████████                                                         | 38/128 [00:55<02:08,  1.42s/it]

iter #37 Loss: 0.6295737778096635


 30%|████████████████████████▋                                                        | 39/128 [00:56<02:07,  1.44s/it]

iter #38 Loss: 0.6239618022883604


 31%|█████████████████████████▎                                                       | 40/128 [00:58<02:03,  1.41s/it]

iter #39 Loss: 0.6174772752798753


 32%|█████████████████████████▉                                                       | 41/128 [00:59<02:02,  1.40s/it]

iter #40 Loss: 0.6104328827885201


 33%|██████████████████████████▌                                                      | 42/128 [01:00<02:00,  1.41s/it]

iter #41 Loss: 0.6022612472173526


 34%|███████████████████████████▏                                                     | 43/128 [01:02<01:59,  1.40s/it]

iter #42 Loss: 0.593788508117804


 34%|███████████████████████████▊                                                     | 44/128 [01:03<01:56,  1.39s/it]

iter #43 Loss: 0.5845566473258328


 35%|████████████████████████████▍                                                    | 45/128 [01:05<01:58,  1.42s/it]

iter #44 Loss: 0.5750234633079035


 36%|█████████████████████████████                                                    | 46/128 [01:06<01:57,  1.43s/it]

iter #45 Loss: 0.5652648476779764


 37%|█████████████████████████████▋                                                   | 47/128 [01:08<01:55,  1.42s/it]

iter #46 Loss: 0.5556493582686192


 38%|██████████████████████████████▍                                                  | 48/128 [01:09<01:55,  1.44s/it]

iter #47 Loss: 0.5453902565056298


 38%|███████████████████████████████                                                  | 49/128 [01:11<01:54,  1.45s/it]

iter #48 Loss: 0.535953953257067


 39%|███████████████████████████████▋                                                 | 50/128 [01:12<01:53,  1.46s/it]

iter #49 Loss: 0.5266761875659378


 40%|████████████████████████████████▎                                                | 51/128 [01:14<01:53,  1.48s/it]

iter #50 Loss: 0.5170424266288123


 41%|████████████████████████████████▉                                                | 52/128 [01:15<01:52,  1.47s/it]

iter #51 Loss: 0.5084192784607108


 41%|█████████████████████████████████▌                                               | 53/128 [01:16<01:49,  1.46s/it]

iter #52 Loss: 0.5000643443486412


 42%|██████████████████████████████████▏                                              | 54/128 [01:18<01:47,  1.46s/it]

iter #53 Loss: 0.4917770084088224


 43%|██████████████████████████████████▊                                              | 55/128 [01:19<01:47,  1.48s/it]

iter #54 Loss: 0.4839574977528625


 44%|███████████████████████████████████▍                                             | 56/128 [01:21<01:44,  1.46s/it]

iter #55 Loss: 0.47665746608361376


 45%|████████████████████████████████████                                             | 57/128 [01:22<01:44,  1.47s/it]

iter #56 Loss: 0.46978960360216004


 45%|████████████████████████████████████▋                                            | 58/128 [01:24<01:41,  1.45s/it]

iter #57 Loss: 0.4630909523156089


 46%|█████████████████████████████████████▎                                           | 59/128 [01:25<01:38,  1.43s/it]

iter #58 Loss: 0.45681559941187727


 47%|█████████████████████████████████████▉                                           | 60/128 [01:26<01:35,  1.41s/it]

iter #59 Loss: 0.450991304941135


 48%|██████████████████████████████████████▌                                          | 61/128 [01:28<01:35,  1.42s/it]

iter #60 Loss: 0.445165133415745


 48%|███████████████████████████████████████▏                                         | 62/128 [01:29<01:34,  1.43s/it]

iter #61 Loss: 0.439579105694887


 49%|███████████████████████████████████████▊                                         | 63/128 [01:31<01:33,  1.45s/it]

iter #62 Loss: 0.4344859973050011


 50%|████████████████████████████████████████▌                                        | 64/128 [01:32<01:31,  1.44s/it]

iter #63 Loss: 0.4301805656573494


 51%|█████████████████████████████████████████▏                                       | 65/128 [01:34<01:30,  1.43s/it]

iter #64 Loss: 0.42523440313974614


 52%|█████████████████████████████████████████▊                                       | 66/128 [01:35<01:28,  1.43s/it]

iter #65 Loss: 0.42092997043719754


 52%|██████████████████████████████████████████▍                                      | 67/128 [01:37<01:28,  1.45s/it]

iter #66 Loss: 0.4167523538778881


 53%|███████████████████████████████████████████                                      | 68/128 [01:38<01:26,  1.45s/it]

iter #67 Loss: 0.41271177102542167


 54%|███████████████████████████████████████████▋                                     | 69/128 [01:39<01:24,  1.43s/it]

iter #68 Loss: 0.4090719327028028


 55%|████████████████████████████████████████████▎                                    | 70/128 [01:41<01:23,  1.43s/it]

iter #69 Loss: 0.4054371822568668


 55%|████████████████████████████████████████████▉                                    | 71/128 [01:42<01:22,  1.45s/it]

iter #70 Loss: 0.40185176098936704


 56%|█████████████████████████████████████████████▌                                   | 72/128 [01:44<01:21,  1.46s/it]

iter #71 Loss: 0.39869451025491437


 57%|██████████████████████████████████████████████▏                                  | 73/128 [01:45<01:20,  1.47s/it]

iter #72 Loss: 0.39549354119651814


 58%|██████████████████████████████████████████████▊                                  | 74/128 [01:47<01:19,  1.48s/it]

iter #73 Loss: 0.39256577182390967


 59%|███████████████████████████████████████████████▍                                 | 75/128 [01:48<01:17,  1.47s/it]

iter #74 Loss: 0.38971010368638836


 59%|████████████████████████████████████████████████                                 | 76/128 [01:50<01:15,  1.44s/it]

iter #75 Loss: 0.38670087091359995


 60%|████████████████████████████████████████████████▋                                | 77/128 [01:51<01:12,  1.43s/it]

iter #76 Loss: 0.3843127844221701


 61%|█████████████████████████████████████████████████▎                               | 78/128 [01:52<01:11,  1.43s/it]

iter #77 Loss: 0.3818089823205459


 62%|█████████████████████████████████████████████████▉                               | 79/128 [01:54<01:09,  1.43s/it]

iter #78 Loss: 0.3794003644873043


 62%|██████████████████████████████████████████████████▋                              | 80/128 [01:55<01:07,  1.41s/it]

iter #79 Loss: 0.37694390099212


 63%|███████████████████████████████████████████████████▎                             | 81/128 [01:57<01:06,  1.42s/it]

iter #80 Loss: 0.37478105825972435


 64%|███████████████████████████████████████████████████▉                             | 82/128 [01:58<01:05,  1.41s/it]

iter #81 Loss: 0.3727397968812945


 65%|████████████████████████████████████████████████████▌                            | 83/128 [02:00<01:05,  1.45s/it]

iter #82 Loss: 0.3707157607346319


 66%|█████████████████████████████████████████████████████▏                           | 84/128 [02:01<01:03,  1.45s/it]

iter #83 Loss: 0.3687416552022326


 66%|█████████████████████████████████████████████████████▊                           | 85/128 [02:03<01:02,  1.46s/it]

iter #84 Loss: 0.36670757124796133


 67%|██████████████████████████████████████████████████████▍                          | 86/128 [02:04<01:01,  1.46s/it]

iter #85 Loss: 0.3651521331843386


 68%|███████████████████████████████████████████████████████                          | 87/128 [02:06<01:00,  1.48s/it]

iter #86 Loss: 0.36344187822060536


 69%|███████████████████████████████████████████████████████▋                         | 88/128 [02:07<00:58,  1.46s/it]

iter #87 Loss: 0.3616442721015608


 70%|████████████████████████████████████████████████████████▎                        | 89/128 [02:08<00:56,  1.45s/it]

iter #88 Loss: 0.36008065621259855


 70%|████████████████████████████████████████████████████████▉                        | 90/128 [02:10<00:55,  1.45s/it]

iter #89 Loss: 0.3585280609614958


 71%|█████████████████████████████████████████████████████████▌                       | 91/128 [02:11<00:53,  1.45s/it]

iter #90 Loss: 0.3569983201016327


 72%|██████████████████████████████████████████████████████████▏                      | 92/128 [02:13<00:51,  1.43s/it]

iter #91 Loss: 0.355699038653053


 73%|██████████████████████████████████████████████████████████▊                      | 93/128 [02:14<00:49,  1.40s/it]

iter #92 Loss: 0.3544030066787591


 73%|███████████████████████████████████████████████████████████▍                     | 94/128 [02:15<00:47,  1.40s/it]

iter #93 Loss: 0.35294853035492946


 74%|████████████████████████████████████████████████████████████                     | 95/128 [02:17<00:46,  1.41s/it]

iter #94 Loss: 0.35171736479940147


 75%|████████████████████████████████████████████████████████████▊                    | 96/128 [02:18<00:46,  1.46s/it]

iter #95 Loss: 0.3506532364009601


 76%|█████████████████████████████████████████████████████████████▍                   | 97/128 [02:20<00:46,  1.50s/it]

iter #96 Loss: 0.3493669829354976


 77%|██████████████████████████████████████████████████████████████                   | 98/128 [02:22<00:45,  1.52s/it]

iter #97 Loss: 0.34809601392010747


 77%|██████████████████████████████████████████████████████████████▋                  | 99/128 [02:23<00:44,  1.53s/it]

iter #98 Loss: 0.3472444500944336


 78%|██████████████████████████████████████████████████████████████▌                 | 100/128 [02:25<00:43,  1.54s/it]

iter #99 Loss: 0.3461260007041965


 79%|███████████████████████████████████████████████████████████████▏                | 101/128 [02:26<00:42,  1.57s/it]

iter #100 Loss: 0.34499869009640616


 80%|███████████████████████████████████████████████████████████████▊                | 102/128 [02:28<00:40,  1.57s/it]

iter #101 Loss: 0.34387275788384647


 80%|████████████████████████████████████████████████████████████████▍               | 103/128 [02:29<00:38,  1.53s/it]

iter #102 Loss: 0.34321106181758915


 81%|█████████████████████████████████████████████████████████████████               | 104/128 [02:31<00:36,  1.50s/it]

iter #103 Loss: 0.3422373269074762


 82%|█████████████████████████████████████████████████████████████████▋              | 105/128 [02:32<00:34,  1.48s/it]

iter #104 Loss: 0.34102479164219146


 83%|██████████████████████████████████████████████████████████████████▎             | 106/128 [02:34<00:32,  1.47s/it]

iter #105 Loss: 0.34052435261419584


 84%|██████████████████████████████████████████████████████████████████▉             | 107/128 [02:35<00:31,  1.48s/it]

iter #106 Loss: 0.33946177448491155


 84%|███████████████████████████████████████████████████████████████████▌            | 108/128 [02:37<00:29,  1.46s/it]

iter #107 Loss: 0.33863233839194784


 85%|████████████████████████████████████████████████████████████████████▏           | 109/128 [02:38<00:27,  1.46s/it]

iter #108 Loss: 0.33788761939040296


 86%|████████████████████████████████████████████████████████████████████▊           | 110/128 [02:40<00:26,  1.48s/it]

iter #109 Loss: 0.337053131879435


 87%|█████████████████████████████████████████████████████████████████████▍          | 111/128 [02:41<00:25,  1.47s/it]

iter #110 Loss: 0.33637984879882205


 88%|██████████████████████████████████████████████████████████████████████          | 112/128 [02:43<00:24,  1.50s/it]

iter #111 Loss: 0.33559159316160353


 88%|██████████████████████████████████████████████████████████████████████▋         | 113/128 [02:44<00:22,  1.52s/it]

iter #112 Loss: 0.33480589570203406


 89%|███████████████████████████████████████████████████████████████████████▎        | 114/128 [02:46<00:21,  1.53s/it]

iter #113 Loss: 0.33418194236716037


 90%|███████████████████████████████████████████████████████████████████████▉        | 115/128 [02:47<00:20,  1.57s/it]

iter #114 Loss: 0.3335899658085126


 91%|████████████████████████████████████████████████████████████████████████▌       | 116/128 [02:49<00:18,  1.56s/it]

iter #115 Loss: 0.33272435066058553


 91%|█████████████████████████████████████████████████████████████████████████▏      | 117/128 [02:51<00:18,  1.69s/it]

iter #116 Loss: 0.33220417863750823


 92%|█████████████████████████████████████████████████████████████████████████▊      | 118/128 [02:53<00:17,  1.71s/it]

iter #117 Loss: 0.3316234200625553


 93%|██████████████████████████████████████████████████████████████████████████▍     | 119/128 [02:54<00:14,  1.65s/it]

iter #118 Loss: 0.33091293598719057


 94%|███████████████████████████████████████████████████████████████████████████     | 120/128 [02:56<00:13,  1.63s/it]

iter #119 Loss: 0.3304372340970233


 95%|███████████████████████████████████████████████████████████████████████████▋    | 121/128 [02:57<00:11,  1.60s/it]

iter #120 Loss: 0.32952142132417805


 95%|████████████████████████████████████████████████████████████████████████████▎   | 122/128 [02:59<00:09,  1.61s/it]

iter #121 Loss: 0.3291342083156714


 96%|████████████████████████████████████████████████████████████████████████████▉   | 123/128 [03:00<00:07,  1.58s/it]

iter #122 Loss: 0.3286045059788651


 97%|█████████████████████████████████████████████████████████████████████████████▌  | 124/128 [03:02<00:06,  1.53s/it]

iter #123 Loss: 0.32823741727220224


 98%|██████████████████████████████████████████████████████████████████████████████▏ | 125/128 [03:03<00:04,  1.55s/it]

iter #124 Loss: 0.32747757217378787


 98%|██████████████████████████████████████████████████████████████████████████████▊ | 126/128 [03:05<00:03,  1.60s/it]

iter #125 Loss: 0.32704927420555635


 99%|███████████████████████████████████████████████████████████████████████████████▍| 127/128 [03:07<00:01,  1.56s/it]

iter #126 Loss: 0.3264898631891926


100%|████████████████████████████████████████████████████████████████████████████████| 128/128 [03:08<00:00,  1.47s/it]

iter #127 Loss: 0.3260256615727383





In [25]:
c = 0
uw = 0
iw = 0 
for name, yash in model.named_parameters():
    if yash.requires_grad:
        print(name, yash.data)
        if c == 0:
          uw = yash.data
          c +=1
        else:
          iw = yash.data

user_factors.weight tensor([[ 0.6566,  1.5865,  1.4255,  ...,  1.6873,  0.7072,  1.8206],
        [ 2.3735,  0.6803,  1.4607,  ...,  1.3688,  0.4064,  1.0991],
        [ 1.9848,  0.8896, -3.5604,  ...,  0.9796, -0.3901, -0.2799],
        ...,
        [ 0.0275,  1.6872,  1.4219,  ...,  1.6841,  1.5586, -1.1896],
        [ 1.0926,  0.8030,  1.4236,  ...,  1.0499,  1.4311,  0.8676],
        [ 1.8563,  0.9356,  0.3273,  ...,  1.0134,  1.6467,  1.3681]])
item_factors.weight tensor([[ 0.6710,  0.1434,  0.5828,  ...,  0.1939,  0.5384,  0.6981],
        [ 0.4802,  0.4188,  0.6303,  ...,  0.3990, -0.0062,  0.5881],
        [ 0.3190,  0.4571,  0.5744,  ...,  0.5104,  0.7044,  0.0860],
        ...,
        [ 0.3506,  0.3344,  0.3727,  ...,  0.3402,  0.3474,  0.3565],
        [ 0.4232,  0.4131,  0.4125,  ...,  0.4102,  0.4085,  0.4064],
        [ 0.4045,  0.4208,  0.4298,  ...,  0.4318,  0.3967,  0.4107]])


In [26]:
trained_movie_embeddings = model.item_factors.weight.data.cpu().numpy()

In [27]:
len(trained_movie_embeddings) 

9724

In [28]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=10, random_state=0).fit(trained_movie_embeddings)

In [29]:
for cluster in range(10):
    print("Cluster #{}".format(cluster))
    movs = []
    for movidx in np.where(kmeans.labels_ == cluster)[0]:
        movid = train_set.idx2movieid[movidx]
        rat_count = ratings_df.loc[ratings_df['movieId']==movid].iloc[:, 0].count()
        movs.append((movie_names[movid], rat_count))
    for mov in sorted(movs, key=lambda tup: tup[1], reverse=True)[:10]:
        print("\t", mov[0])
     

Cluster #0
	 Forrest Gump (1994)
	 Shawshank Redemption, The (1994)
	 Silence of the Lambs, The (1991)
	 Matrix, The (1999)
	 Star Wars: Episode IV - A New Hope (1977)
	 Fight Club (1999)
	 Star Wars: Episode V - The Empire Strikes Back (1980)
	 Usual Suspects, The (1995)
	 Raiders of the Lost Ark (Indiana Jones and the Raiders of the Lost Ark) (1981)
	 Fugitive, The (1993)
Cluster #1
	 Pulp Fiction (1994)
	 American Beauty (1999)
	 Seven (a.k.a. Se7en) (1995)
	 Lord of the Rings: The Fellowship of the Ring, The (2001)
	 Godfather, The (1972)
	 Ace Ventura: Pet Detective (1994)
	 Memento (2000)
	 Monty Python and the Holy Grail (1975)
	 Reservoir Dogs (1992)
	 Kill Bill: Vol. 1 (2003)
Cluster #2
	 Dances with Wolves (1990)
	 Stargate (1994)
	 Home Alone (1990)
	 Waterworld (1995)
	 Net, The (1995)
	 Harry Potter and the Sorcerer's Stone (a.k.a. Harry Potter and the Philosopher's Stone) (2001)
	 Back to the Future Part III (1990)
	 Unbreakable (2000)
	 Pirates of the Caribbean: Dead Man