
# Contextual Bandits on MovieLens — **PyTorch** (GPU-ready)
This notebook implements a GPU-ready pipeline for a **contextual bandit** recommender:

1. **Matrix Factorization (Torch)** on a chronological train split → user/item embeddings.  
2. MF-based **candidate retrieval** (Top-K per event).  
3. A **stochastic logging policy** μ (softmax temperature + ε-mix) to simulate bandit logs with **propensities**.  
4. A **reward model** \( \hat Q \) (Torch MLP) predicting click for \((x,i)\) pairs.  
5. A **policy** \( \pi_\theta \) (Torch MLP) trained with a **Doubly-Robust (DR)** objective.  
6. **Off-Policy Evaluation**: **SNIPS** and **DR** on a held-out bandit log.

> **Data expectation**: MovieLens 1M files at `base_dir/ml-1m/` with  
> `ratings.dat`, `users.dat`, `movies.dat`. If not present, the notebook will fall back to a synthetic dataset so you can run it end-to-end.

> **GPU**: If a CUDA GPU is available, training will use it automatically (`device = 'cuda'`). Otherwise it runs on CPU.


In [None]:

import os, math, random, time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Repro
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Config
base_dir = Path("./data")
ml1m_dir = base_dir / "ml-1m"    # place ratings.dat, users.dat, movies.dat here

# Core params
D = 32          # MF embedding dim
K = 50          # candidate pool size
train_ratio = 0.9
rating_thresh = 4

# Logging μ
tau = 0.7       # softmax temperature
eps = 0.05      # epsilon-uniform mix

# MF training
mf_epochs = 5
mf_lr = 0.05
mf_reg = 0.0    # using optimizer weight decay instead

# Q̂ training
qhat_epochs = 5
qhat_lr = 1e-3
qhat_hid = 128

# Policy training
policy_epochs = 8
policy_lr = 1e-3
policy_hid = 128
weight_clip = 10.0
entropy_coef = 0.01
batch_size = 256

print("Config ready.")


Using device: cuda
Config ready.


In [None]:
def read_ml1m(ml1m_dir: Path):
    rpath = ml1m_dir / "ratings.dat"
    upath = ml1m_dir / "users.dat"
    mpath = ml1m_dir / "movies.dat"
    if not (rpath.exists() and upath.exists() and mpath.exists()):
        return None, None, None

    ratings = pd.read_csv(rpath, sep="::", engine="python", header=None,
                          names=["user_id","movie_id","rating","ts"])
    users = pd.read_csv(upath, sep="::", engine="python", header=None,
                        names=["user_id","gender","age","occupation","zip"])
    movies = pd.read_csv(mpath, sep="::", engine="python", header=None, encoding="latin-1",
                         names=["movie_id","title","genres"])
    return ratings, users, movies

def make_synthetic(n_users=200, n_items=400, n_events=10000):
    set_seed(123)
    genders = np.random.choice([0,1], size=n_users)  # 0=F,1=M
    ages = np.random.choice([1,18,25,35,45,50,56], size=n_users)
    occs = np.random.choice(list(range(21)), size=n_users)
    users = pd.DataFrame({
        "user_id": np.arange(1, n_users+1),
        "gender": genders,
        "age": ages,
        "occupation": occs,
        "zip": ["00000"]*n_users
    })
    genres_all = ["Action","Comedy","Drama","Romance","Sci-Fi","Thriller","Crime","Animation","Horror"]
    item_genres = [ "|".join(np.random.choice(genres_all, size=np.random.randint(1,4), replace=False)) for _ in range(n_items) ]
    movies = pd.DataFrame({
        "movie_id": np.arange(1, n_items+1),
        "title": [f"Movie {i}" for i in range(1, n_items+1)],
        "genres": item_genres
    })
    U = np.random.randn(n_users, 8)
    V = np.random.randn(n_items, 8)
    rows = []
    ts = 950_000_000
    for _ in range(n_events):
        u = np.random.randint(1, n_users+1)
        scores = U[u-1] @ V.T + 0.5*np.random.randn(n_items)
        i_star = int(np.argmax(scores)) + 1
        s = scores[i_star-1]
        rating = int(np.clip(np.round(s*0.5+3), 1, 5))
        ts += np.random.randint(1, 1000)
        rows.append((u, i_star, rating, ts))
    ratings = pd.DataFrame(rows, columns=["user_id","movie_id","rating","ts"])
    return ratings, users, movies

