In [61]:
import ijson
import json
import numpy as np
import random
from sklearn.cluster import MiniBatchKMeans
from collections import deque, Counter, defaultdict
import pandas as pd

# ----------------------------------
# 参数设定
# ----------------------------------
TRAIN_FILE     = 'rl_dataset_train.json'
TEST_FILE      = 'rl_dataset_test.json'
K_SLATE        = 20        
S_CLUSTERS     = 200      
EPISODES       = 20       
REPLAY_CAP     = 200000   # 经验回放容量
BATCH_SIZE     = 256      # 批更新大小
ALPHA          = 0.05
GAMMA          = 0.95
EPSILON_START  = 0.5
EPSILON_DECAY  = 0.9
TOP_M          = 50       # 候选集M
N_TRAIN        = 100000    # 只读入2万训练
N_TEST         = 20000     # 读入5千测试
CATEGORIES_C   = 100      # 类别One-Hot维度

# ----------------------------------
# 工具：加载chunked JSON
# ----------------------------------
def load_chunked_json(path, n=None):
    recs, depth, buf = [], 0, ''
    with open(path, 'r', encoding='utf-8') as f:
        if f.read(1) != '{': f.seek(0)
        while True:
            c = f.read(1)
            if not c: break
            if c == '{': depth += 1; buf += c
            elif c == '}': buf += c; depth -= 1
            else:
                if depth > 0: buf += c
            if depth == 0 and buf:
                try: recs.append(json.loads(buf))
                except: pass
                buf = ''
                if n and len(recs) >= n: break
    return recs

# 加载数据
train_recs = load_chunked_json(TRAIN_FILE, n=N_TRAIN)
test_recs  = load_chunked_json(TEST_FILE,  n=N_TEST)
print(f"Loaded {len(train_recs)} train, {len(test_recs)} test")

# ----------------------------------
# 建立商户ID索引
# ----------------------------------
all_biz = set()
for rec in train_recs + test_recs:
    all_biz.add(rec['action'])
    for e in rec.get('state',[])+rec.get('next_state',[]):
        all_biz.add(e['business_id'])
biz2idx = {b:i for i,b in enumerate(sorted(all_biz))}
idx2biz = {i:b for b,i in biz2idx.items()}
N_BIZ = len(biz2idx)
print(f"Unique businesses: {N_BIZ}")

Loaded 12000 train, 3000 test
Unique businesses: 11953


In [62]:
# ----------------------------------
# Step 1: 类别映射 & Biz-Category Map
# ----------------------------------
counter = Counter()
for rec in train_recs:
    for e in rec.get('state',[]) + rec.get('next_state',[]):
        counter.update(c.strip() for c in e.get('business_categories', []))
CATEGORY_LIST = [cat for cat,_ in counter.most_common(CATEGORIES_C)]
print("Top categories:", CATEGORY_LIST)
# biz -> set(categories)
biz_cat_sets = defaultdict(set)
for rec in train_recs + test_recs:
    for e in rec.get('state',[]) + rec.get('next_state',[]):
        bi = biz2idx[e['business_id']]
        for c in e.get('business_categories',[]):
            biz_cat_sets[bi].add(c.strip())
# category -> list(biz_idx)
cat_biz_map = defaultdict(list)
for bi, cats in biz_cat_sets.items():
    for c in cats:
        if c in CATEGORY_LIST:
            cat_biz_map[c].append(bi)

