In [113]:
import os
from dotenv import load_dotenv
import openai
from pprint import pprint
from dataset_functions import load_claudette_tos_dataset, estimate_tokens
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [12]:
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
client = openai

In [99]:
def classify_clauses_batch(clauses, system_prompt, batch_size=10, model="gpt-4o-mini", mode="test"):
    # shouldBreak = False
    # count = 0
    # if mode == "test":
    #     shouldBreak = True
    all_predictions = []
    for i in range(0, len(clauses), batch_size):
        batch_clauses = clauses[i: i + batch_size]
        user_prompt = "Classify the following clauses:\n" + "\n".join(
            [f"{j+1}. {clause}" for j, clause in enumerate(batch_clauses)]
        )
        full_prompt = system_prompt + "\n" + user_prompt
        token_count = estimate_tokens(full_prompt, model=model)
        print(
            f"Estimated token count for batch {i // batch_size + 1}: {token_count}")

        try:
            response = openai.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0,  # For consistent results
            )
            # Extract classifications from the response
            batch_predictions = []
            for choice in response.choices:
                message_content = choice.message.content
                batch_predictions.extend(message_content.strip().split("\n"))

            # Verify the number of classifications matches the batch size
            if len(batch_predictions) != len(batch_clauses):
                print(f"WARNING: Mismatch in batch size for Batch {i // batch_size + 1}!")
                print(f"Expected {len(batch_clauses)} classifications, but got {len(batch_predictions)}.")
                # print("Skipping this batch...")
                # continue  # Skip this batch if verification fails

            # Append verified predictions to the final list
            all_predictions.extend(batch_predictions)
            print(f"Batch {i // batch_size + 1} processed successfully with {len(batch_predictions)} classifications.")

            # if count > 4:
            #     return all_predictions
            # count += 1
        except Exception as e:
            print(f"Error processing batch {i // batch_size + 1}: {e}")
    return all_predictions

In [14]:
def collect_labels(all_predictions):
    print(len(all_predictions))
    output_labels = []
    for message in all_predictions:
        lst = message.split("\n")
        for predict in lst:
            output_labels.append(
                int(predict.strip().split('Classification: ')[1]))
    return output_labels

In [114]:
# system_prompt = """
# You are a legal expert in consumer rights and contract law. Your task is to classify clauses from Terms of Service (ToS) documents as either '0' (Fair) or '1' (Unfair).

# A clause is '1' (Unfair) if it:
# 1. Imposes unreasonable restrictions or obligations on users.
# 2. Significantly limits users' rights or remedies.
# 3. Creates an imbalance of power favoring the service provider.
# 4. Uses vague or ambiguous language that could be exploited.
# 5. Violates established legal principles or consumer protection laws.

# If none of these apply, classify the clause as '0' (Fair). For each clause, You do not have to number it. You can provide the output line by line in the
# following format so it is easy to parse:

# Classification: <0 or 1>
# """

system_prompt = """
You are a legal expert in consumer rights and contract law. Your task is to classify clauses from Terms of Service (ToS) 
documents as either '0' (Fair) or '1' (Unfair).

A clause is '1' (Unfair) if it:
1. Imposes unreasonable restrictions or obligations on users.
2. Significantly limits users' rights or remedies.
3. Creates an imbalance of power favoring the service provider.
4. Uses vague or ambiguous language that could be exploited.
5. Violates established legal principles or consumer protection laws.

Classify ONLY the clauses provided in the user prompt. Provide ONLY the classification ('Classification: 0' or 'Classification: 1') 
for each clause, line by line, in the same order as the input. Do not add extra lines, blank lines, or explanations.
"""

final_splits = load_claudette_tos_dataset()

validation_clauses = final_splits['validation']['text']
validation_labels = final_splits['validation']['label']

train_clauses = final_splits['train']['text']
train_labels = final_splits['validation']['label']

test_clauses = final_splits['test']['text']
test_labels = final_splits['validation']['label']

In [97]:
batch_size = 10
for i in range(0, len(validation_clauses), batch_size):
    batch_clauses = validation_clauses[i: i + batch_size]
    user_prompt = "Classify the following clauses:\n" + "\n".join(
        [f"{j+1}. {clause}" for j, clause in enumerate(batch_clauses)]
    )
    full_prompt = system_prompt + "\n" + user_prompt
    token_count = estimate_tokens(full_prompt, model="gpt-4o-mini")
    print(f"Batch {i // batch_size + 1} processed successfully.")