def preprocess_users(users: pd.DataFrame):
    # gender
    if users["gender"].dtype == object:
        g = (users["gender"] == "M").astype(int).values.reshape(-1,1)
    else:
        g = users["gender"].astype(int).values.reshape(-1,1)
    age_codes = [1,18,25,35,45,50,56]
    age_map = {a:i for i,a in enumerate(age_codes)}
    age_idx = users["age"].map(lambda a: age_map.get(int(a), 0)).fillna(0).astype(int).values
    age_oh = np.eye(len(age_codes))[age_idx]
    occ_max = 21
    occ = users["occupation"].astype(int).clip(0, occ_max-1).values
    occ_oh = np.eye(occ_max)[occ]
    demo = np.concatenate([g, age_oh, occ_oh], axis=1).astype(np.float32)
    # index by user_id
    demo_df = pd.DataFrame(demo, index=users["user_id"].values)
    return demo_df, {"age_codes": age_codes, "occ_max": occ_max}

def preprocess_items(movies: pd.DataFrame):
    genre_set = set()
    for g in movies["genres"].fillna("").tolist():
        for x in str(g).split("|"):
            x = x.strip()
            if x:
                genre_set.add(x)
    genres = sorted(list(genre_set))
    idx = {g:i for i,g in enumerate(genres)}
    G = np.zeros((len(movies), len(genres)), dtype=np.float32)
    for r, g in enumerate(movies["genres"].fillna("").tolist()):
        for x in str(g).split("|"):
            x = x.strip()
            if x and x in idx:
                G[r, idx[x]] = 1.0
    gdf = pd.DataFrame(G, index=movies["movie_id"].values, columns=genres)
    return gdf, {"genres": genres}

def chronological_split(ratings: pd.DataFrame, train_ratio=0.9):
    ratings = ratings.sort_values("ts").reset_index(drop=True)
    cut = int(len(ratings)*train_ratio)
    return ratings.iloc[:cut].copy(), ratings.iloc[cut:].copy()

def build_id_maps(ratings: pd.DataFrame):
    uids = sorted(ratings["user_id"].unique().tolist())
    iids = sorted(ratings["movie_id"].unique().tolist())
    umap = {u:i for i,u in enumerate(uids)}
    imap = {i:j for j,i in enumerate(iids)}
    return umap, imap

def make_time_features(ts_series: pd.Series):
    dt = pd.to_datetime(ts_series, unit='s', origin='unix')
    hod = dt.dt.hour.values
    dow = dt.dt.dayofweek.values
    hod_sin = np.sin(2*np.pi*hod/24.0)
    hod_cos = np.cos(2*np.pi*hod/24.0)
    dow_sin = np.sin(2*np.pi*dow/7.0)
    dow_cos = np.cos(2*np.pi*dow/7.0)
    return np.stack([hod_sin, hod_cos, dow_sin, dow_cos], axis=1).astype(np.float32)


