In [None]:
%%capture
!pip install openai

In [None]:
import numpy as np
import openai
import os
import pandas as pd
from   pathlib import Path
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import time

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from google.colab import auth
import gspread
from google.auth import default

auth.authenticate_user()
creds, _ = default()
gc = gspread.authorize(creds)

Mounted at /content/drive


In [None]:
interp_dir = '/content/drive/MyDrive/legal_interpretation/code/generative_testing'
train_test_dir = os.path.join(interp_dir, 'train_test_splits', 'train_test_splits_2')
output_path = os.path.join(interp_dir, 'interpretation_results')
generated_output_path = os.path.join(output_path, 'generations')
descriptive_errors_dir = os.path.join(output_path, 'errors')

In [None]:
worksheet = gc.open('final_cleaned_paragraphs').sheet1
rows = worksheet.get_all_values()
interpretation_df = pd.DataFrame(rows)

In [None]:
interpretation_df.columns = interpretation_df.iloc[0]
interpretation_df = interpretation_df.iloc[1:]

In [None]:
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # for exponential backoff


device_name = 'cuda'
max_length = 512

In [None]:
openai.api_key = open(os.path.join(interp_dir, 'private', 'openai_key.txt')).read().strip()

In [None]:
interpretation_prompt = "Some paragraphs in court cases interpret statutes. In this type of paragraph, there is an analysis of a statute and a claim made about its meaning. \n\nIn the following paragraph, determine if legal interpretation occurs. If yes, respond with \”interpretation\” and if not, respond with \”no interpretation\”"

In [None]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def get_binary_interp(text, interpretation_prompt):
  response = openai.ChatCompletion.create(
      model='gpt-4',
      max_tokens=5,
      messages = [{'role': 'user', 'content': interpretation_prompt + text}]
  )
  return response['choices'][0]['message']['content'].strip().lower()


In [None]:
interpretation_df = interpretation_df[interpretation_df['class'].notna()]
interpretation_df["interpretation"] = np.where(interpretation_df["class"].isin(["FORMAL", "GRAND"]), "interpretation", "no interpretation")

In [None]:
example_text = "Nevertheless, respondent urges that the legislative purpose of the statute is best served by construing it to permit some choice in determining the length of the penalty period. In respondent's view, the purpose of the statute is essentially remedial and compensatory, and thus it should not be interpreted literally to produce a monetary award that is so far in excess of any equitable remedy as to be punitive."

In [None]:
get_binary_interp(example_text, interpretation_prompt)

'interpretation'

In [None]:
example_text = interpretation_df["paragraph"].tolist()[2]

In [None]:
from collections import Counter
Counter(interpretation_df["interpretation"])

Counter({'interpretation': 880, 'no interpretation': 1868})

In [None]:
macro_f1_l = []
macro_precision_l = []
macro_recall_l = []

weighted_f1_l = []
weighted_precision_l = []
weighted_recall_l = []

one_f1_l = []
one_precision_l = []
one_recall_l = []

zero_f1_l = []
zero_precision_l = []
zero_recall_l = []


In [None]:
full_df = pd.DataFrame()

In [None]:
for split in range(0, 5):
  start_time = time.time()

  split_id_file = os.path.join(train_test_dir, f'split_{split}')

  with open(split_id_file, 'r') as file:
      train_ids = file.read().split("\n")

  interpretation_train_df = interpretation_df[interpretation_df["section_id"].isin(train_ids)]
  interpretation_test_df = interpretation_df[~interpretation_df["section_id"].isin(train_ids)]


  X_test = interpretation_test_df["paragraph"].to_list()
  y_test = interpretation_test_df["interpretation"].to_list()

  total = len(X_test)

  predicted_labels = []
  for i, text in enumerate(X_test):
    prediction = get_binary_interp(text, interpretation_prompt)
    predicted_labels.append(prediction)

    if i % 50 == 0:
      precent = round((i/total)*100, 2)
      print(f"{precent} percent through processing.")

  with open(os.path.join(generated_output_path, f'predictions_{split}.txt'), 'w') as file:
     for label in predicted_labels:
        file.write(f"{label}\n")

  predictions_df = pd.DataFrame(
  {'section_id': interpretation_test_df["section_id"].tolist(),
    'gold': y_test,
    'predicted': predicted_labels,
    'text': X_test
  })
  errors_df = predictions_df.query('gold != predicted')
  errors_df.to_csv(os.path.join(descriptive_errors_dir, f"{split}_errors.csv"))

