In [1]:
# notebooks/03_lstm_train.ipynb
import pandas as pd, torch, numpy as np
from pathlib import Path
df = pd.read_parquet(Path.cwd().parent/"data/trace.parquet").sort_values("ts")

# map video IDs to ints
cats = df.video.astype("category")
vid2idx = {v:i for i,v in enumerate(cats.cat.categories)}
idx2vid = {i:v for v,i in vid2idx.items()}
df["vid_idx"] = cats.cat.codes

# build user sequences
seqs = df.groupby("user")["vid_idx"].apply(list).tolist()

# split into (input, target) pairs
def make_samples(seq, win=5):
    X,Y = [],[]
    for i in range(len(seq)-win):
        X.append(seq[i:i+win])
        Y.append(seq[i+win])
    return X,Y

X,Y = [],[]
for s in seqs:
    x,y = make_samples(s, win=5)
    X.extend(x); Y.extend(y)

X = torch.tensor(X, dtype=torch.long)
Y = torch.tensor(Y, dtype=torch.long)

In [5]:
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

class NextVidLSTM(nn.Module):
    def __init__(self, n_vid, emb=64, hid=128):
        super().__init__()
        self.emb = nn.Embedding(n_vid, emb)
        self.lstm = nn.LSTM(emb, hid, batch_first=True)
        self.fc   = nn.Linear(hid, n_vid)

    def forward(self, x):
        x = self.emb(x)
        out,_ = self.lstm(x)
        return self.fc(out[:,-1])          # logits for next item

n_vid = len(vid2idx)
model = NextVidLSTM(n_vid)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)
crit  = nn.CrossEntropyLoss()

dl = DataLoader(TensorDataset(X,Y), batch_size=256, shuffle=True)

for epoch in range(100):
    loss=0
    for xb,yb in dl:
        opt.zero_grad()
        logits = model(xb)
        l = crit(logits, yb)
        l.backward(); opt.step()
        loss += l.item()*len(xb)
    print(epoch, loss/len(dl.dataset))

torch.save({"state":model.state_dict(),
            "vid2idx":vid2idx, "idx2vid":idx2vid},
           "../models/lstm.pt")

0 8.349504776093017
1 8.096119071055
2 7.279534698239736
3 6.200761841303909
4 5.301518742302572
5 4.547432731836659
6 3.90046657688524
7 3.3433361936720507
8 2.8685095446354496
9 2.462099691293965
10 2.114693300441436
11 1.822358959294257
12 1.575251950419008
13 1.3654760039127136
14 1.189701366546802
15 1.0443105772414782
16 0.9191210442444087
17 0.8137889204102449
18 0.7243100228927549
19 0.6488090398815681
20 0.5833736165355795
21 0.5259181797426526
22 0.47691318648621533
23 0.4329797269813353
24 0.39593068716800955
25 0.3633767716744635
26 0.33429946029463087
27 0.3091998318577455
28 0.2863625858229609
29 0.26597994962349314
30 0.2483328621213947
31 0.23139205962835968
32 0.21772054801026072
33 0.20348026648027473
34 0.19181689006414343
35 0.18073130472318555
36 0.1714245971646922
37 0.162391099201335
38 0.15372030666794406
39 0.14661252831794322
40 0.13947486087067157
41 0.13242905748373446
42 0.1269730710559027
43 0.12116571148849797
44 0.11585886606268156
45 0.11087395783206308