# Train Recsys Artifacts on MongoDB (Two-Tower + Phoenix + FAISS)

这个 Notebook 用于 **训练你自己的语料（MongoDB posts/user_actions）**，并生成可直接发布到 GCS 的 artifacts，使线上 `ml-services` 的 `ANN` 返回 **Mongo `posts._id` (ObjectId string)**，从而真正打通 OON（Out-of-Network）召回。

## 你是否需要 GPU？

- **必须做**：Two-Tower + FAISS（否则 ANN 仍然会返回 `Nxxxxx` 这类外部语料 ID，无法 hydrate）。
- Two-Tower 训练：小规模数据 **CPU 可跑**，但会慢；建议用 GPU（A100/T4 都行）。
- Phoenix 训练：Transformer 结构，**强烈建议 GPU**；同时建议用更小的模型配置以满足 Cloud Run CPU 推理的 p95 目标。

## 输出文件（与服务端加载路径一致）

- `data/news_vocab.pkl`：key 必须是 **Mongo ObjectId string**
- `data/user_vocab.pkl`
- `data/item_embeddings.npy`
- `models/two_tower_epoch_latest.pt`
- `models/phoenix_epoch_latest.pt`（可选）
- `models/faiss_ivf_pq.index`
- `models/faiss_id_mapping.pkl`

> 最后会生成一个 `stage/` 目录，结构为：`two_tower/model.pt`、`phoenix/model.pt`、`faiss/*`、`data/*`，用于你本地 `gcloud storage cp -r` 上传到 `gs://telegram-467705-recsys/artifacts/$ARTIFACT_VERSION/`。


In [None]:
# --- 0) 环境检查 / GPU 检测 ---
import os
import sys
import platform

print('python:', sys.version)
print('platform:', platform.platform())

