In [1]:
import os
from dotenv import load_dotenv
import openai
from pprint import pprint
from datasets import load_dataset, DatasetDict
import tiktoken
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
client = openai

In [3]:
def estimate_tokens(prompt, model="gpt-4o-mini"):
    """
    Estimate the number of tokens in a given prompt for a specific model.
    :param prompt: The text prompt to estimate tokens for.
    :param model: The model name (e.g., "gpt-4", "gpt-3.5-turbo").
    :return: The estimated token count.
    """
    # Load the tokenizer for the specified model
    encoding = tiktoken.encoding_for_model(model)
    tokens = encoding.encode(prompt)
    return len(tokens)

In [4]:
def write_file(filename, data):
    with open(filename, "wb") as f:
        pickle.dump(data, f)
    f.close()
    
def read_file(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)

def collect_labels(all_predictions):
    print(len(all_predictions))
    output_labels = []
    for i, message in enumerate(all_predictions):
        # print(message)
        output_labels.append(int(message.strip().split('Classification: ')[1]))
    return output_labels

In [105]:
def classify_clauses_batch(clauses, system_prompt, batch_size=10, model="gpt-4o-mini", test=False):
    all_predictions = []
    invalid_batch_nums = []
    for i in range(0, len(clauses), batch_size):
        batch_clauses = clauses[i: i + batch_size]
        user_prompt = "Classify the following clauses:\n" + "\n".join(
            [f"{j+1}. {clause}" for j, clause in enumerate(batch_clauses)]
        )
        full_prompt = system_prompt + "\n" + user_prompt
        token_count = estimate_tokens(full_prompt, model=model)
        print(
            f"Estimated token count for batch {i // batch_size + 1}: {token_count}")

        try:
            response = openai.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0,  # For consistent results
            )
            # response = openai.chat.completions.create(
            #     model=model,
            #     messages=[
            #         {
            #             "role": "user",
            #             "content": full_prompt
            #         }
            #     ]
            # )
            # Extract classifications from the response
            batch_predictions = []
            for choice in response.choices:
                message_content = choice.message.content
                batch_predictions.extend(message_content.strip().split("\n"))

            
            # Verify the number of classifications matches the batch size
            if len(batch_predictions) != len(batch_clauses):
                invalid_batch_nums.append(i // batch_size + 1)
                print(
                    f"WARNING: Mismatch in batch size for Batch {i // batch_size + 1}!")
                print(
                    f"Expected {len(batch_clauses)} classifications, but got {len(batch_predictions)}.")

            # Append verified predictions to the final list
            all_predictions.extend(batch_predictions)
            print(
                f"Batch {i // batch_size + 1} processed successfully with {len(batch_predictions)} classifications.")
            if test:
                return all_predictions, invalid_batch_nums
        except Exception as e:
            print(f"Error processing batch {i // batch_size + 1}: {e}")
            return all_predictions, invalid_batch_nums
    return all_predictions, invalid_batch_nums


In [6]:
def evaluate_predictions(ground_truth, predictions):
    """
    Evaluates the predictions against ground truth labels and calculates metrics.

    Args:
      ground_truth: A list of ground truth labels (0 for Fair, 1 for Unfair).
      predictions: A list of predicted labels (0 for Fair, 1 for Unfair).

    Returns:
      A dictionary containing the calculated metrics.
    """

    accuracy = accuracy_score(ground_truth, predictions)
    precision = precision_score(ground_truth, predictions)
    recall = recall_score(ground_truth, predictions)
    f1 = f1_score(ground_truth, predictions)

    metrics = {
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1
    }

    return metrics

In [7]:
def load_claudette_tos_dataset():
    dataset = load_dataset("LawInformedAI/claudette_tos")['train']
    dataset = dataset.class_encode_column('label')
    split1 = dataset.train_test_split(
        test_size=0.2,
        stratify_by_column='label',
        seed=42
    )
    split2 = split1['test'].train_test_split(
        test_size=0.5,
        stratify_by_column='label',
        seed=42
    )
    # Combine splits
    final_splits = DatasetDict({
        'train': split1['train'],
        'validation': split2['train'],
        'test': split2['test']
    })

    return final_splits

In [39]:

system_prompt = """
You are a legal expert in consumer rights and contract law. Your task is to classify clauses from Terms of Service (ToS) 
documents as either '0' (Fair) or '1' (Unfair).

A clause is '1' (Unfair) if it:
1. Limits or excludes the provider's liability for damages.
2. Allows the provider to terminate services or the contract unilaterally.
3. Allows the provider to change contract terms unilaterally.
4. Grants the provider rights to remove content without reason or notice.
5. Binds users to terms simply by using the service, without explicit agreement.
6. Specifies governing law favoring the provider over the user's country.
7. Forces dispute resolution in a jurisdiction not favorable to the user.
8. Forces arbitration restricting the user's legal rights.
9. Imposes unreasonable restrictions or obligations on users.
10. Significantly limits users' rights or remedies.
11. Creates an imbalance of power favoring the service provider.
12. Uses vague or ambiguous language that could be exploited.
13. Violates established legal principles or consumer protection laws.

Classify ONLY the clauses provided in the user prompt. Provide ONLY the classification ('Classification: 0' or 'Classification: 1') 
for each clause, line by line, in the same order as the input. Do not add extra lines, blank lines, or explanations.
Example:
Classification: 0
Classification: 1
"""

final_splits = load_claudette_tos_dataset()

validation_clauses = final_splits['validation']['text']
validation_labels = final_splits['validation']['label']

# train_clauses = final_splits['train']['text']
# train_labels = final_splits['validation']['label']

test_clauses = final_splits['test']['text']
test_labels = final_splits['test']['label']

In [21]:
# validation_result

{'Accuracy': 0.7640807651434643,
 'Precision': 0.30618892508143325,
 'Recall': 0.912621359223301,
 'F1-Score': 0.4585365853658537}

In [19]:
info = openai.models.retrieve("o1-mini")

In [112]:
test_predictions, invalid_batch_nums = classify_clauses_batch(
    test_clauses, system_prompt, batch_size=5, model="gpt-4o")

Estimated token count for batch 1: 535
Batch 1 processed successfully with 5 classifications.
Estimated token count for batch 2: 441
Batch 2 processed successfully with 5 classifications.
Estimated token count for batch 3: 445
Batch 3 processed successfully with 5 classifications.
Estimated token count for batch 4: 596
Batch 4 processed successfully with 5 classifications.
Estimated token count for batch 5: 486
Batch 5 processed successfully with 5 classifications.
Estimated token count for batch 6: 511
Batch 6 processed successfully with 5 classifications.
Estimated token count for batch 7: 456
Batch 7 processed successfully with 5 classifications.
Estimated token count for batch 8: 521
Batch 8 processed successfully with 5 classifications.
Estimated token count for batch 9: 627
Batch 9 processed successfully with 5 classifications.
Estimated token count for batch 10: 437
Batch 10 processed successfully with 5 classifications.
Estimated token count for batch 11: 426
Batch 11 processed

In [66]:
# test_predictions.remove("```")
# test_predictions.count("```")

0

In [59]:
# test_labels[-5:]
invalid_batch_nums

[73]

In [113]:
len(test_predictions)

942

In [114]:
test_predictions_labels = collect_labels(test_predictions)

942


In [70]:
len(test_predictions_labels)

942

In [31]:
len(test_labels)

942

In [115]:
test_results = evaluate_predictions(test_labels, test_predictions_labels)

In [116]:
#gpt-4
test_results

{'Accuracy': 0.7802547770700637,
 'Precision': 0.32068965517241377,
 'Recall': 0.9029126213592233,
 'F1-Score': 0.4732824427480916}

In [110]:
#gpt4o-mini
test_results

{'Accuracy': 0.7526539278131635,
 'Precision': 0.2929936305732484,
 'Recall': 0.8932038834951457,
 'F1-Score': 0.4412470023980815}

In [100]:
#o4-mini results
test_results

{'Accuracy': 0.8089171974522293,
 'Precision': 0.35361216730038025,
 'Recall': 0.9029126213592233,
 'F1-Score': 0.5081967213114754}

In [87]:
# o3-mini results
test_results

{'Accuracy': 0.856687898089172,
 'Precision': 0.42727272727272725,
 'Recall': 0.912621359223301,
 'F1-Score': 0.5820433436532507}

In [73]:
# o1-mini results
test_results

{'Accuracy': 0.7452229299363057,
 'Precision': 0.290519877675841,
 'Recall': 0.9223300970873787,
 'F1-Score': 0.4418604651162791}

In [33]:
# gpt-4 Results
test_results

{'Accuracy': 0.7866242038216561,
 'Precision': 0.32867132867132864,
 'Recall': 0.912621359223301,
 'F1-Score': 0.4832904884318766}

In [26]:
# 4o-mini results
test_results

{'Accuracy': 0.7505307855626328,
 'Precision': 0.2962962962962963,
 'Recall': 0.9320388349514563,
 'F1-Score': 0.4496487119437939}

In [101]:
write_file("test-o4mini-5-improved.pkl", test_results)

In [44]:
test_results

{'Accuracy': 0.7346072186836518,
 'Precision': 0.25418060200668896,
 'Recall': 0.7378640776699029,
 'F1-Score': 0.3781094527363184}

In [117]:
write_file("test-gpt4o-5-improved.pkl", test_results)

In [5]:
read_file("test-gpt4o-batch-10.pkl")

{'Accuracy': 0.7346072186836518,
 'Precision': 0.25418060200668896,
 'Recall': 0.7378640776699029,
 'F1-Score': 0.3781094527363184}

In [75]:
read_file("test-o1mini-10-improved.pkl")

{'Accuracy': 0.7452229299363057,
 'Precision': 0.290519877675841,
 'Recall': 0.9223300970873787,
 'F1-Score': 0.4418604651162791}