In [None]:
def train_mf_torch(train_df, umap, imap, D=32, epochs=5, lr=0.05, weight_decay=0.0, thresh=4, batch=4096):
    nU, nI = len(umap), len(imap)
    U  = nn.Parameter(0.1*torch.randn(nU, D, device=device))
    V  = nn.Parameter(0.1*torch.randn(nI, D, device=device))
    bu = nn.Parameter(torch.zeros(nU, device=device))
    bi = nn.Parameter(torch.zeros(nI, device=device))
    g  = nn.Parameter(torch.zeros((), device=device))
    opt = torch.optim.SGD([U,V,bu,bi,g], lr=lr, weight_decay=weight_decay)

    us = torch.tensor(train_df["user_id"].map(umap).values, device=device, dtype=torch.long)
    is_ = torch.tensor(train_df["movie_id"].map(imap).values, device=device, dtype=torch.long)
    y  = torch.tensor((train_df["rating"].values >= thresh).astype('float32'), device=device)

    for ep in range(epochs):
        idx = torch.randperm(len(us), device=device)
        last_loss = 0.0
        for b in idx.split(batch):
            u, i, r = us[b], is_[b], y[b]
            pred = (U[u] * V[i]).sum(-1) + bu[u] + bi[i] + g
            loss = F.mse_loss(pred, r)
            opt.zero_grad(); loss.backward(); opt.step()
            last_loss = loss.item()
        print(f"[MF-torch] epoch {ep+1}/{epochs} MSE={last_loss:.4f}")
    with torch.no_grad():
        return U.detach(), V.detach(), bu.detach(), bi.detach(), g.detach()


In [None]:
def topk_candidates_torch(uvec, V, seen_mask, K):
    # uvec: [D], V: [nI, D], seen_mask: [nI] bool
    scores = V @ uvec
    scores = scores.masked_fill(seen_mask, float('-inf'))
    vals, idx = torch.topk(scores, k=K)
    return idx, vals

def softmax_temp_torch(x, tau):
    z = (x / max(tau, 1e-6))
    z = z - z.max()
    return torch.softmax(z, dim=0)

def simulate_bandit_logs_torch(ratings_df, users_df, movies_df, U, V, umap, imap, K=50, tau=0.7, eps=0.05, thresh=4, seed=123):
    torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)
    # demos & genres
    demo_df, _ = preprocess_users(users_df)
    genres_df, _ = preprocess_items(movies_df)

    max_uid = max(umap.values())+1
    max_iid = max(imap.values())+1
    demo = torch.zeros((max_uid, demo_df.shape[1]), device=device)
    for uid, row in demo_df.iterrows():
        if uid in umap:
            demo[umap[uid]] = torch.tensor(row.values, device=device, dtype=torch.float32)
    genres = torch.zeros((max_iid, genres_df.shape[1]), device=device)
    for iid, row in genres_df.iterrows():
        if iid in imap:
            genres[imap[iid]] = torch.tensor(row.values, device=device, dtype=torch.float32)

    # seen sets per user (internal ids)
    seen_by_user = [set() for _ in range(max_uid)]
    logs = []

    ratings_df = ratings_df.sort_values("ts").reset_index(drop=True)
    time_feats = torch.tensor(make_time_features(ratings_df["ts"]), device=device)

    for idx, row in ratings_df.iterrows():
        u_raw = int(row["user_id"]); i_star_raw = int(row["movie_id"])
        if u_raw not in umap or i_star_raw not in imap:
            continue
        u = umap[u_raw]; i_star = imap[i_star_raw]

        seen_mask = torch.zeros(V.shape[0], dtype=torch.bool, device=device)
        if len(seen_by_user[u])>0:
            seen_idx = torch.tensor(list(seen_by_user[u]), device=device, dtype=torch.long)
            seen_mask[seen_idx] = True

        cand_idx, cand_scores = topk_candidates_torch(U[u], V, seen_mask, K=K)
        if (i_star not in cand_idx.tolist()):
            # replace worst
            minpos = torch.argmin(cand_scores)
            cand_idx[minpos] = i_star
            cand_scores[minpos] = (U[u] * V[i_star]).sum()

        p_soft = softmax_temp_torch(cand_scores, tau)
        p_mu = (1.0 - eps) * p_soft + eps * (1.0/len(cand_idx))

        a_pos = torch.multinomial(p_mu, 1).item()
        a_iid = int(cand_idx[a_pos].item())
        logprop = float(p_mu[a_pos].item())
        reward = int(a_iid == i_star)

        x_t = torch.cat([U[u], demo[u], time_feats[idx]], dim=0).detach()

        logs.append({
            "u": u, "i_star": i_star, "cand_idx": cand_idx.detach().cpu().numpy().astype('int32'),
            "x_t": x_t.detach().cpu().numpy().astype('float32'),
            "logprop": logprop, "a_pos": a_pos, "a_iid": a_iid, "reward": reward
        })
        seen_by_user[u].add(i_star)

    return logs, demo.detach().cpu().numpy(), genres.detach().cpu().numpy()


