In [4]:
import sys
import os
os.environ[
    "REDIS_OM_URL"
] = "redis://:password@server_name:port_num"

import json
from tqdm.notebook import tqdm
import rich
import logging
from collections import defaultdict, Counter
from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile
from sotopia.database.logs import EpisodeLog
from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage
from collections import Counter 
from redis_om import Migrator
from rich.console import Console
from rich.terminal_theme import MONOKAI 
import matplotlib.pyplot as plt
import numpy as np
from utils.prompt_reverse_engineering import reverse_episode_log, parse_prompt_to_json


In [5]:
def get_episodes_by_env_dict(
    tags: list, 
    scenarios: set,
) -> dict[str, list[EpisodeLog]]:
     # Load all episodes from tags
    eps_by_tag = {}
    for tag in tags:
        eps = EpisodeLog.find(EpisodeLog.tag == tag).all()
        if len(eps) > 0:
            eps_by_tag[tag] = eps
    eps_list = sum(eps_by_tag.values(), [])

    # Only select episodes under target scenarios
    eps_by_env = {}
    for ep in eps_list:
        if ep.environment in scenarios:
            if ep.environment in eps_by_env:
                eps_by_env[ep.environment].append(ep)
            else:
                eps_by_env[ep.environment] = [ep]
    
    return eps_by_env


def get_sorted_episode_list_for_target_agent(
    tags: list, 
    scenarios: set,
    agent_model: str,
    reward_metric: str = "overall_score"
) -> list:
    
    eps_by_env = get_episodes_by_env_dict(tags, scenarios)
    eps_list = sum(eps_by_env.values(), [])
    
    # Split into two cases where the target agent converses with another agent or itself
    single_agent_eps_list = []
    dual_agent_1_eps_list = []
    dual_agent_2_eps_list = []
    for ep in eps_list:
        if ep.models[1] == agent_model and ep.models[2] == agent_model:
            dual_agent_1_eps_list.append((1, ep))
            dual_agent_2_eps_list.append((2, ep))
        elif ep.models[1] == agent_model or ep.models[2] == agent_model:
            single_agent_eps_list.append((0, ep))

    combined_eps_list = single_agent_eps_list + dual_agent_1_eps_list + dual_agent_2_eps_list

    def reward_sort_fn(x):
        if x[0] == 0:
            return x[1].rewards[0][1][reward_metric] if x[1].models[1] == agent_model else x[1].rewards[1][1][reward_metric]
        elif x[0] == 1:
            return x[1].rewards[0][1][reward_metric]
        else:
            return x[1].rewards[1][1][reward_metric]
        
    sorted_combined_eps_list = sorted(combined_eps_list, key=reward_sort_fn, reverse=True)

    return sorted_combined_eps_list


def select_top_reward_eps(
    sorted_eps_list: list,
    ratio: float = 0.7
) -> list:
    return sorted_eps_list[:round(ratio * len(sorted_eps_list))]

In [6]:
TAGS = ["ft-mistral-7b_ft-mistral-7b_generation-1_26_ruiyi_1220", "ft-mistral-7b_gpt-4_generation-1_26_ruiyi_1220", "ft-mistral-7b_gpt-3.5-turbo_generation-1_26_ruiyi_1220"]
SCENARIOS = ["01HFSDNWGEK1808TP4ZDVZCJEP", "01HFSDNWGHBV0SDXS9FT3P6AF9", "01HFSDNWG8RG3XNA8TNHE00PWD", "01HFSDNWG9NFCGCX3DZ39HHBNR", "01HFSDNWFYTK35BFQDGHTJYBAG", "01HFSDNWHKRS2E646B4K9DQS8Z", "01HFSDNWH76WPTBG0BKRRAMNFG", "01HFSDNWF5X4NJQ6EZPKEDGVGX", "01HFSDNWH17A1ADSAM3BJGW914", "01HFSDNWFZKX4SVRWFDD4GKAMC", "01HFSDNWGBJ7HJXAJVPCZPS491", "01HFSDNWFWG5KJMZ80D10VBYQ7", "01HFSDNWGA8KXENGCJYPHD1YQW", "01HFSDNWFJGESH1WB5NYYEAST8", "01HFSDNWGTP2EMNM949JRA20H0", "01HFSDNWHA7X3EM8S0ZQF94TBH", "01HFSDNWFQ616QWSFSP4RDHJ7S", "01HFSDNWFCEK8NDHMT8EX88CMY", "01HFSDNWFX9NYHZ2GRVN2D1494", "01HFSDNWH0RZDXD8MRX0FVETWC"]
TARGET_AGENT = "localhost"

In [7]:
sorted_combined_eps_list = get_sorted_episode_list_for_target_agent(TAGS, SCENARIOS, TARGET_AGENT)

In [8]:
sorted_combined_eps_list[:10]

[(0,
  EpisodeLog(pk='01HJ30J6KZ2SYZWDM359F7JTY4', environment='01HFSDNWFQ616QWSFSP4RDHJ7S', agents=['01H5TNE5PKW8P500417PMSGSAC', '01H5TNE5PPK39HR52G61PQ5KQ7'], tag='ft-mistral-7b_gpt-4_generation-1_26_ruiyi_1220', models=['gpt-4', 'localhost', 'gpt-4'], messages=[[('Environment', 'Miles Hawkins', "\nHere is the context of this interaction:\nScenario: Two neighbors are both avid gardeners. One neighbor has just received a shipment of rare plant seeds but has run out of pots. The other neighbor has a surplus of pots but has been unsuccessful in finding new and exciting plant varieties to grow.\nParticipants: Miles Hawkins and Zane Bennett\nMiles Hawkins's background: Miles Hawkins is a 50-year-old male chef. He/him pronouns. Miles Hawkins, a chef, is a green thumb enthusiast and spends his free time tending to his kitchen garden, using some of his fresh produce in his dishes. Personality and values description: Miles Hawkins, spontaneous and free-spirited, values sanctity and benevolen

"01HFSDNWG78KV60B0H60GQNNF9", "01HFSDNWGPJSBBS6HT8GDWMC9Q", "01HFSDNWG5Z99ZX40DHZY88PY8", "01HFSDNWHDNSAWC41SWM1ZXTCG", "01HFSDNWGMHY3VNT4BGDXP0GZN", "01HFSDNWG1WFWWBEMMQ0M694BG"

In [6]:
sorted_combined_eps_list = get_sorted_episode_list_for_target_agent(TAGS, SCENARIOS, TARGET_AGENT)
selected_eps_list = select_top_reward_eps(sorted_combined_eps_list, ratio=0.7)
print(len(selected_eps_list))

104


In [28]:
def get_train_data_from_eps_list(
    selected_eps_list: list,
    output_file: str
):
    conv_list = []
    for ep in selected_eps_list:
        if ep[0] == 0: # target agent + other agent
            if ep[1].models[1] == TARGET_AGENT:
                conv = reverse_episode_log(ep[1], later_speak=False)
            else:
                conv = reverse_episode_log(ep[1], later_speak=True)
        elif ep[0] == 1: # target agent * 2, higher speaks first
            conv = reverse_episode_log(ep[1], later_speak=False)
        else: # target agent * 2, higher speaks second
            conv = reverse_episode_log(ep[1], later_speak=True)
        
        for turn in conv:
            conv_list.append({
                "instruction": "",
                "input": turn['prompt'],
                "output": turn['result']
            })
    
    with open(output_file, 'w') as f:
        f.write(json.dumps(conv_list, indent=4))

In [29]:
get_train_data_from_eps_list(selected_eps_list, output_file="data_generation_20_scenarios.json")