# Uploading rewards from human annotation to DB


## Import neccessary packages

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from sotopia.database import EpisodeLog
from pydantic import ValidationError
from sotopia.generation_utils.generate import LLM_Name
from typing import get_args
import numpy as np
import numpy.typing as npt
from collections import defaultdict
from rich.console import Console
from rich.table import Table
from sotopia.envs.evaluators import EvaluationBySocialDimensions
from sotopia.database.logs import AnnotationForEpisode
from sotopia.database import EpisodeLog
import tqdm
from pathlib import Path
from pandas import read_csv
from redis_om import Migrator

## Find all csv files in the directory

In [None]:
csv_files = list(Path(".").glob("*.csv"))

## Iterate through all csv files and upload them to DB

In [None]:
SOCIAL_DIMENSIONS: list[str] = list(EvaluationBySocialDimensions.__fields__.keys())

Migrator().run()

for csv_file in csv_files:
    print(f"Processing {csv_file}")
    d = read_csv(csv_file)
    for _, row in tqdm.tqdm(d.iterrows()):
        rewards: list[tuple[float, dict[str, float]]] = []
        for agent in ["agent1", "agent2"]:
            rewards_for_agent = {
                social_dim: row[f"Answer.{agent}_{social_dim}"]
                for social_dim in SOCIAL_DIMENSIONS
            }
            rewards.append((sum(rewards_for_agent.values()) / len(rewards_for_agent), rewards_for_agent))
        
        reasoning = ""
        for agent_name_in_reasoning, agent in [("agent 1", "agent1"), ("agent 2", "agent2")]:
            reasoning += f"{agent_name_in_reasoning} comments: "
            for social_dim in SOCIAL_DIMENSIONS:
                reasoning += f"{social_dim}: {row[f'Answer.{agent}_{social_dim}_rationale']} "
        
        episode_id = row["Input.episode_id"]
        assert EpisodeLog.get(pk=episode_id)
        existing_annotations = AnnotationForEpisode.find((AnnotationForEpisode.episode==episode_id) & (AnnotationForEpisode.annotator_id==row["WorkerId"])).all()
        if existing_annotations:
            print(f"Skipping {episode_id} for {row['WorkerId']} because it already exists")
            continue
        annotation = AnnotationForEpisode(
            episode=episode_id,
            annotator_id=row["WorkerId"],
            rewards=rewards,
            reasoning=reasoning,
        )
        annotation.save()


## Remove disqualified workers

In [None]:

import ipywidgets as widgets
worker_id_widget = widgets.Textarea(
    placeholder='Type something',
    description='String:',
    disabled=False   
)
worker_id_widget

In [None]:
layout = widgets.Layout(width='auto', height='40px')
worker_id = worker_id_widget.value
annotations = AnnotationForEpisode.find(AnnotationForEpisode.annotator_id==worker_id).all()
print(f"Found {len(annotations)} annotations")
def _f(*args, **kwargs):
    for annotation in annotations:
        print(f"Deleting Worker {worker_id}'s annotation for Episode {annotation.episode}")
        annotation.delete(annotation.pk)
button = widgets.Button(
    description='Do you want to delete all annotations for this worker?',
    disabled=False,
    button_style='warning', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Yes',
    icon='trash', # (FontAwesome names without the `fa-` prefix)
    layout=layout
)
button.on_click(_f)
button