In [None]:
def make_pair_features(x_t_np, item_iid, U, V, genres_np):
    # x_t_np: numpy float32, starts with user embedding of dim D
    D = U.shape[1]
    u_emb = torch.tensor(x_t_np[:D], device=device)
    v_emb = V[item_iid]
    uv = u_emb * v_emb
    g = torch.tensor(genres_np[item_iid], device=device)
    ctx_rest = torch.tensor(x_t_np[D:], device=device)
    return torch.cat([u_emb, v_emb, uv, g, ctx_rest], dim=0)

class MLPQ(nn.Module):
    def __init__(self, in_dim, hid=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.ReLU(),
            nn.Linear(hid, 1)
        )
    def forward(self, x):  # x: [N, in_dim]
        return self.net(x)  # logits

def train_qhat_torch(X, y, hid=128, lr=1e-3, epochs=5, batch=4096):
    model = MLPQ(X.shape[1], hid).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    crit = nn.BCEWithLogitsLoss()
    X_t = torch.tensor(X, device=device, dtype=torch.float32)
    y_t = torch.tensor(y, device=device, dtype=torch.float32).view(-1,1)
    for ep in range(epochs):
        perm = torch.randperm(len(X_t), device=device)
        last = 0.0
        for b in perm.split(batch):
            xb, yb = X_t[b], y_t[b]
            logits = model(xb)
            loss = crit(logits, yb)
            opt.zero_grad(); loss.backward(); opt.step()
            last = loss.item()
        print(f"[Qhat] epoch {ep+1}/{epochs} BCE={last:.4f}")
    return model

@torch.no_grad()
def qhat_predict_probs(model, X):  # X: [N, in_dim] tensor on device
    return torch.sigmoid(model(X)).squeeze(-1)


In [None]:
class PolicyMLP(nn.Module):
    def __init__(self, in_dim, hid=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.ReLU(),
            nn.Linear(hid, 1)  # per-candidate logit
        )
    def forward(self, X):  # X: [B,K,in_dim]
        B,K,D = X.shape
        z = self.net(X.view(B*K, D)).view(B, K)   # logits per candidate
        P = F.softmax(z, dim=1)
        return P, z

def build_qhat_dataset(logs, U, V, genres_np, neg_per_pos=1, cap=None):
    Xs, ys = [], []
    it = logs if cap is None else logs[:cap]
    for ex in it:
        x_t = ex["x_t"]; cand = ex["cand_idx"]; a_pos = ex["a_pos"]; r = ex["reward"]
        # positive
        Xa = make_pair_features(x_t, int(cand[a_pos]), U, V, genres_np)
        Xs.append(Xa.cpu().numpy()); ys.append(r)
        # negatives
        others = [k for k in range(len(cand)) if k!=a_pos]
        random.shuffle(others)
        for _ in range(neg_per_pos):
            if not others: break
            k = others.pop()
            Xn = make_pair_features(x_t, int(cand[k]), U, V, genres_np)
            Xs.append(Xn.cpu().numpy()); ys.append(0)
    X = np.stack(Xs, axis=0).astype(np.float32)
    y = np.array(ys, dtype=np.float32)
    return X, y