In [None]:
for split in range(1, 5):

  split_id_file = os.path.join(train_test_dir, f'split_{split}')

  with open(split_id_file, 'r') as file:
      train_ids = file.read().split("\n")

  interpretation_train_df = interpretation_df[interpretation_df["section_id"].isin(train_ids)]
  interpretation_test_df = interpretation_df[~interpretation_df["section_id"].isin(train_ids)]


  X_test = interpretation_test_df["paragraph"].to_list()
  y_test = interpretation_test_df["interpretation"].to_list()

  with open(os.path.join(generated_output_path, f'predictions_{split}.txt'), 'r') as file:
    print(file)
    predicted_labels = [line.rstrip() for line in file]

  print(y_test, predicted_labels)
  class_report = classification_report(y_test, predicted_labels, output_dict=True)

  sample_dict = {
      "model": "interpretation_generative",
      "split": split,

      "macro_f1": round(class_report["macro avg"]["f1-score"], 3),
      "macro_precision": round(class_report["macro avg"]["precision"], 3),
      "macro_recall": round(class_report["macro avg"]["recall"], 3),

      "weighted_f1": round(class_report["weighted avg"]["f1-score"], 3),
      "weighted_precision": round(class_report["weighted avg"]["precision"], 3),
      "weighted_recall": round(class_report["weighted avg"]["recall"], 3),

      "1_f1": round(class_report["interpretation"]["f1-score"], 3),
      "1_precision": round(class_report["interpretation"]["precision"], 3),
      "1_recall": round(class_report["interpretation"]["recall"], 3),

      "0_f1": round(class_report["no interpretation"]["f1-score"], 3),
      "0_precision": round(class_report["no interpretation"]["precision"], 3),
      "0_recall": round(class_report["no interpretation"]["recall"], 3),

  }

  new_row = pd.DataFrame(sample_dict, index = [0])
  full_df = pd.concat([full_df, new_row])

  macro_f1_l.append(class_report["macro avg"]["f1-score"])
  macro_precision_l.append(class_report["macro avg"]["precision"])
  macro_recall_l.append(class_report["macro avg"]["recall"])

  weighted_f1_l.append(class_report["weighted avg"]["f1-score"])
  weighted_precision_l.append(class_report["weighted avg"]["precision"])
  weighted_recall_l.append(class_report["weighted avg"]["recall"])

  one_f1_l.append(class_report["interpretation"]["f1-score"])
  one_precision_l.append(class_report["interpretation"]["precision"])
  one_recall_l.append(class_report["interpretation"]["recall"])

  zero_f1_l.append(class_report["no interpretation"]["f1-score"])
  zero_precision_l.append(class_report["no interpretation"]["precision"])
  zero_recall_l.append(class_report["no interpretation"]["recall"])

  # end_time = time.time()

  # total_minutes = round((end_time - start_time) / 60, 2)
  # print(f"Total time: {total_minutes} minutes.")

