# Fine-Tuning RoBERTa for Legal Contract Clause Extraction
## Project Overview
This project fine-tunes a RoBERTa model to automatically extract 6 key clause types from legal contracts using the CUAD dataset.

**Target Clauses:**
- Governing Law
- Expiration Date
- Effective Date
- Anti-Assignment
- Cap On Liability
- License Grant

## Phase 1: Data Preparation
Loading and formatting the CUAD dataset for question-answering task.

In [90]:
import json

# Load the CUAD dataset
with open('../data/CUADv1.json', 'r') as f:
    cuad_data = json.load(f)

# Check the structure
print(f"Keys in dataset: {cuad_data.keys()}")
print(f"Number of entries: {len(cuad_data['data'])}")

Keys in dataset: dict_keys(['version', 'data'])
Number of entries: 510


In [91]:
# Explore first entry
first_entry = cuad_data['data'][0]
print(f"Title: {first_entry['title']}")
print(f"\nNumber of paragraphs: {len(first_entry['paragraphs'])}")
print(f"\nFirst paragraph keys: {first_entry['paragraphs'][0].keys()}")

Title: LIMEENERGYCO_09_09_1999-EX-10-DISTRIBUTOR AGREEMENT

Number of paragraphs: 1

First paragraph keys: dict_keys(['qas', 'context'])


In [92]:
# Look at the context and questions
first_para = first_entry['paragraphs'][0]

print(f"Contract text length: {len(first_para['context'])} characters")
print(f"\nNumber of questions (clause types): {len(first_para['qas'])}")
print(f"\nFirst question example:")
print(f"Question: {first_para['qas'][0]['question']}")
print(f"Answer: {first_para['qas'][0]['answers']}")

Contract text length: 54290 characters

Number of questions (clause types): 41

First question example:
Question: Highlight the parts (if any) of this contract related to "Document Name" that should be reviewed by a lawyer. Details: The name of the contract
Answer: [{'text': 'DISTRIBUTOR AGREEMENT', 'answer_start': 44}]


In [93]:
# Count how many contracts have answers for each clause type
clause_counts = {}

for entry in cuad_data['data']:
    for para in entry['paragraphs']:
        for qa in para['qas']:
            question = qa['question']
            has_answer = len(qa['answers']) > 0
            
            if question not in clause_counts:
                clause_counts[question] = 0
            if has_answer:
                clause_counts[question] += 1

# Show top 10 most common clauses
sorted_clauses = sorted(clause_counts.items(), key=lambda x: x[1], reverse=True)
print("Top 10 most common clause types:\n")
for clause, count in sorted_clauses[:10]:
    print(f"{count} contracts: {clause}")


Top 10 most common clause types:

510 contracts: Highlight the parts (if any) of this contract related to "Document Name" that should be reviewed by a lawyer. Details: The name of the contract
509 contracts: Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer. Details: The two or more parties who signed the contract
470 contracts: Highlight the parts (if any) of this contract related to "Agreement Date" that should be reviewed by a lawyer. Details: The date of the contract
437 contracts: Highlight the parts (if any) of this contract related to "Governing Law" that should be reviewed by a lawyer. Details: Which state/country's law governs the interpretation of the contract?
413 contracts: Highlight the parts (if any) of this contract related to "Expiration Date" that should be reviewed by a lawyer. Details: On what date will the contract's initial term expire?
390 contracts: Highlight the parts (if any) of this contract related to "Effe

In [94]:
# Define target clause types
target_clauses = [
    "Highlight the parts (if any) of this contract related to \"Governing Law\" that should be reviewed by a lawyer. Details: Which state/country's law governs the interpretation of the contract?",
    "Highlight the parts (if any) of this contract related to \"Expiration Date\" that should be reviewed by a lawyer. Details: On what date will the contract's initial term expire?",
    "Highlight the parts (if any) of this contract related to \"Effective Date\" that should be reviewed by a lawyer. Details: The date when the contract is effective ",
    "Highlight the parts (if any) of this contract related to \"Anti-Assignment\" that should be reviewed by a lawyer. Details: Is consent or notice required of a party if the contract is assigned to a third party?",
    "Highlight the parts (if any) of this contract related to \"Cap On Liability\" that should be reviewed by a lawyer. Details: Does the contract include a cap on liability upon the breach of a party’s obligation? This includes time limitation for the counterparty to bring claims or maximum amount for recovery.",
    "Highlight the parts (if any) of this contract related to \"License Grant\" that should be reviewed by a lawyer. Details: Does the contract contain a license granted by one party to its counterparty?"
]

print(f"Selected {len(target_clauses)} clause types for extraction")


Selected 6 clause types for extraction


In [95]:
# Extract data using partial matching on clause names
clause_names = ["Governing Law", "Expiration Date", "Effective Date", 
                "Anti-Assignment", "Cap On Liability", "License Grant"]