def train_policy_dr_torch(logs, U, V, genres_np, qhat_model, K, lr=1e-3, epochs=5, wclip=10.0, ent=0.01, bs=256):
    # infer in_dim
    x0 = logs[0]["x_t"]; i0 = int(logs[0]["cand_idx"][0])
    phi0 = make_pair_features(x0, i0, U, V, genres_np).to(device)
    in_dim = phi0.numel()
    policy = PolicyMLP(in_dim, hid=policy_hid).to(device)
    opt = torch.optim.AdamW(policy.parameters(), lr=lr)

    def make_batch(batch):
        B = len(batch)
        X_all = []
        Q_all = []
        a_pos = []
        mu = []
        r = []
        with torch.no_grad():
            for ex in batch:
                x_t = ex["x_t"]; cand = ex["cand_idx"]; apos = ex["a_pos"]
                feats = [make_pair_features(x_t, int(c), U, V, genres_np) for c in cand]
                X = torch.stack(feats, 0).to(device)  # [K, in_dim]
                Q = qhat_predict_probs(qhat_model, X) # [K]
                X_all.append(X); Q_all.append(Q)
                a_pos.append(apos); mu.append(ex["logprop"]); r.append(ex["reward"])
        X_all = torch.stack(X_all, 0)                 # [B,K,in_dim]
        Q_all = torch.stack(Q_all, 0)                 # [B,K]
        a_pos = torch.tensor(a_pos, device=device, dtype=torch.long)
        mu    = torch.tensor(mu, device=device, dtype=torch.float32)
        r     = torch.tensor(r,  device=device, dtype=torch.float32)
        return X_all, Q_all, a_pos, mu, r

    N = len(logs)
    for ep in range(epochs):
        random.shuffle(logs)
        tot = 0.0; steps = 0
        for s in range(0, N, bs):
            batch = logs[s:s+bs]
            X_all, Q_all, a_pos, mu, r = make_batch(batch)     # [B,K,d], [B,K], [B], [B], [B]
            P, logits = policy(X_all)                           # [B,K], [B,K]

            direct = (P * Q_all).sum(1)
            pi_at = P[torch.arange(P.size(0), device=device), a_pos] + 1e-12
            w = (pi_at / (mu + 1e-12)).clamp(max=wclip)
            Qa = Q_all[torch.arange(Q_all.size(0), device=device), a_pos]
            resid = w * (r - Qa)
            H = -(P * torch.log(P + 1e-12)).sum(1)

            loss = -(direct + resid + ent*H).mean()
            opt.zero_grad(); loss.backward(); opt.step()

            tot += loss.item(); steps += 1
        print(f"[Policy-DR] epoch {ep+1}/{epochs} loss={tot/max(1,steps):.4f}")
    return policy

@torch.no_grad()
def ope_snips(policy, logs, U, V, genres_np, K):
    num, den = 0.0, 0.0
    for ex in logs:
        x_t = ex["x_t"]; cand = ex["cand_idx"]; a_pos = ex["a_pos"]; mu = ex["logprop"]; r = ex["reward"]
        feats = torch.stack([make_pair_features(x_t, int(c), U, V, genres_np) for c in cand], 0).to(device)
        P, _ = policy(feats.unsqueeze(0))  # [1,K]
        P = P[0]
        w = (P[a_pos].item() / (mu + 1e-12))
        num += w * r
        den += w
    return float(num/den) if den > 0 else 0.0

@torch.no_grad()
def ope_dr(policy, qhat_model, logs, U, V, genres_np, K, wclip=10.0):
    vals = []
    for ex in logs:
        x_t = ex["x_t"]; cand = ex["cand_idx"]; a_pos = ex["a_pos"]; mu = ex["logprop"]; r = ex["reward"]
        feats = torch.stack([make_pair_features(x_t, int(c), U, V, genres_np) for c in cand], 0).to(device)
        P, _ = policy(feats.unsqueeze(0))  # [1,K]
        P = P[0]
        Q = qhat_predict_probs(qhat_model, feats)    # [K]
        direct = float((P * Q).sum().item())
        w = float((P[a_pos].item() / (mu + 1e-12)))
        w = min(w, float(wclip))
        resid = w * (r - float(Q[a_pos].item()))
        vals.append(direct + resid)
    return float(np.mean(vals)) if len(vals) else 0.0


