In [None]:
# agents/recommender.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ActorCritic(nn.Module):
    def __init__(self, obs_dim, plan_emb_dim, action_dim, hidden=256):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim + plan_emb_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU()
        )
        self.actor = nn.Linear(hidden, action_dim)
        self.critic = nn.Linear(hidden, 1)

    def forward(self, obs, plan_emb):
        x = torch.cat([obs, plan_emb], dim=-1)
        h = self.fc(x)
        logits = self.actor(h)
        value = self.critic(h).squeeze(-1)
        return logits, value

class Recommender:
    def __init__(self, obs_dim, plan_emb_dim, action_dim, lr=3e-4, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.net = ActorCritic(obs_dim, plan_emb_dim, action_dim).to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)

    def act(self, obs, plan_emb):
        """
        obs: torch.tensor shape (batch, obs_dim) or (obs_dim,)
        plan_emb: torch.tensor shape (batch, plan_emb_dim) or (plan_emb_dim,)
        """
        if obs.ndim == 1:
            obs = obs.unsqueeze(0)
        if plan_emb.ndim == 1:
            plan_emb = plan_emb.unsqueeze(0)
        logits, value = self.net(obs.to(self.device).float(), plan_emb.to(self.device).float())
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.cpu().item(), dist.log_prob(action).cpu().item(), value.cpu().item()

    def get_logits_values(self, obs, plan_emb):
        with torch.no_grad():
            logits, value = self.net(obs.to(self.device).float(), plan_emb.to(self.device).float())
        return logits, value

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
        self.optimizer.step()