try:
    import torch
    print('torch:', torch.__version__)
    print('cuda available:', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('cuda:', torch.version.cuda)
        print('gpu:', torch.cuda.get_device_name(0))
except Exception as e:
    print('torch import failed:', e)


In [None]:
# --- 1) 安装依赖（Colab 一般已经有 torch/numpy；这里只补齐缺的） ---
# 说明：faiss-cpu 用于构建索引；训练本身不依赖 faiss。
# 如遇到安装冲突，可先重启运行时再执行。

!pip -q install "pymongo>=4.6" "tqdm>=4.65" "faiss-cpu>=1.7.4" "google-cloud-storage>=2.16" "python-dotenv>=1.0"


In [None]:
# --- 2) 配置参数（按你的小规模目标做默认值；可自行调整） ---
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path

@dataclass
class Cfg:
    # 数据时间窗：用于导出训练数据
    days_posts: int = 30
    days_actions: int = 30

    # Two-Tower
    two_tower_embedding_dim: int = 256
    max_history: int = 100
    neg_per_pos: int = 4
    batch_size: int = 2048
    epochs: int = 3
    lr: float = 1e-3

    # Phoenix（可选，默认训练一个小模型，适合 Cloud Run CPU 推理）
    train_phoenix: bool = True
    phoenix_embedding_dim: int = 256
    phoenix_num_heads: int = 8
    phoenix_num_layers: int = 4
    phoenix_batch_size: int = 512
    phoenix_epochs: int = 2
    phoenix_lr: float = 5e-4

    # 输出目录（与 ml-services 目录结构一致）
    out_root: Path = Path('.')

cfg = Cfg()
print(cfg)


In [None]:
# --- 3) 读取 MongoDB（用 getpass 输入，避免把密码写进 Notebook） ---
import getpass
from pymongo import MongoClient

if not os.environ.get('MONGODB_URI'):
    os.environ['MONGODB_URI'] = getpass.getpass('MONGODB_URI (Mongo Atlas URI): ')

mongo_uri = os.environ['MONGODB_URI']
client = MongoClient(mongo_uri)
db = client.get_default_database()
if db is None:
    # 如果 URI 没写默认库名，就手动指定
    db_name = os.environ.get('MONGODB_DB') or os.environ.get('MONGODB_DATABASE') or getpass.getpass('Mongo DB name: ')
    db = client[db_name]

print('db:', db.name)
posts_col = db['posts']
actions_col = db['user_actions']


In [None]:
# --- 4) 拉取 posts / user_actions（小规模可直接拉入内存；大了就要分批/落盘） ---
from tqdm import tqdm

now = datetime.now(timezone.utc)
cutoff_posts = (now - timedelta(days=cfg.days_posts)).replace(tzinfo=None)
cutoff_actions = (now - timedelta(days=cfg.days_actions)).replace(tzinfo=None)

# 只训练非新闻、未删除的帖子
posts_cursor = posts_col.find(
    {
        'deletedAt': None,
        'isNews': {'$ne': True},
        'createdAt': {'$gte': cutoff_posts},
    },
    {'_id': 1, 'authorId': 1, 'createdAt': 1, 'engagementScore': 1}
)
posts = list(posts_cursor)
post_ids = [str(p['_id']) for p in posts]
post_id_set = set(post_ids)
print('posts:', len(post_ids))

# 行为：用于训练（注意：你现在的前端已做“进入视口”曝光；数据质量较好）
actions_cursor = actions_col.find(
    {
        'timestamp': {'$gte': cutoff_actions},
        'targetPostId': {'$exists': True, '$ne': None},
        'action': {'$in': ['impression', 'click', 'like', 'reply', 'repost', 'quote', 'delivery']},
    },
    {'userId': 1, 'action': 1, 'targetPostId': 1, 'timestamp': 1, 'requestId': 1}
)

actions = []
for a in actions_cursor:
    pid = str(a.get('targetPostId')) if a.get('targetPostId') is not None else None
    if not pid or pid not in post_id_set:
        continue
    uid = str(a.get('userId') or '')
    if not uid:
        continue
    actions.append({
        'userId': uid,
        'action': str(a.get('action')),
        'postId': pid,
        'ts': a.get('timestamp'),
        'requestId': a.get('requestId'),
    })

print('actions:', len(actions))


In [None]:
# --- 5) 构建 vocab（关键：news_vocab 的 key 必须是 Mongo ObjectId string） ---
import pickle

DATA_DIR = cfg.out_root / 'data'
MODELS_DIR = cfg.out_root / 'models'
DATA_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# news_vocab: postId -> index
news_vocab = {'<PAD>': 0, '<UNK>': 1}
for pid in sorted(post_id_set):
    news_vocab[pid] = len(news_vocab)

# user_vocab: userId -> index
user_ids = sorted({a['userId'] for a in actions})
user_vocab = {'<PAD>': 0, '<UNK>': 1}
for uid in user_ids:
    user_vocab[uid] = len(user_vocab)

with open(DATA_DIR / 'news_vocab.pkl', 'wb') as f:
    pickle.dump(news_vocab, f)
with open(DATA_DIR / 'user_vocab.pkl', 'wb') as f:
    pickle.dump(user_vocab, f)

print('news_vocab size:', len(news_vocab))
print('user_vocab size:', len(user_vocab))


In [None]:
# --- 6) 构建训练样本（Two-Tower pointwise + 负采样） ---
import random
from collections import defaultdict

random.seed(42)

POS_ACTIONS = {'click', 'like', 'reply', 'repost', 'quote'}

# 按 user 分组并按时间排序
by_user = defaultdict(list)
for a in actions:
    by_user[a['userId']].append((a['ts'], a['action'], a['postId']))
for uid in list(by_user.keys()):
    by_user[uid].sort(key=lambda x: x[0] or datetime.min)

all_posts_list = list(post_id_set)

samples = []

for uid, events in tqdm(by_user.items(), desc='build samples'):
    history = []
    history_set = set()

    for ts, act, pid in events:
        # 正样本
        if act in POS_ACTIONS:
            samples.append({
                'user_id': uid,
                'history': history[-cfg.max_history:],
                'candidate_id': pid,
                'label': 1.0,
            })

            # 负采样
            for _ in range(cfg.neg_per_pos):
                for _try in range(20):
                    neg = random.choice(all_posts_list)
                    if neg == pid:
                        continue
                    if neg in history_set:
                        continue
                    samples.append({
                        'user_id': uid,
                        'history': history[-cfg.max_history:],
                        'candidate_id': neg,
                        'label': 0.0,
                    })
                    break

        # 更新历史（把所有曝光/点击等都加入历史，避免只看强正反馈导致稀疏）
        if pid not in history_set:
            history.append(pid)
            history_set.add(pid)

print('total samples:', len(samples))
random.shuffle(samples)

split = int(len(samples) * 0.95)
train_samples = samples[:split]
val_samples = samples[split:]

print('train:', len(train_samples), 'val:', len(val_samples))


In [None]:
# --- 7) Two-Tower 训练（GPU 推荐；CPU 也能跑但更慢） ---
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 优先使用仓库内实现；如果你没有 clone 仓库，则下一格会提供 fallback 实现
using_repo_model = False
try:
    from scripts.model_arch import TwoTowerModel  # type: ignore
    using_repo_model = True
    print('✅ Using TwoTowerModel from scripts.model_arch')
except Exception as e:
    print('⚠️ Cannot import scripts.model_arch.TwoTowerModel, will use fallback. err=', e)


如果你在 Colab 里 `git clone` 了仓库并把工作目录切到 `telegram/ml-services/`，上一格会自动 import `scripts.model_arch.TwoTowerModel`。

如果 import 失败（比如你没 clone 仓库），下一格会定义一个 **与服务端一致** 的 fallback 版本 TwoTower 模型结构。


In [None]:
# --- 7.1) TwoTowerModel fallback（当无法 import scripts.model_arch 时） ---
import torch
import torch.nn as nn
import torch.nn.functional as F

if using_repo_model:
    print('✅ Repo TwoTowerModel already loaded, skip fallback')
else:
    class _NewsEncoder(nn.Module):
        def __init__(self, num_news, embedding_dim):
            super().__init__()
            self.news_embedding = nn.Embedding(num_news, embedding_dim)
            self.fc = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.ReLU(),
                nn.Linear(embedding_dim, embedding_dim),
            )

        def forward(self, news_ids):
            emb = self.news_embedding(news_ids)
            out = self.fc(emb)
            return F.normalize(out, p=2, dim=1)

    class _UserEncoder(nn.Module):
        def __init__(self, num_users, news_encoder, embedding_dim):
            super().__init__()
            self.user_embedding = nn.Embedding(num_users, embedding_dim)
            self.news_encoder = news_encoder
            self.fc = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.ReLU(),
                nn.Linear(embedding_dim, embedding_dim),
            )

        def forward(self, user_ids, history_news_ids, history_mask):
            user_emb = self.user_embedding(user_ids)
            hist_emb = self.news_encoder.news_embedding(history_news_ids)
            mask = history_mask.unsqueeze(-1)
            hist_emb = hist_emb * mask
            sum_emb = hist_emb.sum(dim=1)
            count = mask.sum(dim=1).clamp(min=1)
            avg_hist = sum_emb / count
            combined = user_emb + avg_hist
            out = self.fc(combined)
            return F.normalize(out, p=2, dim=1)

    class TwoTowerModel(nn.Module):
        def __init__(self, num_users, num_news, embedding_dim):
            super().__init__()
            self.news_encoder = _NewsEncoder(num_news, embedding_dim)
            self.user_encoder = _UserEncoder(num_users, self.news_encoder, embedding_dim)

        def forward(self, user_ids, history_news_ids, history_mask, target_news_ids):
            user_vec = self.user_encoder(user_ids, history_news_ids, history_mask)
            item_vec = self.news_encoder(target_news_ids)
            logits = (user_vec * item_vec).sum(dim=1)
            return logits, user_vec, item_vec


