In [None]:
# offline_rl_slateq.py

import numpy as np
import pandas as pd
from math import log2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# -----------------------------
# 1. Load train/test data
# -----------------------------
def read_jsonl(path):
    return pd.read_json(path, lines=True)

train_users      = read_jsonl('train_users.jsonl')
train_businesses = read_jsonl('train_businesses.jsonl')
train_reviews    = read_jsonl('train_reviews.jsonl').sort_values(['user_id','date']).reset_index(drop=True)

test_users       = read_jsonl('test_users.jsonl')
test_businesses  = read_jsonl('test_businesses.jsonl')
test_reviews     = read_jsonl('test_reviews.jsonl').sort_values(['user_id','date']).reset_index(drop=True)

# -----------------------------
# 2. Preprocessing - categories vocabulary & fans_count
# -----------------------------
all_categories = set()
for cats in train_businesses['categories'].dropna():
    for c in cats.split(','):
        all_categories.add(c.strip())
all_categories = sorted(all_categories)
cat2idx = {c:i for i,c in enumerate(all_categories)}

def parse_fans(x):
    if isinstance(x, str):
        return len(x.split(','))
    elif isinstance(x, list):
        return len(x)
    return 0

for df in [train_users, test_users]:
    df['fans_count'] = df['fans'].apply(parse_fans)

# -----------------------------
# 3. Feature extraction
# -----------------------------
def extract_user_feat(row):
    return np.array([row['average_stars'],
                     row['review_count'],
                     row['fans_count']], dtype=np.float32)

def extract_biz_feat(row):
    base = [row['stars'], row['review_count']]
    multi_hot = np.zeros(len(all_categories), dtype=np.float32)
    if pd.notna(row['categories']):
        for c in row['categories'].split(','):
            idx = cat2idx.get(c.strip())
            if idx is not None:
                multi_hot[idx] = 1.0
    return np.concatenate([np.array(base, dtype=np.float32), multi_hot])

# -----------------------------
# 4. Construct transitions
# -----------------------------
def build_transitions(reviews_df, users_df, biz_df):
    users_idx = users_df.set_index('user_id')
    biz_idx   = biz_df.set_index('business_id')
    trans = []
    for uid, grp in reviews_df.groupby('user_id', sort=False):
        grp = grp.reset_index(drop=True)
        if len(grp) < 2: continue
        for i in range(len(grp)-1):
            cur, nxt = grp.loc[i], grp.loc[i+1]
            u_feat = extract_user_feat(users_idx.loc[uid])
            b_feat = extract_biz_feat(biz_idx.loc[cur['business_id']])
            u_feat_next = u_feat  # User features unchanged
            b_feat_next = extract_biz_feat(biz_idx.loc[nxt['business_id']])
            s  = np.concatenate([u_feat, b_feat])
            a  = b_feat.copy()
            r  = 1.0 if cur['stars'] >= 4 else 0.0
            s2 = np.concatenate([u_feat_next, b_feat_next])
            trans.append((s, a, r, s2))
    return trans

train_trans = build_transitions(train_reviews, train_users, train_businesses)
test_trans  = build_transitions(test_reviews, test_users, test_businesses)
print(f"Train transitions: {len(train_trans)}  Test transitions: {len(test_trans)}")

# -----------------------------
# 5. PyTorch Dataset
# -----------------------------
class TransDataset(Dataset):
    def __init__(self, transitions):
        self.data = transitions
    def __len__(self):
        return len(self.data)
    def __getitem__(self, i):
        s, a, r, s2 = self.data[i]
        return (torch.from_numpy(s),
                torch.from_numpy(a),
                torch.tensor(r, dtype=torch.float32),
                torch.from_numpy(s2))

train_loader = DataLoader(TransDataset(train_trans), batch_size=256, shuffle=True)

# -----------------------------
# 6. Define Q-network
# -----------------------------
class QNet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim+action_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, s, a):
        x = torch.cat([s, a], dim=1)
        return self.net(x).squeeze(1)

user_dim = 3
biz_dim  = 2 + len(all_categories)
state_dim  = user_dim + biz_dim
action_dim = biz_dim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Q     = QNet(state_dim, action_dim).to(device)
Q_tar = QNet(state_dim, action_dim).to(device)
Q_tar.load_state_dict(Q.state_dict())
opt = optim.Adam(Q.parameters(), lr=1e-3)

# -----------------------------
# 7. SlateQ (SARSA-TS) training
# -----------------------------
gamma = 0.99
tau   = 0.005  # Soft update
K     = 10     # slate size

cand_feats = torch.from_numpy(
    np.stack([extract_biz_feat(row) for _, row in test_businesses.set_index('business_id').iterrows()])
).to(device)

for epoch in range(20):
    total_loss = 0
    for s, a, r, s2 in train_loader:
        s, a, r, s2 = s.to(device), a.to(device), r.to(device), s2.to(device)

        q = Q(s, a)

        ss = s2[:, :user_dim].unsqueeze(1).repeat(1, len(cand_feats), 1)
        feats = cand_feats.unsqueeze(0).repeat(ss.size(0), 1, 1)
        inp = torch.cat([ss, feats], dim=2).view(-1, state_dim)
        qs2 = Q_tar(inp, feats.view(-1, biz_dim)).view(-1, len(cand_feats))
        topk_vals, _ = torch.topk(qs2, K, dim=1)
        q_next = topk_vals.mean(dim=1)

        y = r + gamma * q_next.detach()

        loss = nn.MSELoss()(q, y)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item()

        for p, pt in zip(Q.parameters(), Q_tar.parameters()):
            pt.data.copy_(tau*p.data + (1-tau)*pt.data)

    print(f"Epoch {epoch+1}  loss={total_loss/len(train_loader):.4f}")