Top categories: ['Restaurants', 'Food', 'Nightlife', 'Bars', 'American (New)', 'American (Traditional)', 'Breakfast & Brunch', 'Shopping', 'Sandwiches', 'Coffee & Tea', 'Italian', 'Arts & Entertainment', 'Pizza', 'Event Planning & Services', 'Seafood', 'Burgers', 'Mexican', 'Specialty Food', 'Salad', 'Desserts', 'Bakeries', 'Active Life', 'Japanese', 'Sushi Bars', 'Chinese', 'Cafes', 'Fast Food', 'Beauty & Spas', 'Beer', 'Cocktail Bars', 'Fashion', 'Local Flavor', 'Ice Cream & Frozen Yogurt', 'Pubs', 'Hotels & Travel', 'Asian Fusion', 'Steakhouses', 'Wine & Spirits', 'Cajun/Creole', 'Grocery', 'Delis', 'Vegetarian', 'Caterers', 'Barbeque', 'Wine Bars', 'Diners', 'Venues & Event Spaces', 'Lounges', 'Automotive', 'Thai', 'Southern', 'Sports Bars', 'Home & Garden', 'Local Services', 'Mediterranean', 'Music Venues', 'Chicken Wings', 'French', 'Gluten-Free', 'Vietnamese', 'Vegan', 'Breweries', 'Latin American', 'Soup', 'Home Services', 'Nail Salons', 'Health & Medical', 'Public Services & G

In [63]:
# ----------------------------------
# Step 2: 特征化 & 状态离散化
# ----------------------------------
def featurize_event(e):
    f = [
        e.get('days_since_action', 0),
        e.get('recency_weight', 1.0),
        e.get('stars', 0.0),
        e.get('user_profile',{}).get('review_count',0),
        e.get('user_profile',{}).get('average_stars',0.0),
        e.get('user_profile',{}).get('friend_count',0),
        e.get('business_checkin',{}).get('morning',0),
        e.get('business_checkin',{}).get('afternoon',0),
        e.get('business_checkin',{}).get('evening',0),
        len(e.get('business_categories',[]))
    ]
    for seg in ['morning','afternoon','evening']:
        f.append(1 if e.get('time_segment','')==seg else 0)
    cats = set(c.strip() for c in e.get('business_categories',[]))
    for cat in CATEGORY_LIST:
        f.append(1 if cat in cats else 0)
    return f

def flatten_states(recs, field):
    X = []
    for rec in recs:
        vec=[]
        for e in rec.get(field,[]): vec += featurize_event(e)
        X.append(vec)
    return np.array(X)

S_train  = flatten_states(train_recs,'state')
SP_train = flatten_states(train_recs,'next_state')
S_test   = flatten_states(test_recs,'state')
kmeans = MiniBatchKMeans(n_clusters=S_CLUSTERS,random_state=0,batch_size=1024)
state_ids_train = kmeans.fit_predict(S_train)
next_ids_train  = kmeans.predict(SP_train)
state_ids_test  = kmeans.predict(S_test)

actions_train = np.array([biz2idx[r['action']] for r in train_recs])
rewards_train = np.array([r['reward'] for r in train_recs])


In [64]:
# ----------------------------------
# Step 3: 经验回放 & 用户选择
# ----------------------------------
class ReplayBuffer:
    def __init__(self,cap): self.buf=deque(maxlen=cap)
    def push(self,x):self.buf.append(x)
    def sample(self,n):return random.sample(self.buf,n)
    def __len__(self):return len(self.buf)
buffer=ReplayBuffer(REPLAY_CAP)

def user_choice(slate,s,Q):
    vals=np.array([Q[s,a] for a in slate])
    exp=np.exp(vals-vals.max());probs=exp/exp.sum()
    return np.random.choice(slate,p=probs)


In [65]:
# ----------------------------------
# Step 4: 候选集生成 (Category-based)
# ----------------------------------

def retrieve_candidates(rec, M=TOP_M):
    # 基于用户最后交互的类别
    last_cats = rec['state'][-1]['business_categories']
    cand = []
    for c in last_cats:
        for bi in cat_biz_map.get(c.strip(),[]):
            if bi not in cand: cand.append(bi)
    if len(cand) >= M:
        return random.sample(cand, M)
    # 不够则补全随机
    pool = [i for i in range(N_BIZ) if i not in cand]
    return cand + random.sample(pool, M-len(cand))

In [66]:
# ----------------------------------
# Step 5: Slate 生成 (MMR)
# ----------------------------------
def generate_slate_mmr(s, rec, Q, M=TOP_M, K=K_SLATE, lam=0.7):
    cand = retrieve_candidates(rec, M)
    slate = []
    while len(slate) < K:
        best, best_score = None, -1e9
        for a in cand:
            if a in slate: continue
            rel = Q[s,a]
            if not slate:
                score = rel
            else:
                # Jaccard similarity
                div = max(
                    len(biz_cat_sets[a] & biz_cat_sets[b])/
                    len(biz_cat_sets[a] | biz_cat_sets[b])
                    for b in slate
                )
                score = lam*rel - (1-lam)*div
            if score>best_score:
                best_score, best = score, a
        slate.append(best)
    return slate

In [67]:
# ----------------------------------
# Step 6: Tabular SlateQ Agent
# ----------------------------------
class SlateQAgent:
    def __init__(self,S,N,alpha,gamma,eps):
        self.Q = np.zeros((S,N));self.alpha=alpha;self.gamma=gamma;self.epsilon=eps
    def select(self,s,rec):
        if random.random()<self.epsilon:
            return random.sample(range(N_BIZ),K_SLATE)
        return generate_slate_mmr(s,rec,self.Q)
    def update(self,s,slate,r,s_next):
        best = np.max(self.Q[s_next])
        for a in slate:
            td = r + self.gamma*best - self.Q[s,a]
            self.Q[s,a] += self.alpha*td
agent=SlateQAgent(S_CLUSTERS,N_BIZ,ALPHA,GAMMA,EPSILON_START)


In [68]:
# ----------------------------------
# Step 7: 训练循环（带进度条）
# ----------------------------------
from tqdm.auto import tqdm

for ep in range(EPISODES):
    print(f"--- Starting Episode {ep+1}/{EPISODES} ---")
    # 使用 tqdm 展示对训练样本的进度
    for i, rec in enumerate(tqdm(train_recs, desc=f"Epoch {ep+1} samples", ncols=80)):
        s = state_ids_train[i]
        slate = agent.select(s, rec)
        r = rewards_train[i]
        s_next = next_ids_train[i]
        buffer.push((s, slate, r, s_next))

    # 批量更新也用 tqdm
    for _ in tqdm(range(len(buffer) // BATCH_SIZE), desc="Batch updates", ncols=80):
        for s, slate, r, s_next in buffer.sample(BATCH_SIZE):
            agent.update(s, slate, r, s_next)

    agent.epsilon *= EPSILON_DECAY
    print(f"Completed Episode {ep+1}/{EPISODES}, ε={agent.epsilon:.3f}")



--- Starting Episode 1/10 ---


Epoch 1 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                     | 0/46 [00:00<?, ?it/s]

Completed Episode 1/10, ε=0.450
--- Starting Episode 2/10 ---


Epoch 2 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                     | 0/93 [00:00<?, ?it/s]

Completed Episode 2/10, ε=0.405
--- Starting Episode 3/10 ---


Epoch 3 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/140 [00:00<?, ?it/s]

Completed Episode 3/10, ε=0.365
--- Starting Episode 4/10 ---


Epoch 4 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/187 [00:00<?, ?it/s]

Completed Episode 4/10, ε=0.328
--- Starting Episode 5/10 ---


Epoch 5 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/234 [00:00<?, ?it/s]

Completed Episode 5/10, ε=0.295
--- Starting Episode 6/10 ---


Epoch 6 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/281 [00:00<?, ?it/s]

Completed Episode 6/10, ε=0.266
--- Starting Episode 7/10 ---


Epoch 7 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/328 [00:00<?, ?it/s]

Completed Episode 7/10, ε=0.239
--- Starting Episode 8/10 ---


Epoch 8 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/375 [00:00<?, ?it/s]

Completed Episode 8/10, ε=0.215
--- Starting Episode 9/10 ---


Epoch 9 samples:   0%|                                | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/421 [00:00<?, ?it/s]

Completed Episode 9/10, ε=0.194
--- Starting Episode 10/10 ---


Epoch 10 samples:   0%|                               | 0/12000 [00:00<?, ?it/s]

Batch updates:   0%|                                    | 0/468 [00:00<?, ?it/s]

Completed Episode 10/10, ε=0.174


In [69]:
# ----------------------------------
# Step 8: 离线评估 HIT/NDCG...
# ----------------------------------
agent.epsilon=0.0
hits=0;ndcg=0.0
for i,rec in enumerate(test_recs):
    s = state_ids_test[i]
    slate=agent.select(s,rec)
    true = biz2idx[rec['action']]
    if true in slate:
        hits+=1
        pos = slate.index(true)
        ndcg += 1/np.log2(pos+2)
num=len(test_recs)
hit_rate=hits/num
precision=hits/(num*K_SLATE)
recall=hit_rate
f1 = 2*precision*recall/(precision+recall) if precision+recall>0 else 0.0
ndcg=ndcg/num
print(f"Hit Rate@{K_SLATE}: {hit_rate:.4f}")
print(f"Precision@{K_SLATE}: {precision:.4f}")
print(f"Recall@{K_SLATE}: {recall:.4f}")
print(f"F1@{K_SLATE}: {f1:.4f}")
print(f"NDCG@{K_SLATE}: {ndcg:.4f}")


Hit Rate@20: 0.0013
Precision@20: 0.0001
Recall@20: 0.0013
F1@20: 0.0001
NDCG@20: 0.0003