Batch 1 processed successfully.
Batch 2 processed successfully.
Batch 3 processed successfully.
Batch 4 processed successfully.
Batch 5 processed successfully.
Batch 6 processed successfully.
Batch 7 processed successfully.
Batch 8 processed successfully.
Batch 9 processed successfully.
Batch 10 processed successfully.
Batch 11 processed successfully.
Batch 12 processed successfully.
Batch 13 processed successfully.
Batch 14 processed successfully.
Batch 15 processed successfully.
Batch 16 processed successfully.
Batch 17 processed successfully.
Batch 18 processed successfully.
Batch 19 processed successfully.
Batch 20 processed successfully.
Batch 21 processed successfully.
Batch 22 processed successfully.
Batch 23 processed successfully.
Batch 24 processed successfully.
Batch 25 processed successfully.
Batch 26 processed successfully.
Batch 27 processed successfully.
Batch 28 processed successfully.
Batch 29 processed successfully.
Batch 30 processed successfully.
Batch 31 processed 

In [31]:
len(validation_clauses)

941

In [100]:
validation_predictions = classify_clauses_batch(validation_clauses, system_prompt)

Estimated token count for batch 1: 484
Batch 1 processed successfully with 10 classifications.
Estimated token count for batch 2: 546
Batch 2 processed successfully with 10 classifications.
Estimated token count for batch 3: 563
Batch 3 processed successfully with 10 classifications.
Estimated token count for batch 4: 500
Batch 4 processed successfully with 10 classifications.
Estimated token count for batch 5: 509
Batch 5 processed successfully with 10 classifications.
Estimated token count for batch 6: 570
Batch 6 processed successfully with 10 classifications.
Estimated token count for batch 7: 593
Batch 7 processed successfully with 10 classifications.
Estimated token count for batch 8: 636
Batch 8 processed successfully with 10 classifications.
Estimated token count for batch 9: 609
Batch 9 processed successfully with 10 classifications.
Estimated token count for batch 10: 678
Batch 10 processed successfully with 10 classifications.
Estimated token count for batch 11: 616
Batch 11

In [88]:
with open("test_validation_predictions.pkl", "wb") as f:
    pickle.dump(test_validation_predictions, f)

In [56]:
import pickle
# Load the results from the pickle file
with open("validation_prediction.pkl", "rb") as f:
    validation_predictions = pickle.load(f)

In [111]:
# validation_predictions[0]
print(len(validation_predictions))
output_labels = []
for i, message in enumerate(validation_predictions):
    # print(message)
    output_labels.append(int(message.strip().split('Classification: ')[1]))
print(len(output_labels))

941
941


In [112]:
output_labels[:10]

[0, 1, 1, 1, 0, 0, 0, 0, 1, 1]

In [117]:
def evaluate_predictions(ground_truth, predictions):
  """
  Evaluates the predictions against ground truth labels and calculates metrics.

  Args:
    ground_truth: A list of ground truth labels (0 for Fair, 1 for Unfair).
    predictions: A list of predicted labels (0 for Fair, 1 for Unfair).

  Returns:
    A dictionary containing the calculated metrics.
  """

  accuracy = accuracy_score(ground_truth, predictions)
  precision = precision_score(ground_truth, predictions)
  recall = recall_score(ground_truth, predictions)
  f1 = f1_score(ground_truth, predictions)

  metrics = {
    'Accuracy': accuracy,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1
  }

  return metrics

# metrics = evaluate_predictions(ground_truth, predictions)
# print(metrics)

In [115]:
validation_predicted_labels = output_labels

In [118]:
evaluate_predictions(validation_labels, validation_predicted_labels)

{'Accuracy': 0.6588735387885228,
 'Precision': 0.20054945054945056,
 'Recall': 0.7087378640776699,
 'F1-Score': 0.31263383297644537}

In [28]:
import pickle

# Save the results to a pickle file
with open("validation_labels.pkl", "wb") as f:
    pickle.dump(validation_labels, f)

In [41]:
with open("validation_prediction.pkl", "wb") as f:
    pickle.dump(validation_predictions, f)

In [29]:
# Load the results from the pickle file
with open("validation_labels.pkl", "rb") as f:
    validation_loaded = pickle.load(f)

print("Validation predictions loaded successfully.")
print(validation_loaded)

