# Legal Clause Similarity Detection
## Deep Learning Assignment 2 - CS425

**Student:** Syed Taha Hasan  
**FastID:** i211767  
**Date:** November 2025

### Objective
Develop NLP models to identify semantic similarity between legal clauses using baseline architectures (BiLSTM and ESIM - Enhanced Sequential Inference Model) without pre-trained transformers.


## 1. Setup and Imports


In [None]:
import sys
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True

# Import custom modules
from data_loader import LegalClauseDataLoader
from text_preprocessor import TextPreprocessor
from models import BiLSTMSimilarityModel, ESIMSimilarityModel
from trainer import ModelTrainer, ClausePairDataset
from evaluator import ModelEvaluator
from visualization import TrainingVisualizer

print("Setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


Setup complete!
PyTorch version: 2.7.0+cpu
CUDA available: False


## 2. Data Loading and Exploration


In [2]:
# Initialize data loader
data_loader = LegalClauseDataLoader(data_dir="archive (1)")

# Load all data
clauses_by_category = data_loader.load_all_data()

# Get statistics
stats = data_loader.get_statistics()
print("\nDataset Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")


Loading 395 CSV files...
Loaded 395 categories
Total clauses: 150881
Categories with most clauses: [('time-of-essence', 630), ('time-of-the-essence', 620), ('capitalized-terms', 590), ('definitions-and-interpretation', 590), ('captions', 580)]

Dataset Statistics:
  num_categories: 395
  total_clauses: 150881
  avg_clauses_per_category: 381.9772151898734
  min_clauses_per_category: 15
  max_clauses_per_category: 630


In [3]:
# Create similarity pairs
NUM_PAIRS = 10000  # Adjust based on computational resources
POSITIVE_RATIO = 0.5  # 50% similar, 50% dissimilar

pairs, labels = data_loader.create_similarity_pairs(
    num_pairs=NUM_PAIRS,
    positive_ratio=POSITIVE_RATIO,
    seed=42
)

print(f"\nTotal pairs created: {len(pairs)}")
print(f"Similar pairs: {sum(labels)}")
print(f"Dissimilar pairs: {len(labels) - sum(labels)}")


Created 10000 pairs: 5000 positive, 5000 negative

Total pairs created: 10000
Similar pairs: 5000
Dissimilar pairs: 5000


## 3. Text Preprocessing and Data Splitting


In [4]:
# Initialize preprocessor
MAX_VOCAB_SIZE = 10000
MAX_SEQ_LENGTH = 200

preprocessor = TextPreprocessor(
    max_vocab_size=MAX_VOCAB_SIZE,
    max_seq_length=MAX_SEQ_LENGTH
)

# Build vocabulary from all clause texts
all_texts = [text for pair in pairs for text in pair]
preprocessor.build_vocabulary(all_texts)

print(f"\nVocabulary size: {preprocessor.vocab_size}")
print(f"Max sequence length: {preprocessor.max_seq_length}")

# Split data: 70% train, 15% validation, 15% test
X_train, X_temp, y_train, y_temp = train_test_split(
    pairs, labels, test_size=0.3, random_state=42, stratify=labels
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"\nTraining pairs: {len(X_train)}")
print(f"Validation pairs: {len(X_val)}")
print(f"Test pairs: {len(X_test)}")

# Create datasets
BATCH_SIZE = 32
train_dataset = ClausePairDataset(X_train, y_train, preprocessor)
val_dataset = ClausePairDataset(X_val, y_val, preprocessor)
test_dataset = ClausePairDataset(X_test, y_test, preprocessor)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


Building vocabulary...
Vocabulary size: 10000
Most common words: [('the', 182334), ('or', 174451), ('of', 161900), ('any', 122302), ('in', 86170), ('to', 68240), ('and', 64064), ('its', 38740), ('by', 35922), ('company', 35099)]

Vocabulary size: 10000
Max sequence length: 200

Training pairs: 7000
Validation pairs: 1500
Test pairs: 1500


In [5]:
# Initialize and train BiLSTM model
bilstm_model = BiLSTMSimilarityModel(
    vocab_size=preprocessor.vocab_size,
    embedding_dim=128,
    hidden_dim=256,
    num_layers=2,
    dropout=0.3,
    num_classes=2
)

print("BiLSTM Model created")
total_params = sum(p.numel() for p in bilstm_model.parameters())
print(f"Total parameters: {total_params:,}")

# Train
bilstm_trainer = ModelTrainer(bilstm_model)
bilstm_history = bilstm_trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    learning_rate=0.001,
    weight_decay=1e-5,
    patience=5
)