filtered_data = []

for entry in cuad_data['data']:
    contract_id = entry['title']
    for para in entry['paragraphs']:
        context = para['context']
        
        for qa in para['qas']:
            # Check if any of our clause names appears in the question
            for clause_name in clause_names:
                if f'"{clause_name}"' in qa['question']:
                    filtered_data.append({
                        'contract_id': contract_id,
                        'context': context,
                        'question': qa['question'],
                        'answers': qa['answers']
                    })
                    break  # Found this clause, move to next qa

print(f"Total examples: {len(filtered_data)}")


Total examples: 3060


In [96]:
# Count examples per clause type
from collections import Counter

clause_distribution = Counter([item['question'] for item in filtered_data])

print("Examples per clause type:\n")
for question, count in clause_distribution.items():
    # Extract just the clause name for readability
    clause_name = question.split('"')[1]
    print(f"{clause_name}: {count} examples")

Examples per clause type:

Effective Date: 510 examples
Expiration Date: 510 examples
Governing Law: 510 examples
Anti-Assignment: 510 examples
License Grant: 510 examples
Cap On Liability: 510 examples


In [97]:
# Count examples with vs without answers
has_answer = sum(1 for item in filtered_data if len(item['answers']) > 0)
no_answer = len(filtered_data) - has_answer

print(f"Examples WITH answers: {has_answer}")
print(f"Examples WITHOUT answers: {no_answer}")
print(f"Percentage with answers: {has_answer/len(filtered_data)*100:.1f}%")

Examples WITH answers: 2144
Examples WITHOUT answers: 916
Percentage with answers: 70.1%


In [98]:
# Show one example with an answer
example_with_answer = [item for item in filtered_data if len(item['answers']) > 0][0]

clause_name = example_with_answer['question'].split('"')[1]
answer_text = example_with_answer['answers'][0]['text']
answer_start = example_with_answer['answers'][0]['answer_start']

print(f"Clause: {clause_name}")
print(f"\nExtracted text: {answer_text}")
print(f"Position in contract: character {answer_start}")


Clause: Effective Date

Extracted text: The term of this  Agreement  shall be ten (10)                            years (the "Term")  which shall  commence on the date                            upon which the Company  delivers to  Distributor  the                            last Sample, as defined  hereinafter.
Position in contract: character 5268


In [99]:
# Format data for QA task
formatted_examples = []

for item in filtered_data:
    example = {
        'id': f"{item['contract_id']}_{item['question'].split('\"')[1]}",
        'context': item['context'],
        'question': item['question'],
        'answers': item['answers']
    }
    formatted_examples.append(example)

print(f"Total formatted examples: {len(formatted_examples)}")
print(f"\nFirst example keys: {formatted_examples[0].keys()}")

Total formatted examples: 3060

First example keys: dict_keys(['id', 'context', 'question', 'answers'])


In [100]:
# Check the actual keys in our formatted examples
print(formatted_examples[0].keys())


dict_keys(['id', 'context', 'question', 'answers'])


### Train/Validation/Test Split
Splitting by contract ID to prevent data leakage (70/15/15 split).

In [101]:
from sklearn.model_selection import train_test_split

# Extract contract IDs from the 'id' field
unique_contracts = list(set([item['id'].rsplit('_', 1)[0] for item in formatted_examples]))
print(f"Total unique contracts: {len(unique_contracts)}")