In [None]:
# --- 7.2) Dataset / DataLoader ---
class TwoTowerDataset(Dataset):
    def __init__(self, samples, news_vocab, user_vocab, max_history_len):
        self.samples = samples
        self.news_vocab = news_vocab
        self.user_vocab = user_vocab
        self.max_history_len = max_history_len
        self.unk_news = news_vocab.get('<UNK>', 1)
        self.unk_user = user_vocab.get('<UNK>', 1)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        uid = self.user_vocab.get(s['user_id'], self.unk_user)
        hist = [self.news_vocab.get(pid, self.unk_news) for pid in (s.get('history') or [])]
        if len(hist) > self.max_history_len:
            hist = hist[-self.max_history_len:]
            mask = [1.0] * self.max_history_len
        else:
            pad = self.max_history_len - len(hist)
            mask = [1.0] * len(hist) + [0.0] * pad
            hist = hist + [0] * pad

        target = self.news_vocab.get(s['candidate_id'], self.unk_news)
        label = float(s['label'])

        return (
            torch.tensor(uid, dtype=torch.long),
            torch.tensor(hist, dtype=torch.long),
            torch.tensor(mask, dtype=torch.float),
            torch.tensor(target, dtype=torch.long),
            torch.tensor(label, dtype=torch.float),
        )

