In [None]:
import os
import json
from typing import get_args
from tqdm.notebook import tqdm
import rich
import logging
from collections import Counter
from sotopia.database.persistent_profile import (
    AgentProfile,
    EnvironmentProfile,
    RelationshipProfile,
)
from sotopia.database.logs import EpisodeLog
from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage

### Episodes to CSV/JSON

In [None]:
from typing import Literal

from pydantic import ValidationError


LLM_Name = Literal[
    "togethercomputer/llama-2-7b-chat",
    "togethercomputer/llama-2-70b-chat",
    "togethercomputer/mpt-30b-chat",
    "gpt-3.5-turbo",
    "text-davinci-003",
    "gpt-4",
    "gpt-4-turbo",
    "human",
    "redis",
]


def _is_valid_episode_log_pk(pk: str) -> bool:
    try:
        episode = EpisodeLog.get(pk=pk)
    except ValidationError:
        return False
    try:
        tag = episode.tag
        model_1, model_2, version = tag.split("_", maxsplit=2)
        if (
            model_1 in get_args(LLM_Name)
            and model_2 in get_args(LLM_Name)
            and version == "v0.0.1_clean"
        ):
            return True
        else:
            return False
    except (ValueError, AttributeError):
        # ValueError: tag has less than 3 parts
        # AttributeError: tag is None
        return False


episodes: list[EpisodeLog] = [
    EpisodeLog.get(pk=pk)
    for pk in filter(_is_valid_episode_log_pk, EpisodeLog.all_pks())
]

In [None]:
from sotopia.database.serialization import episodes_to_csv


episodes_to_csv(episodes, "../data/sotopia_episodes_v1.csv")

In [None]:
from sotopia.database.serialization import episodes_to_json


episodes_to_json(episodes, "../data/sotopia_episodes_v1.json")

## Relationship Profile

In [None]:
res_pks = RelationshipProfile.all_pks()
res_pks = list(res_pks)
print(len(res_pks))
res = []
for pk in res_pks:
    print(pk)
    try:
        res.append(RelationshipProfile.get(pk=pk))
    except Exception:
        print("error")
        pass
res_relationships = [r.relationship for r in res]
Counter(res_relationships)

## Agents Profile

In [None]:
# obtain a specific agent
agents = AgentProfile.find(AgentProfile.first_name == "ss").all()
rich.print(agents)

In [None]:
# find specific agnets
agents = AgentProfile.find(AgentProfile.gender == "Man", AgentProfile.age > 30)
for agent in agents:
    rich.print(agent)

In [None]:
# obtain all agents' basic info
agent_pks = AgentProfile.all_pks()
agent_pks = list(agent_pks)
print(len(agent_pks))
agents = []
for pk in agent_pks:
    try:
        agents.append(AgentProfile.get(pk=pk))
    except Exception:
        print("error")
        pass
# output agents's basic info
for agent in agents:
    rich.print(agent)

In [None]:
agent_pks = AgentProfile.all_pks()
agent_pks = list(agent_pks)
print(len(agent_pks))

In [None]:
# Update agent's information
agents = AgentProfile.find(
    AgentProfile.first_name == "Ava", AgentProfile.last_name == "Martinez"
).all()[0]

In [None]:
agents

In [None]:
agents.update(secret="Keeps their bisexuality a secret from her conservative family")

## Environment Profile

In [None]:
# get all environments
all_envs = list(EnvironmentProfile.all_pks())
print(len(all_envs))
print(all_envs[:5])

In [None]:
# get a specific environment profile
env_profile_id = "01H7VFHPJKR16MD1KC71V4ZRCF"
env = EnvironmentProfile.get(env_profile_id)
rich.print(env)

## EnvAgentComboStorage

In [None]:
# all env-agent combos
all_combos = EnvAgentComboStorage().all_pks()
all_combos = list(all_combos)
print(len(all_combos))
rich.print(EnvAgentComboStorage().get(all_combos[0]))

In [None]:
# check for duplicates in EnvAgentComboStorage
cache = set()
for combo_pk in all_combos:
    combo = EnvAgentComboStorage.get(combo_pk)
    curr_tuple = (combo.env_id, combo.agent_ids[0], combo.agent_ids[1])
    if curr_tuple in cache:
        print("duplicate")
    else:
        cache.add(curr_tuple)

