# Notes

## Rewards

NLI encoder consists of:

- Embedder $f_\theta:\text{text}\rightarrow\mathbb{R}^d$
- Classifer $g_\phi:\mathbb{R}^d\rightarrow(a_e, a_n, a_c)$ (logits)
- (Or $h(\mathbf{x})=\text{softmax}(g(\mathbf{x}))=(p_e, p_n, p_c)$)

Recall softmax: $p_i = e^{a_i} / Z, \ Z=\sum_{i'\neq i} e^{a_{i'}}$.


Reward signal needs to be a number, not a tuple.  We can either do E vs. C, or E vs $\lnot C$ (aka E vs. (N or C)).

The log-odds of E vs. C are

$$
\log \frac{p_e}{p_c} = \log \frac{e^{a_e}/Z}{e^{a_c}/Z} = \log \frac{e^{a_e}}{e^{a_c}} = a_e - a_c
$$

The log-odds of E vs. $\lnot C$ are

$$
\log \frac{p_e}{p_n + p_c} = a_e - \log(e^{a_n}+e^{a_c})
$$

[ continue here ]

Assume encoder is calibrated as

$$
h(\mathbf{a})_i = p_i = \text{Pr}(Y=i|\mathbf{x}) = 
\frac{\text{Pr}(\mathbf{x}|Y=i)\pi_i}{\sum_{i'}\text{Pr}(\mathbf{x}|Y=i')\pi_{i'}}
$$

[ continue ]

So for each state $s_t$ we have associated log-odds

$$
\ell_t = \ell_0 + \sum_{t'=1}^t \ell_{t'}
$$

## Returns

Assume we have reward $r_t$ for state $s_t$ 

$$
r_t = \ell_{t+1} - \ell_t
$$

(NOTE: because of our independence assumption this reduces to $\ell_{t+1}$...)

Now define the return at time $t$ for utterance $v$ as

$$
G_v^{(t)} = \sum_{t'=t}^{T-1} \gamma^{t'-t} r^{(t')}
$$



## Cluster-level

The tree consists of nodes $S_j^{(t)}$.  Each node represents: 

- Action-type $a$.  This is the type of action taken at this node.
- Persuader clusters $\mathcal{U} = \{U_k\}$.  ($|\mathcal{U}| = K$)  Each cluster $U_k$ consists of persuader utterances $u_{k_i}$.
- Target clusters $\mathcal{V} = \{V_m\}$. ($|\mathcal{V}| = M$)  Target utterances $v_{m_i}$.

For $\mathcal{U}$, we use a semantic embedding model.  For $\mathcal{V}$, a NLI-based embedding model.

We also track the following $(K, M)$ matrices

- $N_{k\rightarrow m}$: number of times an utterance in $k$ led to $m$
- $W_{k\rightarrow m} = \sum_{v\in V_m} G_v$: total returns from $m$ coming from $k$
- $Q_{k\rightarrow m} = (W/N)_{k\rightarrow m}$: average return
- $\pi_{k\rightarrow m} = N_{k\rightarrow m}/\sum_{m'} N_{k\rightarrow m'}$: empirical transition matrix (row stochastic)

And the following value functions:

- $Q_k = \mathbb{E}[G|u\in U_k] = \sum_m \pi_{k\rightarrow m} Q_{k\rightarrow m}$: expected empirical return for a persuader utterance in $U_k$
- $Q_a = \max_k Q_k$: value function for action $a$

## Clustered Open-loop MCTS (COL-MCTS)

### Selection

Select to an action $a$ using Upper Confidence Tree criteria:
$$
a^* = \text{arg}\max_a \left[ Q_a + c_1 P(a|s) \frac{\sqrt{N(s)}}{1+N(s,a)} \right]
$$

This balances exploitation of known high scoring branches (first term) with exploration of less explored branches (second term).

$N(s)$ is visits to parent (where we are), $N(s,a)$ is visits to this action, so if $a$ is less explored, this term goes up.

For now, assume $P(a|s) = 1/|A|$.

### Expansion

Expand an unexplored action from this node, choosing from remaining actions uniformly at random.  We now have a path corresponding to an action sequence $(a_{0:t})$ with $a_t$ the newly expanded node.

### Rollout

We simulate an entire conversation along this sequence (open-loop), using clusters to further condition persuader utterance type and score target utterances (clustered).

Specifically, we iterate through the path, at each step (node):

- select a persuader cluster $U_k$ using Upper Confidence Bound:
$$
k^* = \text{arg}\max Q_k + c_2 \sqrt{\frac{\log \sum N_j}{1+N_k}}
$$

- generate persuader response.  If $k^*$ is "none", we only condition on action type; otherwise we further condition on that persuader response type.

- generate target response

- add this pair to the node, get its cluster assignments, score, and record $(k,m,r)$

At the end we have a list of $\mathcal{P} = (\text{node}_t, k_t,m_t,r_t)$ for the expanded path.

### Backprop

Now we traverse through $\mathcal{P}$ in reverse, computing

```
G = 0
for node, k, m, r in reversed(path):
    G = r + gamma * G
    node.update(k, m, G)
```

where `.update` updates node internal metrics ($N$, $Q$, etc).


# code

In [1]:
import json
import sys
import os
import pickle

# hack so we can import normally from other packages
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import matplotlib.pyplot as plt
import numpy as np
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from sentence_transformers import SentenceTransformer
import torch

from agent.agent import Agent
from agent.llm_client import LLMClient
from mcts.mcts_node import ConversationState
from mcts.mcts_node import OLNode, ResponseBank

In [3]:
with open("/home/stmorse/data/mdp/fender/test/v0_-1.00_v1_1.00/turn_0_root.pkl", "rb") as f:
    root = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
root.children[0].target_bank.embeddings[0].shape

torch.Size([1024])