# Split contract IDs
train_ids, temp_ids = train_test_split(unique_contracts, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

print(f"Train contracts: {len(train_ids)}")
print(f"Val contracts: {len(val_ids)}")
print(f"Test contracts: {len(test_ids)}")

Total unique contracts: 510
Train contracts: 357
Val contracts: 76
Test contracts: 77


In [102]:
# Split examples based on contract IDs
train_data = [ex for ex in formatted_examples if ex['id'].rsplit('_', 1)[0] in train_ids]
val_data = [ex for ex in formatted_examples if ex['id'].rsplit('_', 1)[0] in val_ids]
test_data = [ex for ex in formatted_examples if ex['id'].rsplit('_', 1)[0] in test_ids]

print(f"Train examples: {len(train_data)} ({len(train_data)/6:.0f} contracts × 6 clauses)")
print(f"Val examples: {len(val_data)} ({len(val_data)/6:.0f} contracts × 6 clauses)")
print(f"Test examples: {len(test_data)} ({len(test_data)/6:.0f} contracts × 6 clauses)")


Train examples: 2142 (357 contracts × 6 clauses)
Val examples: 456 (76 contracts × 6 clauses)
Test examples: 462 (77 contracts × 6 clauses)


In [103]:
import json

# Save datasets
with open('../data/train_data.json', 'w') as f:
    json.dump(train_data, f, indent=2)

with open('../data/val_data.json', 'w') as f:
    json.dump(val_data, f, indent=2)

with open('../data/test_data.json', 'w') as f:
    json.dump(test_data, f, indent=2)

print("Saved train_data.json")
print("Saved val_data.json")
print("Saved test_data.json")


Saved train_data.json
Saved val_data.json
Saved test_data.json


## Phase 2: Model Selection & Setup
Using RoBERTa-base for question-answering on legal contracts.

In [104]:
# Check if transformers is installed
try:
    import transformers
    print(f"transformers version: {transformers.__version__}")
except ImportError:
    print("Need to install: !pip install transformers datasets accelerate")

transformers version: 5.1.0


### Loading RoBERTa Model and Tokenizer

In [105]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

model_name = "roberta-base"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

print(f"Loaded {model_name}")
print(f"Model parameters: {model.num_parameters():,}")

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

[1mRobertaForQuestionAnswering LOAD REPORT[0m from: roberta-base
Key                             | Status     | 
--------------------------------+------------+-
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
lm_head.layer_norm.weight       | UNEXPECTED | 
lm_head.dense.bias              | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
qa_outputs.bias                 | MISSING    | 
qa_outputs.weight               | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


Loaded roberta-base
Model parameters: 124,056,578


### Tokenizing the Dataset
Converting text to token IDs and mapping answer positions to token positions.

In [106]:
from datasets import Dataset

# Convert our lists to Hugging Face Dataset format
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)
test_dataset = Dataset.from_list(test_data)

print(f"Created HF datasets")
print(f"Train: {len(train_dataset)} examples")
print(f"Val: {len(val_dataset)} examples")
print(f"Test: {len(test_dataset)} examples")


Created HF datasets
Train: 2142 examples
Val: 456 examples
Test: 462 examples


### Preprocessing Function
Maps character positions to token positions for answer spans.

In [107]:
def preprocess_function(examples):
    """
    Tokenizes questions and contexts, and maps answer character positions to token positions.
    """
    
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    start_positions = []
    end_positions = []
    
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")
    
    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        
        # Check if answers is a list (batched) or dict (single)
        if isinstance(answers, list):
            has_answer = len(answers) > 0 and answers[0].get("text")
        else:
            has_answer = len(answers.get("text", [])) > 0
        
        if not has_answer:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            if isinstance(answers, list):
                start_char = answers[0]["answer_start"]
                answer_text = answers[0]["text"]
            else:
                start_char = answers["answer_start"][0]
                answer_text = answers["text"][0]
            
            end_char = start_char + len(answer_text)
            
            token_start_index = 0
            while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                token_start_index += 1
            start_positions.append(token_start_index - 1)
            
            token_end_index = len(offsets) - 1
            while token_end_index >= 0 and offsets[token_end_index][1] >= end_char:
                token_end_index -= 1
            end_positions.append(token_end_index + 1)
    
    tokenized_examples["start_positions"] = start_positions
    tokenized_examples["end_positions"] = end_positions
    
    return tokenized_examples


### Apply Tokenization to All Datasets

In [108]:
# Tokenize all datasets
tokenized_train = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

tokenized_val = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names
)

print(f"Tokenized train: {len(tokenized_train)} examples")
print(f"Tokenized val: {len(tokenized_val)} examples")


Map:   0%|          | 0/2142 [00:00<?, ? examples/s]

Map:   0%|          | 0/456 [00:00<?, ? examples/s]

Tokenized train: 126982 examples
Tokenized val: 31702 examples


## Phase 3: Fine-Tuning Configuration
Setting up training parameters and callbacks for model fine-tuning.

In [118]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",              # where to save model checkpoints
    eval_strategy="epoch",               # evaluate after each epoch
    learning_rate=3e-5,                  # standard for fine-tuning
    per_device_train_batch_size=4,       # small batch for M1 Mac memory
    per_device_eval_batch_size=4,
    num_train_epochs=3,                  # 3 full passes through data
    weight_decay=0.01,                   # prevents overfitting
    save_strategy="epoch",               # save checkpoint each epoch
    save_total_limit=2,                  # keep only 2 best checkpoints
    load_best_model_at_end=True,         # load best model after training
    metric_for_best_model="eval_loss",   # use validation loss to pick best
)

print("Training arguments configured")

Training arguments configured


### Setting up the Trainer
The Trainer handles the training loop, evaluation, and checkpointing.

In [119]:
from transformers import Trainer, DefaultDataCollator

# Data collator handles batching
data_collator = DefaultDataCollator()

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

print("Trainer initialized and ready")


Trainer initialized and ready


### Training the Model
Starting fine-tuning process

In [120]:
# Start training
trainer.train()

  super().__init__(loader)


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 