# Steerability Evaluation

## Friday 2024-10-18

### Skeleton

Overview:

We want to build an eval that tests Honcho-like steerable systems. For our purposes, a Honcho-like steerable system is capable of a) taking some observations about a user and b) based on those observations, infer how that user would think and behave. Ultimately, this eval framework will help us make evidence-based decisions when building newer versions of Honcho.

- **Eval step 1: steering**. To measure how good a particular steerable system is, we first steer it towards a certain persona. We do this by giving it a certain set of observations about the persona, e.g. what they did or thought in certain scenarios.
	- Once they receive the user observations, different steerable systems will implement this steering differently. For example, a baseline system could simply inject the observations into a few-shot prompt; our first implementation of Honcho uses the observations to write up a natural language user representation; further versions could use them to train a LoRRA...
	- In order to be fair, our eval should be agnostic to the way steerable systems implement steering. After all, what we want to measure is how well a certain system can adapt its behavior to model a user based on certain observations—whether it does it using a prompt, LoRRA or a pool of 500 human experts that can type at the speed of sound shouldn't affect its score. 
- **Eval step 2: testing**. After steering a system towards a certain persona, we test whether it can accurately infer the persona's thoughts or behaviors, i.e. how accurately it can predict a held-out part of the scenario-action dataset for that persona.
- **Eval step 3: aggregation**. We repeat steps 1 and 2 over a large number of personas (e.g. N=100). We have every steered persona x take a test to measure how well it models every persona y. Ideally, if we plot a heat map of the accuracy of the model steered to persona x when tested on persona y, we'd see a strong diagonal, i.e. the model steered to a certain persona scores high on that persona's test, and low on all others. 
	- We could also look at specific horizontal or vertical slices of this plot. For example, a horizontal slice tells us, for a given persona test, how well all steered systems model it. A vertical slice tells us, for a given steered model, how well it does on each persona test.
- **Step 4: system score**. The entire steerable system is given a score based on this heat map. There's a few ways we could do this: 
	- How often is a test "won" by its corresponding steered model? I.e. on average, for a given test persona, how often is the highest score achieved by the model that was steered to that same persona? (is this somewhat akin to specificity?)
	- How often does a steered model do best at its own score? I.e. on average, for a system steered to a given persona, how often is its highest score achieved on the test for the same persona? (is this somewhat akin to sensitivity?)
	- Combining both measures into a sort of F1-score.

In [20]:
from typing import List, Tuple, Set

import pandas as pd

SystemResponse = str
PersonaId = str
ScenarioId = str
ObservationId = str
MAX_PERSONAS = 20
MAX_OBSERVATIONS_PER_PERSONA = 10

import hashlib

def generate_short_hash(text: str) -> str:
    return hashlib.md5(text.encode()).hexdigest()[:8]

class Persona:
    def __init__(self, persona_id: PersonaId, persona_description: str, persona_framework: str):
        self.persona_id = persona_id
        self.persona_description = persona_description
        self.persona_framework = persona_framework

    def __repr__(self):
        return f'Persona(persona_id={self.persona_id}, persona_description={self.persona_description}, persona_framework={self.persona_framework})'


class Observation:
    def __init__(
        self,
        observation_id: ObservationId,
        observation_description: str,
        scenario_id: ScenarioId,
        scenario_description: str,
    ):
        self.observation_id = observation_id
        self.observation_description = observation_description
        self.scenario_id = scenario_id
        self.scenario_description = scenario_description

