In [None]:
# train/magrpo_train.py
import torch
import torch.nn.functional as F
import numpy as np

class MAGRPOTrainer:
    def __init__(self, analyzer, recommender, validator, env, buffer, gamma=0.99, lam=0.95, clip_eps=0.2, epochs=4, batch_size=32, device=None):
        self.analyzer = analyzer
        self.recommender = recommender
        self.validator = validator
        self.env = env
        self.buffer = buffer
        self.gamma = gamma
        self.lam = lam
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    def collect_episode(self, max_steps=50):
        state = self.env.reset()
        done = False
        steps = 0
        while not done and steps < max_steps:
            state_repr = state.tolist() if hasattr(state, 'tolist') else str(state)
            insight = self.analyzer.generate_insight(state_repr)
            plan_emb = self.analyzer.embed_text(insight)
            # convert to tensors
            obs_t = torch.tensor(state, dtype=torch.float32)
            plan_emb_t = torch.tensor(plan_emb, dtype=torch.float32)
            action, logprob, value = self.recommender.act(obs_t, plan_emb_t)
            next_state, env_reward, done, _ = self.env.step(action)
            shaped = self.validator.shaped_reward(env_reward, insight, action, tokens_used=len(insight.split()))
            transition = {
                'state': state,
                'plan_text': insight,
                'plan_emb': plan_emb,
                'action': action,
                'reward': shaped,
                'value': value,
                'logprob': logprob,
                'next_state': next_state,
                'done': done
            }
            self.buffer.store(transition)
            state = next_state
            steps += 1

    def compute_gae(self, rewards, values, dones):
        rewards = np.array(rewards)
        values = np.array(values + [0.0])
        dones = np.array(dones)
        gae = 0.0
        returns = []
        for step in reversed(range(len(rewards))):
            mask = 1.0 - float(dones[step])
            delta = rewards[step] + self.gamma * values[step+1] * mask - values[step]
            gae = delta + self.gamma * self.lam * mask * gae
            returns.insert(0, gae + values[step])
        advs = np.array(returns) - values[:-1]
        return np.array(returns), (advs - advs.mean()) / (advs.std() + 1e-8)

    def update(self):
        transitions = self.buffer.get_all()
        if len(transitions) == 0:
            return
        states = np.stack([t['state'] for t in transitions])
        plan_embs = np.stack([t['plan_emb'] for t in transitions])
        actions = np.array([t['action'] for t in transitions])
        rewards = [t['reward'] for t in transitions]
        values = [t['value'] for t in transitions]
        logprobs = np.array([t['logprob'] for t in transitions])
        dones = [t['done'] for t in transitions]

        returns, advs = self.compute_gae(rewards, values, dones)
        # convert to torch
        obs = torch.tensor(states, dtype=torch.float32).to(self.device)
        plan_emb = torch.tensor(plan_embs, dtype=torch.float32).to(self.device)
        actions_t = torch.tensor(actions).to(self.device)
        old_logprobs = torch.tensor(logprobs, dtype=torch.float32).to(self.device)
        returns_t = torch.tensor(returns, dtype=torch.float32).to(self.device)
        advs_t = torch.tensor(advs, dtype=torch.float32).to(self.device)

        n = len(actions)
        inds = np.arange(n)
        for _ in range(self.epochs):
            np.random.shuffle(inds)
            for start in range(0, n, self.batch_size):
                mb_idx = inds[start:start+self.batch_size]
                mb_obs = obs[mb_idx]
                mb_plan = plan_emb[mb_idx]
                mb_actions = actions_t[mb_idx]
                mb_old_logp = old_logprobs[mb_idx]
                mb_returns = returns_t[mb_idx]
                mb_advs = advs_t[mb_idx]

                logits, vals = self.recommender.net(mb_obs, mb_plan)
                probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)
                new_logp = dist.log_prob(mb_actions)
                ratio = torch.exp(new_logp - mb_old_logp)
                surr1 = ratio * mb_advs
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * mb_advs
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.mse_loss(vals, mb_returns)
                entropy = dist.entropy().mean()
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

                self.recommender.update(loss)

        self.buffer.clear()
