## Convert OS based questions to DIALOGUE STYLE questions

In [None]:
import re
import pandas as pd
import json
import os
from sklearn.model_selection import train_test_split

In [None]:
IN_DIR = "../../datasets/original"
OUT_DIR = "../finetuning_datasets/eval_only"

TRANSCRIPTS_DIR = "../../2024_cases_json/"
CASEBRIEF_DIR = "../../2023-2024_case_briefs/"      # directory of raw JSONs of case briefs

def save_jsonl(df, filename):
    df.to_json(filename, orient="records", lines=True)

def read_jsonl(filename):
    with open(filename, "r") as f:
        data = [json.loads(line) for line in f]
    return data

## Generate jsonl for inference on finetuned model

In [None]:
input_fp = f'{IN_DIR}/2024_all_questions.csv'
questions_df = pd.read_csv(input_fp)
justices = list(questions_df['justice'].unique())

In [None]:
justices = [
    "justice_amy_coney_barrett",
    "justice_brett_m_kavanaugh",
    "justice_clarence_thomas",
    "justice_elena_kagan",
    "justice_john_g_roberts_jr",
    "justice_ketanji_brown_jackson",
    "justice_neil_gorsuch",
    "justice_samuel_a_alito_jr",
    "justice_sonia_sotomayor"
 ]

In [None]:
input_fp = f'{IN_DIR}/2024_full_text_transcripts.csv'
df = pd.read_csv(input_fp)


def extract_speaker_and_text(input_string):
    speaker_pattern = re.compile(r"<speaker>(.*?)</speaker>", re.DOTALL)
    text_pattern = re.compile(r"<text>(.*?)</text>", re.DOTALL)

    speaker_match = speaker_pattern.search(input_string)
    text_match = text_pattern.search(input_string)

    speaker = speaker_match.group(1)
    text_content = text_match.group(1)

    # turn = f"{speaker}: {text_content}"
    turn = text_content

    return turn

df['petitioner_turn'] = df['petitioner_opening_text'].apply(extract_speaker_and_text)
df['respondent_turn'] = df['respondent_opening_statement'].apply(extract_speaker_and_text)

In [None]:
df.head()

In [7]:
import re
def clean_text(text):
    if text:
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', '', text)
        # Replace unicode characters
        text = text.replace("\u201c", "\"").replace("\u201d", "\"").replace("\u2018", "'").replace("\u2019", "'")
        return text.strip()
    else:
        return 'UNKNOWN'

def get_facts_and_question(transcript_id, dir=CASEBRIEF_DIR):
    case_brief_file_path = os.path.join(dir, transcript_id + ".json")
    with open(case_brief_file_path, 'r') as json_file:
        case_brief_json = json.load(json_file)
        facts = clean_text(case_brief_json['facts_of_the_case'])
        question = clean_text(case_brief_json['question'])
        return facts, question

def get_system_prompt(transcript_id):
    facts, question = get_facts_and_question(transcript_id)
    return f"You are a legal expert trained to simulate Supreme Court oral arguments.\n\nFACTS_OF_THE_CASE:\n{facts}\n\nLEGAL_QUESTION:\n{question}"


In [8]:

def get_formatted_text_of_turn(turn, advocate):
    '''
    Return all text within a turn as a dict denoting speaker role, and text.

    @param turn -- JSON representing a single speaker turn
    @return -- Dict with keys "role", "content"
    '''
    if not turn["speaker"]:  # skip turns that have no speaker like "Laughter"
        return None

    if not turn["speaker"]["roles"]:
        role = "attorney"
    elif ('2' in turn["speaker"]["roles"] and turn["speaker"]["roles"]['2']["type"] == "scotus_justice") or \
         turn["speaker"]["roles"][0]["type"] == "scotus_justice":
        role = "scotus_justice"

    if role == "scotus_justice":
        identifier = f'justice_{turn["speaker"]["identifier"]}'
    else:
        identifier = advocate

    text = " ".join([block["text"] for block in turn["text_blocks"]])

    return {
        "role": identifier,
        "content": text
    }

def get_transcript_data(json_file_name):
    '''
    @param json_file_name -- Name of the oral argument JSON file
    @return -- List of dicts with keys "role", "content" representing each speaker turn in the transcript
    '''

    transcript_file_path = os.path.join(TRANSCRIPTS_DIR, json_file_name)
    with open(transcript_file_path, 'r') as json_file:
        transcript_json = json.load(json_file)

    formatted_turns = []
    for section in [0, 1]:
        advocate = 'respondent' if section else 'petitioner' 
        section_turns = transcript_json["transcript"]["sections"][section]["turns"]
        section_turns = [get_formatted_text_of_turn(turn, advocate) for turn in section_turns]
        section_turns = [turn for turn in section_turns if turn]  # remove None values
        formatted_turns.extend(section_turns)

    return formatted_turns