train_ds = TwoTowerDataset(train_samples, news_vocab, user_vocab, cfg.max_history)
val_ds = TwoTowerDataset(val_samples, news_vocab, user_vocab, cfg.max_history)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)

print('train batches:', len(train_loader), 'val batches:', len(val_loader))


In [None]:
# --- 7.3) Train loop ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

model = TwoTowerModel(
    num_users=len(user_vocab),
    num_news=len(news_vocab),
    embedding_dim=cfg.two_tower_embedding_dim,
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=cfg.lr)

use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_val = 1e9

for epoch in range(cfg.epochs):
    model.train()
    total = 0.0
    for (uid, hist, mask, tgt, y) in tqdm(train_loader, desc=f'two-tower train e{epoch+1}/{cfg.epochs}'):
        uid = uid.to(device)
        hist = hist.to(device)
        mask = mask.to(device)
        tgt = tgt.to(device)
        y = y.to(device)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits, _, _ = model(uid, hist, mask, tgt)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += float(loss.detach().cpu().item())

    avg_train = total / max(1, len(train_loader))

    model.eval()
    vtotal = 0.0
    with torch.no_grad():
        for (uid, hist, mask, tgt, y) in tqdm(val_loader, desc=f'two-tower val e{epoch+1}/{cfg.epochs}'):
            uid = uid.to(device)
            hist = hist.to(device)
            mask = mask.to(device)
            tgt = tgt.to(device)
            y = y.to(device)
            logits, _, _ = model(uid, hist, mask, tgt)
            loss = criterion(logits, y)
            vtotal += float(loss.detach().cpu().item())

    avg_val = vtotal / max(1, len(val_loader))
    print(f'[TwoTower] epoch={epoch+1} train={avg_train:.4f} val={avg_val:.4f}')

    # save best
    if avg_val < best_val:
        best_val = avg_val
        torch.save(model.state_dict(), MODELS_DIR / 'two_tower_epoch_latest.pt')
        print('  ✅ saved best to models/two_tower_epoch_latest.pt')


In [None]:
# --- 8) 导出 item_embeddings.npy（用于 FAISS） ---
model.eval()
with torch.no_grad():
    emb = model.news_encoder.news_embedding.weight.detach().cpu().numpy().astype(np.float32)

np.save(DATA_DIR / 'item_embeddings.npy', emb)
print('saved:', DATA_DIR / 'item_embeddings.npy', 'shape=', emb.shape)


In [None]:
# --- 9) 构建 FAISS index + id mapping（ivf_pq 默认，适合未来扩展；小规模也可改成 flat） ---
import faiss
import numpy as np

emb = np.load(DATA_DIR / 'item_embeddings.npy').astype(np.float32)
# L2 normalize for cosine similarity (IndexFlatIP / IVF* with IP)
norms = np.linalg.norm(emb, axis=1, keepdims=True)
emb = emb / (norms + 1e-10)

dim = emb.shape[1]

n = emb.shape[0]
nlist = max(64, min(int(np.sqrt(n) * 4), 2048))

m = 8
if dim % m != 0:
    # choose a divisor
    for cand in [4, 8, 16, 32]:
        if dim % cand == 0:
            m = cand
            break