In [None]:
# ==== Driver ====
ratings, users, movies = read_ml1m(ml1m_dir)
if ratings is None:
    print("MovieLens 1M files not found; using synthetic dataset for a quick run.")
    ratings, users, movies = make_synthetic(n_users=200, n_items=400, n_events=8000)

print("Ratings:", ratings.shape, "Users:", users.shape, "Movies:", movies.shape)
train_df, test_df = chronological_split(ratings, train_ratio=train_ratio)
print(f"Split: train={len(train_df)}, test={len(test_df)}")

umap, imap = build_id_maps(ratings)

# Train MF on train
U, V, bu, bi, g = train_mf_torch(train_df, umap, imap, D=D, epochs=mf_epochs, lr=mf_lr, weight_decay=mf_reg, thresh=rating_thresh)

# Simulate μ logs (different seeds for train/test)
train_logs, demo_np, genres_np = simulate_bandit_logs_torch(train_df, users, movies, U, V, umap, imap, K=K, tau=tau, eps=eps, thresh=rating_thresh, seed=123)
test_logs,  _,       _         = simulate_bandit_logs_torch(test_df,  users, movies, U, V, umap, imap, K=K, tau=tau, eps=eps, thresh=rating_thresh, seed=456)
print(f"Logs built — train: {len(train_logs)}, test: {len(test_logs)}")

# Build Qhat dataset (positives + negatives from candidate set)
Xq, yq = build_qhat_dataset(train_logs, U, V, genres_np, neg_per_pos=2, cap=20000)
print("Q̂ dataset:", Xq.shape, yq.shape)

# Train Q̂
qhat_model = train_qhat_torch(Xq, yq, hid=qhat_hid, lr=qhat_lr, epochs=qhat_epochs, batch=4096)

# Train policy with DR
policy_model = train_policy_dr_torch(train_logs, U, V, genres_np, qhat_model, K=K,
                                     lr=policy_lr, epochs=policy_epochs, wclip=weight_clip, ent=entropy_coef, bs=batch_size)

# OPE on held-out test log
snips = ope_snips(policy_model, test_logs, U, V, genres_np, K)
drval = ope_dr(policy_model, qhat_model, test_logs, U, V, genres_np, K, wclip=weight_clip)
print("\\n=== OPE on test log ===")
print(f"SNIPS CTR estimate: {snips:.4f}")
print(f"DR    CTR estimate: {drval:.4f}")


Ratings: (1000209, 4) Users: (6040, 5) Movies: (3883, 3)
Split: train=900188, test=100021
[MF-torch] epoch 1/5 MSE=0.2429
[MF-torch] epoch 2/5 MSE=0.2430
[MF-torch] epoch 3/5 MSE=0.2413
[MF-torch] epoch 4/5 MSE=0.2414
[MF-torch] epoch 5/5 MSE=0.2383
Logs built — train: 900188, test: 100021
Q̂ dataset: (60000, 147) (60000,)
[Qhat] epoch 1/5 BCE=0.4797
[Qhat] epoch 2/5 BCE=0.2922
[Qhat] epoch 3/5 BCE=0.1466
[Qhat] epoch 4/5 BCE=0.0848
[Qhat] epoch 5/5 BCE=0.0556
[Policy-DR] epoch 1/8 loss=-0.1809
[Policy-DR] epoch 2/8 loss=-0.2009
[Policy-DR] epoch 3/8 loss=-0.2017
[Policy-DR] epoch 4/8 loss=-0.2027
[Policy-DR] epoch 5/8 loss=-0.2034
[Policy-DR] epoch 6/8 loss=-0.2040
[Policy-DR] epoch 7/8 loss=-0.2045
[Policy-DR] epoch 8/8 loss=-0.2049
\n=== OPE on test log ===
SNIPS CTR estimate: 0.7952
DR    CTR estimate: 0.1940