# Load all transcripts
data_transcripts = []
cases_dir = os.fsencode(TRANSCRIPTS_DIR)
for json_file_name in os.listdir(TRANSCRIPTS_DIR):
    if json_file_name.endswith('.json'):
        # Extract the transcript_id
        transcript_id = json_file_name[:-9].strip()
        try:
            # Load the corresponding case brief and extract the facts of the case and the legal question
            system_prompt = get_system_prompt(transcript_id)
            messages = [
                {
                    "role": "system",
                    "content": system_prompt
                }
            ]
        except Exception:
            print(f"Could not get facts and question from case brief: Skipping {transcript_id}")
            continue
        # Load the transcript and extract the messages
        messages.extend(get_transcript_data(json_file_name))
        data_transcripts.append({
            "transcript_id": transcript_id,
            "messages": messages
        })

dialogues = [transcript["messages"] for transcript in data_transcripts]

In [9]:
petitioner_chat = [d[:3] for d in dialogues]

In [10]:
len(petitioner_chat)

13

In [11]:
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM
)
from peft import PeftModel

# base_model_id = "path/to/original_base_model" 
# lora_model_id = "path/to/finetuned_adapter_folder"
base_model_dir = "/scratch/gpfs/nnadeem/transformer_cache/Meta-Llama-3.1-8B-Instruct-bnb-4bit/"
adapter_dir = "../models/finetuned_Meta-Llama-3.1-8B-Instruct-bnb-4bit_dialogue_style/checkpoint-242"

# 1) Load tokenizer from the *base* model
tokenizer = AutoTokenizer.from_pretrained(base_model_dir, use_fast=False)

# 2) Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_dir,
    # If you are using 4-bit or 8-bit, set it up here
    load_in_4bit=True,  
    device_map="auto",
)

# 3) Load the LoRA adapter on top of the base model
model = PeftModel.from_pretrained(base_model, adapter_dir)

  from .autonotebook import tqdm as notebook_tqdm
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [12]:
def set_chat_template():
    return """<|begin_of_text|>{%- for message in messages %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n{{ message['content'] }}<|eot_id|>{%- endfor %}"""
tokenizer.chat_template = set_chat_template()

**Sanity Test:**

In [13]:
formatted_chats = [tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False) for sample in petitioner_chat]
prompt = formatted_chats[0]
justice = "justice_sonia_sotomayor"
prompt += f"<|start_header_id|>{justice}<|end_header_id|>\n\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a legal expert trained to simulate Supreme Court oral arguments.

FACTS_OF_THE_CASE:
ATF, created in 1972, is responsible for regulating firearms under the Gun Control Act of 1968 (GCA). The GCA requires federal firearms licensees (FFLs) to conduct background checks, record firearm transfers, and serialize firearms when selling or transferring them. The GCA's regulation of firearms is based on the definition of "firearm," which includes the "frame or receiver." However, ATF's 1978 definition of "frame or receiver" became outdated due to changes in modern firearm design, such as the AR-15 and Glock pistols. Furthermore, the rise of privately made firearms (PMFs) or "ghost guns" posed challenges to law enforcement because they were not regulated under the GCA and did not require serialization. In response, ATF issued a Final Rule in 2022, updating the definitions of "frame," "receiver," and "firearm" to

**Test regex for extracting question from generation:**

In [14]:
import re

def extract_justice_text(transcript: str, justice_identifier: str) -> str:
    """
    Extracts the text between <|start_header_id|>justice_identifier<|end_header_id|>
    and .<|eot_id|> from the transcript.
    
    Returns the matched text (with surrounding whitespace stripped),
    or None if no match is found.
    """
    # Build a pattern specific to the provided justice identifier
    pattern = (
        rf"<\|start_header_id\|>{justice_identifier}<\|end_header_id\|>"  # Match the start marker
        r"(.*?)"                                                         # Captures everything (non-greedy)
        r"<\|eot_id\|>"                                               # Until a period + <|eot_id|>
    )
    
    match = re.search(pattern, transcript, flags=re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

full_response = tokenizer.decode(outputs[0])
j = "justice_sonia_sotomayor"
extract_justice_text(full_response, j)


'Mr. Prelogar, can you tell me what the difference is between a frame and a receiver?'

**Try full flow**:

In [None]:
formatted_chats = [tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False) for sample in petitioner_chat]
results = []
for i, chat in enumerate(formatted_chats):
    print(f"Processing transcript {i}...\n\n")
    result = {
        "prompt": chat,
        "responses": []
    }
    for j in justices:
        print(f"Processing justice {j}...")
        chat += f"<|start_header_id|>{j}<|end_header_id|>\n\n"
        inputs = tokenizer(chat, return_tensors="pt").to("cuda")
        outputs = model.generate(**inputs)
        full_response = tokenizer.decode(outputs[0])
        result["responses"].append({
            "justice": j,
            "full_response": full_response,
            "parsed_response": extract_justice_text(full_response, j)
        })
    results.append(result)

Processing transcript 0...


Processing justice justice_amy_coney_barrett...
Processing justice justice_brett_m_kavanaugh...
Processing justice justice_clarence_thomas...


In [None]:
results[0]["responses"][2]