print('faiss config:', {'dim': dim, 'n': n, 'nlist': nlist, 'm': m})

quantizer = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, 8, faiss.METRIC_INNER_PRODUCT)

print('training index...')
index.train(emb)
print('adding...')
index.add(emb)
index.nprobe = 16

faiss_path = MODELS_DIR / 'faiss_ivf_pq.index'
faiss.write_index(index, str(faiss_path))
print('saved:', faiss_path, 'ntotal=', index.ntotal)

# mapping
idx_to_news_id = {v: k for k, v in news_vocab.items()}
import pickle
map_path = MODELS_DIR / 'faiss_id_mapping.pkl'
with open(map_path, 'wb') as f:
    pickle.dump({'news_vocab': news_vocab, 'idx_to_news_id': idx_to_news_id}, f)
print('saved:', map_path)


In [None]:
# --- 10)（可选）训练 Phoenix：多动作 heads 同时训练（click/like/reply/repost） ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class PhoenixRanker(nn.Module):
    def __init__(self, num_news, embedding_dim=256, num_heads=8, num_layers=4, dropout=0.1):
        super().__init__()
        self.news_embedding = nn.Embedding(num_news, embedding_dim)
        self.position_embedding = nn.Embedding(512, embedding_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.click_head = nn.Linear(embedding_dim, 1)
        self.like_head = nn.Linear(embedding_dim, 1)
        self.reply_head = nn.Linear(embedding_dim, 1)
        self.repost_head = nn.Linear(embedding_dim, 1)

        self.dropout = nn.Dropout(dropout)
        self.ln_f = nn.LayerNorm(embedding_dim)

    def create_isolation_mask(self, history_len, num_candidates):
        total_len = history_len + num_candidates
        mask = torch.zeros((total_len, total_len), dtype=torch.float)
        cand_region = torch.full((num_candidates, num_candidates), float('-inf'))
        cand_region.fill_diagonal_(0.0)
        mask[history_len:, history_len:] = cand_region
        mask[:history_len, history_len:] = float('-inf')
        return mask

    def forward(self, history_ids, candidate_ids):
        bsz, hist_len = history_ids.shape
        _, cand_len = candidate_ids.shape
        input_ids = torch.cat([history_ids, candidate_ids], dim=1)
        seq_len = input_ids.shape[1]
        x = self.news_embedding(input_ids)
        pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.dropout(x + self.position_embedding(pos))
        attn_mask = self.create_isolation_mask(hist_len, cand_len).to(input_ids.device)
        out = self.transformer(x, mask=attn_mask)
        out = self.ln_f(out)
        cand_out = out[:, hist_len:, :]
        return {
            'click': self.click_head(cand_out).squeeze(-1),
            'like': self.like_head(cand_out).squeeze(-1),
            'reply': self.reply_head(cand_out).squeeze(-1),
            'repost': self.repost_head(cand_out).squeeze(-1),
        }

if not cfg.train_phoenix:
    print('cfg.train_phoenix is False, skip')
else:
    # 训练样本：直接用 user_actions 记录，按 action type 给对应 head 打 label
    ACTION_TO_HEAD = {
        'click': 'click',
        'like': 'like',
        'reply': 'reply',
        'repost': 'repost',
        'quote': 'repost',
    }

    phoenix_samples = []
    for uid, events in by_user.items():
        history = []
        history_set = set()
        for ts, act, pid in events:
            head = ACTION_TO_HEAD.get(act)
            if head:
                y = {'click': 0.0, 'like': 0.0, 'reply': 0.0, 'repost': 0.0}
                y[head] = 1.0
                phoenix_samples.append({
                    'history': history[-cfg.max_history:],
                    'candidate_id': pid,
                    'y': y,
                })
            if pid not in history_set:
                history.append(pid)
                history_set.add(pid)

    random.shuffle(phoenix_samples)
    split = int(len(phoenix_samples) * 0.95)
    ph_train = phoenix_samples[:split]
    ph_val = phoenix_samples[split:]
    print('phoenix samples:', len(phoenix_samples), 'train:', len(ph_train), 'val:', len(ph_val))

    class PhoenixDataset(Dataset):
        def __init__(self, samples, news_vocab, max_history):
            self.samples = samples
            self.news_vocab = news_vocab
            self.max_history = max_history
            self.unk = news_vocab.get('<UNK>', 1)

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            s = self.samples[idx]
            hist = [self.news_vocab.get(pid, self.unk) for pid in (s.get('history') or [])]
            if len(hist) > self.max_history:
                hist = hist[-self.max_history:]
            else:
                hist = hist + [0] * (self.max_history - len(hist))

            cand = self.news_vocab.get(s['candidate_id'], self.unk)
            y = s['y']
            return (
                torch.tensor(hist, dtype=torch.long),
                torch.tensor([cand], dtype=torch.long),
                torch.tensor([y['click']], dtype=torch.float),
                torch.tensor([y['like']], dtype=torch.float),
                torch.tensor([y['reply']], dtype=torch.float),
                torch.tensor([y['repost']], dtype=torch.float),
            )

    ph_train_loader = DataLoader(PhoenixDataset(ph_train, news_vocab, cfg.max_history), batch_size=cfg.phoenix_batch_size, shuffle=True, num_workers=2, pin_memory=True)
    ph_val_loader = DataLoader(PhoenixDataset(ph_val, news_vocab, cfg.max_history), batch_size=cfg.phoenix_batch_size, shuffle=False, num_workers=2, pin_memory=True)

    ph_model = PhoenixRanker(
        num_news=len(news_vocab),
        embedding_dim=cfg.phoenix_embedding_dim,
        num_heads=cfg.phoenix_num_heads,
        num_layers=cfg.phoenix_num_layers,
        dropout=0.1,
    ).to(device)

    bce = nn.BCEWithLogitsLoss()
    opt = optim.AdamW(ph_model.parameters(), lr=cfg.phoenix_lr)

    use_amp = torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # head weights：回复/转发通常稀疏，略放大
    w = {'click': 1.0, 'like': 2.0, 'reply': 3.0, 'repost': 3.0}

    best = 1e9
    for epoch in range(cfg.phoenix_epochs):
        ph_model.train()
        total = 0.0
        for (hist, cand, y_click, y_like, y_reply, y_repost) in tqdm(ph_train_loader, desc=f'phoenix train e{epoch+1}/{cfg.phoenix_epochs}'):
            hist = hist.to(device)
            cand = cand.to(device)
            y_click = y_click.to(device)
            y_like = y_like.to(device)
            y_reply = y_reply.to(device)
            y_repost = y_repost.to(device)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=use_amp):
                out = ph_model(hist, cand)
                # [B,1]
                lc = bce(out['click'].flatten(), y_click.flatten()) * w['click']
                ll = bce(out['like'].flatten(), y_like.flatten()) * w['like']
                lr = bce(out['reply'].flatten(), y_reply.flatten()) * w['reply']
                lrp = bce(out['repost'].flatten(), y_repost.flatten()) * w['repost']
                loss = lc + ll + lr + lrp

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            total += float(loss.detach().cpu().item())

        avg_train = total / max(1, len(ph_train_loader))

        ph_model.eval()
        vtotal = 0.0
        with torch.no_grad():
            for (hist, cand, y_click, y_like, y_reply, y_repost) in tqdm(ph_val_loader, desc=f'phoenix val e{epoch+1}/{cfg.phoenix_epochs}'):
                hist = hist.to(device)
                cand = cand.to(device)
                y_click = y_click.to(device)
                y_like = y_like.to(device)
                y_reply = y_reply.to(device)
                y_repost = y_repost.to(device)
                out = ph_model(hist, cand)
                lc = bce(out['click'].flatten(), y_click.flatten()) * w['click']
                ll = bce(out['like'].flatten(), y_like.flatten()) * w['like']
                lr = bce(out['reply'].flatten(), y_reply.flatten()) * w['reply']
                lrp = bce(out['repost'].flatten(), y_repost.flatten()) * w['repost']
                loss = lc + ll + lr + lrp
                vtotal += float(loss.detach().cpu().item())

        avg_val = vtotal / max(1, len(ph_val_loader))
        print(f'[Phoenix] epoch={epoch+1} train={avg_train:.4f} val={avg_val:.4f}')

        if avg_val < best:
            best = avg_val
            torch.save(ph_model.state_dict(), MODELS_DIR / 'phoenix_epoch_latest.pt')
            print('  ✅ saved best to models/phoenix_epoch_latest.pt')