## Episode Log

In [None]:
# find episode log by tag
Episodes = EpisodeLog.find(EpisodeLog.tag == "aug20_gpt4_llama-2-70b-chat_zqi2").all()
len(Episodes)

In [None]:
# get all episode logs' primary keys
episode_pks = EpisodeLog.all_pks()
episode_pks = list(episode_pks)
print(len(episode_pks))
print(episode_pks[0])

In [None]:
# some eps have validation error while loading
# find all buggy eps and good eps
good_eps = []
buggy_eps = []
for pk in episode_pks:
    try:
        curr_ep = EpisodeLog.get(pk)
    except Exception as e:
        print(pk)
        print(e)
        buggy_eps.append(curr_ep)
        continue
    good_eps.append(curr_ep)
len(buggy_eps)

In [None]:
print(len(buggy_eps))

In [None]:
def delete_and_save(delete_pks, save_path):
    for pk in tqdm(delete_pks):
        episode = EpisodeLog.get(pk=pk)
        EpisodeLog.delete(pk)
        ep_json = episode.json()
        with open(os.path.join(save_path, f"{pk}.json"), "w") as f:
            f.write(ep_json)

In [None]:
!pwd

In [None]:
buggy_ep_pks = [ep.pk for ep in buggy_eps]
print(len(buggy_ep_pks))
delete_and_save(buggy_ep_pks, "/Users/zhengyangqi/sotopia/backup/buggy_episode")

In [None]:
with open(
    "/Users/zhengyangqi/sotopia/backup/buggy_episode/01H6YDDT9M8ZYM2FN1F1STM1QB.json",
    "r",
) as f:
    test_ep_data = json.load(f)
test_ep_data.keys()

In [None]:
test_ep = EpisodeLog(**test_ep_data)
agent_profiles, conversation = test_ep.render_for_humans()
for agent_profile in agent_profiles:
    rich.print(agent_profile)
for message in conversation:
    rich.print(message)

In [None]:
# get the epilogs that contain the specified models
model1 = "gpt-4"
model2 = "gpt-4"
model_comp1 = ["gpt-4", model1, model2]
model_comp2 = ["gpt-4", model2, model1]

gpt35_llama2_eps = []
for epid in episode_pks:
    try:
        curr_ep = EpisodeLog.get(epid)
    except Exception:
        continue
    if curr_ep.models == model_comp1 or curr_ep.models == model_comp2:
        gpt35_llama2_eps.append(curr_ep)
len(gpt35_llama2_eps)

In [None]:
# check symetry of epilogs, i.e., if we have an epilog for env1, agent1, agent2, then we should have an epilog for env1, agent2, agent1
def is_symmetric_epilogs(epilogs):
    asymetric_epilogs = []
    gpt35_llama2_epilogs_dict = {}
    for ep in epilogs:
        hash_key = (
            ep.environment,
            ep.agents[0],
            ep.agents[1],
            ep.models[0],
            ep.models[1],
            ep.models[2],
        )
        gpt35_llama2_epilogs_dict[hash_key] = ep.pk
    for hash_key in gpt35_llama2_epilogs_dict:
        if (
            hash_key[0],
            hash_key[1],
            hash_key[2],
            hash_key[3],
            hash_key[5],
            hash_key[4],
        ) not in gpt35_llama2_epilogs_dict:
            asymetric_epilogs.append(gpt35_llama2_epilogs_dict[hash_key])

    if len(asymetric_epilogs) == 0:
        return True
    else:
        logging.warning(
            f"Found {len(asymetric_epilogs)} asymetric epilogs: {asymetric_epilogs}"
        )
        return False

In [None]:
is_symmetric_epilogs(Episodes)

## Render Episodes

In [None]:
# get a human readable version of the episode
agent_profiles, conversation = Episodes[1].render_for_humans()
for agent_profile in agent_profiles:
    rich.print(agent_profile)
for message in conversation:
    rich.print(message)

In [None]:
# check environments
len(set([Episode.environment for Episode in Episodes]))