# Evaluate
bilstm_evaluator = ModelEvaluator(bilstm_model)
bilstm_metrics = bilstm_evaluator.compute_metrics(test_loader)
bilstm_evaluator.print_metrics(bilstm_metrics, "BiLSTM Model")
bilstm_qualitative = bilstm_evaluator.get_qualitative_results(test_loader, X_test, y_test, num_examples=5)

# Plot training history
visualizer = TrainingVisualizer()
fig = visualizer.plot_training_history(bilstm_history, "BiLSTM Model")
plt.show()


BiLSTM Model created
Total parameters: 3,976,194
Training on device: cpu
Model parameters: 3,976,194

Epoch 1/15
--------------------------------------------------


Training:  33%|███▎      | 73/219 [05:04<10:09,  4.17s/it]


KeyboardInterrupt: 

## 5. Model 2: ESIM (Enhanced Sequential Inference Model)


In [None]:
# Initialize and train ESIM model
esim_model = ESIMSimilarityModel(
    vocab_size=preprocessor.vocab_size,
    embedding_dim=128,
    hidden_dim=256,
    num_layers=1,
    dropout=0.3,
    num_classes=2
)

print("ESIM Model created")
print("ESIM Architecture:")
print("  - Input Encoding: BiLSTM")
print("  - Local Inference: Soft attention alignment")
print("  - Inference Composition: BiLSTM")
print("  - Pooling: Mean and Max")
total_params = sum(p.numel() for p in esim_model.parameters())
print(f"Total parameters: {total_params:,}")

# Train
esim_trainer = ModelTrainer(esim_model)
esim_history = esim_trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    learning_rate=0.001,
    weight_decay=1e-5,
    patience=5
)

# Evaluate
esim_evaluator = ModelEvaluator(esim_model)
esim_metrics = esim_evaluator.compute_metrics(test_loader)
esim_evaluator.print_metrics(esim_metrics, "ESIM Model")
esim_qualitative = esim_evaluator.get_qualitative_results(test_loader, X_test, y_test, num_examples=5)

# Plot training history
fig = visualizer.plot_training_history(esim_history, "ESIM Model")
plt.show()


## 6. Model Comparison and Results


In [None]:
# Compare models
comparison_metrics = [bilstm_metrics, esim_metrics]
model_names = ["BiLSTM", "ESIM"]

fig = visualizer.plot_model_comparison(comparison_metrics, model_names)
plt.show()

# Create comparison table
comparison_df = pd.DataFrame({
    'Model': model_names,
    'Accuracy': [m['accuracy'] for m in comparison_metrics],
    'Precision': [m['precision'] for m in comparison_metrics],
    'Recall': [m['recall'] for m in comparison_metrics],
    'F1-Score': [m['f1_score'] for m in comparison_metrics],
    'ROC-AUC': [m['roc_auc'] for m in comparison_metrics],
    'PR-AUC': [m['pr_auc'] for m in comparison_metrics],
    'Training Time (s)': [bilstm_history['training_time'], esim_history['training_time']]
})

print("\nModel Comparison:")
print(comparison_df.to_string(index=False))


## 7. Qualitative Results


In [None]:
# Display qualitative results
print("\n" + "="*80)
print("BiLSTM Model - Sample Correct Predictions")
print("="*80)
for i, example in enumerate(bilstm_qualitative['correct'][:3], 1):
    print(f"\nExample {i}:")
    print(f"  Text 1: {example['text1'][:150]}...")
    print(f"  Text 2: {example['text2'][:150]}...")
    print(f"  True: {example['true_label']}, Predicted: {example['predicted_label']}, Confidence: {example['confidence']:.3f}")

print("\n" + "="*80)
print("ESIM Model - Sample Correct Predictions")
print("="*80)
for i, example in enumerate(esim_qualitative['correct'][:3], 1):
    print(f"\nExample {i}:")
    print(f"  Text 1: {example['text1'][:150]}...")
    print(f"  Text 2: {example['text2'][:150]}...")
    print(f"  True: {example['true_label']}, Predicted: {example['predicted_label']}, Confidence: {example['confidence']:.3f}")
