#### From case briefs and JSON of Cases, make CSV of actual statements made by each justice. For each question, have the justice asking it and the context up to the point of the statement being made (incl. the system prompt)
Most code taken from `finetuning-inference-script`

In [1]:
import argparse
import json
import os
from datasets import Dataset
import re
import pandas as pd
import copy

TRANSCRIPTS_DIR = f"../2024_cases_json/"
CASEBRIEF_DIR = f"../2023-2024_case_briefs/"      # directory of raw JSONs of case briefs

Helper Functions

In [2]:
def clean_text(text):
    if text:
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', '', text)
        # Replace unicode characters
        text = text.replace("\u201c", "\"").replace("\u201d", "\"").replace("\u2018", "'").replace("\u2019", "'")
        return text.strip()
    else:
        return 'UNKNOWN'

def get_facts_and_question(transcript_id, dir=CASEBRIEF_DIR):
    case_brief_file_path = os.path.join(dir, transcript_id + ".json")
    with open(case_brief_file_path, 'r') as json_file:
        case_brief_json = json.load(json_file)
        facts = clean_text(case_brief_json['facts_of_the_case'])
        question = clean_text(case_brief_json['question'])
        return facts, question

def get_system_prompt(transcript_id):
    facts, question = get_facts_and_question(transcript_id)
    return f"You are a legal expert trained to simulate Supreme Court oral arguments.\n\nFACTS_OF_THE_CASE:\n{facts}\n\nLEGAL_QUESTION:\n{question}"

def get_formatted_text_of_turn(turn, advocate):
    '''
    Return all text within a turn as a dict denoting speaker role, and text.

    @param turn -- JSON representing a single speaker turn
    @return -- Dict with keys "role", "content"
    '''
    if not turn["speaker"]:  # skip turns that have no speaker like "Laughter"
        return None

    if not turn["speaker"]["roles"]:
        role = "attorney"
    elif ('2' in turn["speaker"]["roles"] and turn["speaker"]["roles"]['2']["type"] == "scotus_justice") or \
         turn["speaker"]["roles"][0]["type"] == "scotus_justice":
        role = "scotus_justice"

    if role == "scotus_justice":
        identifier = f'justice_{turn["speaker"]["identifier"]}'
    else:
        identifier = advocate

    text = " ".join([block["text"] for block in turn["text_blocks"]])

    return {
        "role": identifier,
        "content": text
    }

def get_transcript_data(json_file_name, section):
    '''
    @param json_file_name -- Name of the oral argument JSON file
    @return -- List of dicts with keys "role", "content" representing each speaker turn in the transcript
    '''

    transcript_file_path = os.path.join(TRANSCRIPTS_DIR, json_file_name)
    with open(transcript_file_path, 'r') as json_file:
        transcript_json = json.load(json_file)

    formatted_turns = []
    advocate = 'respondent' if section else 'petitioner' 
    section_turns = transcript_json["transcript"]["sections"][section]["turns"]
    section_turns = [get_formatted_text_of_turn(turn, advocate) for turn in section_turns]
    section_turns = [turn for turn in section_turns if turn]  # remove None values

    return section_turns

Load all transcripts

In [3]:
justices = [
    "justice_amy_coney_barrett",
    "justice_brett_m_kavanaugh",
    "justice_clarence_thomas",
    "justice_elena_kagan",
    "justice_john_g_roberts_jr",
    "justice_ketanji_brown_jackson",
    "justice_neil_gorsuch",
    "justice_samuel_a_alito_jr",
    "justice_sonia_sotomayor"
]

# Load all transcripts
data_transcripts = []
cases_dir = os.fsencode(TRANSCRIPTS_DIR)
for json_file_name in os.listdir(TRANSCRIPTS_DIR):
    if json_file_name.endswith('.json'):
        for section in [0, 1]:
            # Extract the transcript_id
            transcript_id = json_file_name[:-9].strip()
            try:
                # Load the corresponding case brief and extract the facts of the case and the legal question
                system_prompt = get_system_prompt(transcript_id)
                messages = [
                    {
                        "role": "system",
                        "content": system_prompt
                    }
                ]
            except Exception as e:
                print(f"Could not get facts and question from case brief: Skipping {transcript_id}")
                print(e)
                continue
            # Load the transcript and extract the messages
            messages.extend(get_transcript_data(json_file_name, section))
            data_transcripts.append({
                "transcript_id": transcript_id,
                "messages": messages
            })

Format transcripts into CSV form and save. Remember to apply tokenizer and chat template to the context!

In [4]:
context_list = []
for transcript in data_transcripts:
    context = []
    for turn in transcript["messages"]:
        if turn["role"] == "system" or turn["role"] == "petitioner" or turn["role"] == "respondent":
            context.append(turn)
            continue # skip the system prompt
        
        new_df_row = {"transcript_id": transcript["transcript_id"]}
        new_df_row["context"] = copy.deepcopy(context)

        new_df_row["justice"] = turn["role"]
        new_df_row["text"] = turn["content"]
        context.append(turn)
        context_list.append(new_df_row)

df = pd.DataFrame(context_list)
df.to_csv("generated_data/context_based_statements_format.csv")

Format transcripts into jsonl form and save.

In [10]:
import json
import copy

justices = [
    "justice_amy_coney_barrett",
    "justice_brett_m_kavanaugh",
    "justice_clarence_thomas",
    "justice_elena_kagan",
    "justice_john_g_roberts_jr",
    "justice_ketanji_brown_jackson",
    "justice_neil_gorsuch",
    "justice_samuel_a_alito_jr",
    "justice_sonia_sotomayor"
]

jsonl_path = "generated_data/context_based_statements_format.jsonl"
with open(jsonl_path, "w") as outfile:
    for transcript in data_transcripts:
        context = []
        for turn in transcript["messages"]:
            if turn["role"] in ["system", "petitioner", "respondent"]:
                context.append(turn)
                continue

            json_obj = {
                "transcript_id": transcript["transcript_id"],
                "context": copy.deepcopy(context),
                "actual_justice": turn["role"],
                "text": turn["content"],
                "predictions": [{justice: {}} for justice in justices]
            }

            outfile.write(json.dumps(json_obj) + "\n")
            context.append(turn)