class Dataset:
    def __init__(
        self,
        personas_path: str,
        observations_path: str,
        max_personas: int = MAX_PERSONAS,
        max_observations_per_persona: int = MAX_OBSERVATIONS_PER_PERSONA,
        use_actions: bool = True,
        use_thoughts: bool = False,
        use_emotions: bool = False,
    ):
        self.max_personas = max_personas
        self.max_observations_per_persona = max_observations_per_persona
        self.observation_types = self.set_observation_types(use_actions, use_thoughts, use_emotions)
        self.personas, self.persona_ids = self.load_personas(personas_path)
        self.observations = self.load_observations(observations_path)

    def set_observation_types(self, use_actions: bool, use_thoughts: bool, use_emotions: bool) -> List[str]:
        observation_types = []
        if use_actions:
            observation_types.append('action')
        if use_thoughts:
            observation_types.append('thought')
        if use_emotions:
            observation_types.append('emotion')
        return observation_types

    def load_personas(self, personas_path: str) -> Tuple[List[Persona], Set[PersonaId]]:
        with open(personas_path, 'r') as f:
            df = pd.read_csv(f)
        df = df.head(self.max_personas)
        personas = []
        persona_ids = set()
        for index, row in df.iterrows():
            persona = Persona(row['persona_id'], row['persona_description'], row['framework_name'])
            personas.append(persona)
            persona_ids.add(persona.persona_id)
        return personas, persona_ids
    
    def load_observations(self, observations_path: str) -> List[Observation]:
        with open(observations_path, 'r') as f:
            df = pd.read_csv(f)
        observations = []
        for index, row in df.iterrows():
            if row['persona_id'] not in self.persona_ids:
                continue
            scenario = f'{row["context"]}\n{row["scenario"]}'
            scenario_id = row['scenario_id']
            for observation_type in self.observation_types:
                observation_description = row[observation_type]
                observation_id = generate_short_hash(f'{observation_description}{scenario_id}')
                observation = Observation(observation_id, row[observation_type], scenario_id, scenario)
                observations.append(observation)
        return observations

In [21]:
personas_path = 'dataset/personas_tarot.csv'
observations_path = 'dataset/w5_tarot.csv'

dataset = Dataset(personas_path, observations_path)


In [23]:
dataset.observations

[<__main__.Observation at 0x10ffada50>,
 <__main__.Observation at 0x10ffad710>,
 <__main__.Observation at 0x2c2a8d090>,
 <__main__.Observation at 0x2b398a310>,
 <__main__.Observation at 0x2b7791990>,
 <__main__.Observation at 0x10ff7f750>,
 <__main__.Observation at 0x2ab02fd50>,
 <__main__.Observation at 0x10ff7f6d0>,
 <__main__.Observation at 0x2eef9ea10>,
 <__main__.Observation at 0x2ce9f7950>,
 <__main__.Observation at 0x2c3fba1d0>,
 <__main__.Observation at 0x2c3fba090>,
 <__main__.Observation at 0x10faa0c90>,
 <__main__.Observation at 0x2cfdbe250>,
 <__main__.Observation at 0x2c8ede0d0>,
 <__main__.Observation at 0x2c8edc990>,
 <__main__.Observation at 0x2c8ede2d0>,
 <__main__.Observation at 0x2c8edde90>,
 <__main__.Observation at 0x2b85fffd0>,
 <__main__.Observation at 0x2b85ff950>,
 <__main__.Observation at 0x2b85fc050>,
 <__main__.Observation at 0x2b77cf750>,
 <__main__.Observation at 0x10825e110>,
 <__main__.Observation at 0x2b77cfc50>,
 <__main__.Observation at 0x2b77cfa10>,


In [7]:
class BaseSteeredSystem:
    def __init__(self, persona, steerable_system):
        self.persona = persona
        self.steerable_system = steerable_system

    def run_inference(self, scenario) -> SystemResponse:
        raise NotImplementedError


class BaseSteerableSystem:
    def __init__(self):
        raise NotImplementedError

    def steer(self, persona: Persona, steer_observations: List[Observation]) -> BaseSteeredSystem:
        raise NotImplementedError


class SteerabilityEval:
    def __init__(self, tested_system: BaseSteerableSystem, dataset: Dataset):
        self.tested_system = tested_system
        self.dataset = dataset
        self.personas = self.dataset.personas
        self.steer_set, self.test_set = self.dataset.split()

    def steer_to_persona(self, persona: Persona) -> BaseSteeredSystem:
        steer_observations = self.steer_set.get_observations(persona)
        steered_system = self.tested_system.steer(persona, steer_observations)
        return steered_system