<_io.TextIOWrapper name='/content/drive/MyDrive/legal_interpretation/code/generative_testing/interpretation_results/generations/predictions_1.txt' mode='r' encoding='UTF-8'>
['interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no inter

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


<_io.TextIOWrapper name='/content/drive/MyDrive/legal_interpretation/code/generative_testing/interpretation_results/generations/predictions_2.txt' mode='r' encoding='UTF-8'>
['interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'inter

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


<_io.TextIOWrapper name='/content/drive/MyDrive/legal_interpretation/code/generative_testing/interpretation_results/generations/predictions_3.txt' mode='r' encoding='UTF-8'>
['interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'no interpretation', 'interpretat

In [None]:
macro_f1 = sum(macro_f1_l) / len(macro_f1_l)
macro_precision = sum(macro_precision_l) / len(macro_precision_l)
macro_recall = sum(macro_recall_l) / len(macro_recall_l)

weighted_f1 = sum(weighted_f1_l) / len(weighted_f1_l)
weighted_precision = sum(weighted_precision_l) / len(weighted_precision_l)
weighted_recall = sum(weighted_recall_l) / len(weighted_recall_l)

one_f1 = sum(one_f1_l) / len(one_f1_l)
one_precision = sum(one_precision_l) / len(one_precision_l)
one_recall = sum(one_recall_l) / len(one_recall_l)

zero_f1 = sum(zero_f1_l) / len(zero_f1_l)
zero_precision = sum(zero_precision_l) / len(zero_precision_l)
zero_recall = sum(zero_recall_l) / len(zero_recall_l)

In [None]:
model_dict = {
    "model": "descriptive_generative",
    "split": "averages",

    "macro_f1": round(macro_f1, 3),
    "macro_precision": round(macro_precision, 3),
    "macro_recall": round(macro_recall, 3),

    "weighted_f1": round(weighted_f1, 3),
    "weighted_precision": round(weighted_precision, 3),
    "weighted_recall": round(weighted_recall, 3),

    "1_f1": round(one_f1, 3),
    "1_precision": round(one_precision, 3),
    "1_recall": round(one_recall, 3),

    "0_f1": round(zero_f1, 3),
    "0_precision": round(zero_precision, 3),
    "0_recall": round(zero_recall, 3),

}

new_row = pd.DataFrame(model_dict, index = [0])
full_df = pd.concat([full_df, new_row])

# full_df.to_csv(os.path.join(output_path, 'gpt_interpretation_results.csv'))



In [None]:
full_df

Unnamed: 0,model,split,macro_f1,macro_precision,macro_recall,weighted_f1,weighted_precision,weighted_recall,1_f1,1_precision,1_recall,0_f1,0_precision,0_recall
0,interpretation_generative,1,0.386,0.44,0.438,0.577,0.737,0.578,0.583,0.433,0.89,0.574,0.886,0.424
0,interpretation_generative,1,0.386,0.44,0.438,0.577,0.737,0.578,0.583,0.433,0.89,0.574,0.886,0.424
0,interpretation_generative,2,0.283,0.329,0.329,0.567,0.754,0.565,0.562,0.408,0.901,0.569,0.908,0.415
0,interpretation_generative,3,0.588,0.667,0.668,0.588,0.748,0.588,0.587,0.436,0.897,0.589,0.898,0.438
0,interpretation_generative,4,0.575,0.656,0.65,0.571,0.726,0.575,0.586,0.439,0.885,0.563,0.874,0.415
0,descriptive_generative,averages,0.443,0.506,0.504,0.576,0.74,0.577,0.58,0.43,0.893,0.574,0.891,0.423


In [None]:
averages_df = full_df[full_df["split"] == "averages"]

In [None]:
averages_df = averages_df.drop(columns = ["weighted_f1", "weighted_precision", "weighted_recall"])

In [None]:
print(averages_df.to_latex(
                  formatters={"name": str.upper},
                  float_format="{:.2f}".format,
                  index = False
                  ))

\begin{tabular}{llrrrrrrrrr}
\toprule
                 model &    split &  macro\_f1 &  macro\_precision &  macro\_recall &  1\_f1 &  1\_precision &  1\_recall &  0\_f1 &  0\_precision &  0\_recall \\
\midrule
descriptive\_generative & averages &      0.44 &             0.51 &          0.50 &  0.58 &         0.43 &      0.89 &  0.57 &         0.89 &      0.42 \\
\bottomrule
\end{tabular}



  print(averages_df.to_latex(