In [None]:
# --- 11) 生成可上传的 stage/ 目录（结构对齐你的 GCS artifacts 约定） ---
import shutil
from pathlib import Path

stage = Path('stage')
if stage.exists():
    shutil.rmtree(stage)

(stage / 'two_tower').mkdir(parents=True, exist_ok=True)
(stage / 'phoenix').mkdir(parents=True, exist_ok=True)
(stage / 'faiss').mkdir(parents=True, exist_ok=True)
(stage / 'data').mkdir(parents=True, exist_ok=True)

# 必需
shutil.copy2(MODELS_DIR / 'two_tower_epoch_latest.pt', stage / 'two_tower' / 'model.pt')

phoenix_path = MODELS_DIR / 'phoenix_epoch_latest.pt'
if phoenix_path.exists():
    shutil.copy2(phoenix_path, stage / 'phoenix' / 'model.pt')
else:
    print('phoenix model missing, you can still upload two-tower + faiss and keep phoenix disabled on server')

shutil.copy2(MODELS_DIR / 'faiss_ivf_pq.index', stage / 'faiss' / 'faiss_ivf_pq.index')
shutil.copy2(MODELS_DIR / 'faiss_id_mapping.pkl', stage / 'faiss' / 'faiss_id_mapping.pkl')

# 数据
shutil.copy2(DATA_DIR / 'news_vocab.pkl', stage / 'data' / 'news_vocab.pkl')
shutil.copy2(DATA_DIR / 'user_vocab.pkl', stage / 'data' / 'user_vocab.pkl')
shutil.copy2(DATA_DIR / 'item_embeddings.npy', stage / 'data' / 'item_embeddings.npy')