Validation predictions loaded successfully.
[0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [42]:
# Load the results from the pickle file
with open("validation_prediction.pkl", "rb") as f:
    validation_prediction_test = pickle.load(f)

print("Validation predictions loaded successfully.")
print(validation_prediction_test)

Validation predictions loaded successfully.
['Classification: 0  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 0  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassification: 1  \nClassific

In [67]:
# def compare_validation_predictions(test_pkl_path, full_pkl_path):
def compare_validation_predictions(test_validation_predictions, validation_predictions):
    """
    Compare the structure and content of two pickle files containing validation predictions.
    :param test_pkl_path: Path to the test validation predictions pickle file (batch size 10).
    :param full_pkl_path: Path to the full validation predictions pickle file (batch size 100).
    :return: None (prints differences in structure and content).
    """
    test_predictions = validation_predictions
    full_predictions = test_validation_predictions
    # Load the test validation predictions
    # with open(test_pkl_path, "rb") as f:
    #     test_predictions = pickle.load(f)

    # # Load the full validation predictions
    # with open(full_pkl_path, "rb") as f:
    #     full_predictions = pickle.load(f)

    # Print basic information about the two files
    print("=== Basic Information ===")
    print(f"Test Predictions: {len(test_predictions)} batches")
    print(f"Full Predictions: {len(full_predictions)} batches")
    print()

    # Compare the structure of the two files
    print("=== Structure Comparison ===")
    if isinstance(test_predictions, list) and isinstance(full_predictions, list):
        print("Both files are lists.")
    else:
        print("Mismatch in data types: Test is a", type(test_predictions), "and Full is a", type(full_predictions))
        return

    # Compare the lengths of the batches
    test_batch_sizes = [len(batch.split('\n')) for batch in test_predictions]
    full_batch_sizes = [len(batch.split('\n')) for batch in full_predictions]
    print(f"Test Batch Sizes: {test_batch_sizes}")
    print(f"Full Batch Sizes: {full_batch_sizes}")
    print()

    # Compare the content of the first few batches
    print("=== Content Comparison ===")
    for i, (test_batch, full_batch) in enumerate(zip(test_predictions, full_predictions)):
        print(f"Batch {i + 1}:")
        test_lines = test_batch.split("\n")
        full_lines = full_batch.split("\n")

        # Compare the number of lines in each batch
        print(f"  Test Batch Lines: {len(test_lines)}")
        print(f"  Full Batch Lines: {len(full_lines)}")

        # Compare the first few lines of each batch
        print("  Test Batch Sample:")
        print(test_lines[:5])
        print("  Full Batch Sample:")
        print(full_lines[:5])

        # Stop after comparing a few batches
        if i >= 2:
            break

    # Check for any discrepancies in the total number of classifications
    print("\n=== Total Classifications ===")
    test_total = sum(len(batch.split("\n")) for batch in test_predictions)
    full_total = sum(len(batch.split("\n")) for batch in full_predictions)
    print(f"Total Classifications in Test Predictions: {test_total}")
    print(f"Total Classifications in Full Predictions: {full_total}")

    if test_total != full_total:
        print("WARNING: The total number of classifications differs between the two files!")
    else:
        print("The total number of classifications matches between the two files.")

In [101]:
compare_validation_predictions(test_validation_predictions, validation_predictions)

TypeError: unhashable type: 'list'

In [68]:
compare_validation_predictions('test_validation_predictions.pkl', 'validation_prediction.pkl')

=== Basic Information ===
Test Predictions: 4 batches
Full Predictions: 7 batches

=== Structure Comparison ===
Both files are lists.
Test Batch Sizes: [10, 10, 10, 10]
Full Batch Sizes: [3277, 3277, 3277, 3277, 3277, 3277, 43]

=== Content Comparison ===
Batch 1:
  Test Batch Lines: 10
  Full Batch Lines: 3277
  Test Batch Sample:
['Classification: 0  ', 'Classification: 1  ', 'Classification: 1  ', 'Classification: 1  ', 'Classification: 0  ']
  Full Batch Sample:
['Classification: 0  ', 'Classification: 1  ', 'Classification: 1  ', 'Classification: 1  ', 'Classification: 0  ']
Batch 2:
  Test Batch Lines: 10
  Full Batch Lines: 3277
  Test Batch Sample:
['Classification: 1  ', 'Classification: 0  ', 'Classification: 1  ', 'Classification: 1  ', 'Classification: 1  ']
  Full Batch Sample:
['Classification: 0  ', 'Classification: 0  ', 'Classification: 0  ', 'Classification: 0  ', 'Classification: 0  ']
Batch 3:
  Test Batch Lines: 10
  Full Batch Lines: 3277
  Test Batch Sample:
['Cl