print('stage ready:', stage)
for p in stage.rglob('*'):
    if p.is_file():
        print(' -', p)


In [None]:
# --- 12) （可选）打包 stage，便于从 Colab 下载到本地再上传到 GCS ---
import tarfile
from pathlib import Path

bundle = Path('stage_bundle.tgz')
if bundle.exists():
    bundle.unlink()

with tarfile.open(bundle, 'w:gz') as tar:
    tar.add('stage', arcname='stage')

size_mb = bundle.stat().st_size / (1024 * 1024)
print('bundle:', bundle, f'{size_mb:.2f} MB')
print('下载方式：Colab 左侧 Files 面板里找到 stage_bundle.tgz -> 右键 Download')


## 上传到 GCS（在你本地电脑执行，沿用你之前成功的流程）

1) 选择新版本号：`ARTIFACT_VERSION=YYYY-MM-DD_buildNN`，例如：`2026-02-07_build02`

2) 把 stage 上传到 bucket：

```bash
BUCKET="telegram-467705-recsys"
ARTIFACT_VERSION="2026-02-07_build02"

# 将 stage/ 目录上传到 artifacts/$ARTIFACT_VERSION/
gcloud storage cp -r stage/* "gs://$BUCKET/artifacts/$ARTIFACT_VERSION/"

# 校验
gcloud storage ls -r "gs://$BUCKET/artifacts/$ARTIFACT_VERSION/**"
```

3) 更新 Cloud Run：

```bash
gcloud run services update telegram-ml-services \
  --project telegram-467705 \
  --region us-central1 \
  --update-env-vars "ARTIFACT_VERSION=$ARTIFACT_VERSION,TWO_TOWER_EMBEDDING_DIM=256,PHOENIX_EMBEDDING_DIM=256,PHOENIX_NUM_HEADS=8,PHOENIX_NUM_LAYERS=4"
```

4) 验证：

- `GET /health` 应显示 `models_loaded=true`、`faiss_enabled=true`
- `POST /ann/retrieve` 返回的 `postId` 应该是 24 位 hex 的 ObjectId（而不是 `Nxxxxx`）
