In [2]:
# !pip install datasets

# Reading the Legal Cases Dataframe
* case_id - represents a unique id for each case
* text - represents the complaint text for each case
* domain - the domain name of which this case belongs to

In [1]:
from datasets import load_dataset
dataset = load_dataset('darrow-ai/legal-task')
dataset = dataset['train'].to_pandas()
dataset.head()

  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,id,text,domain
0,r-e4EYcBD5gMZwcz41zP,UNITED STATES DISTRICT COURT \nEASTERN DISTRIC...,consumer fraud
1,i9H5DocBD5gMZwcztj0y,IN THE UNITED STATES DISTRICT COURT \nFOR THE ...,privacy
2,SMn3DYcBD5gMZwcz-hwH,IN THE UNITED STATES DISTRICT COURT\n FOR THE ...,privacy
3,GMIWDYcBD5gMZwczDQBb,Case No. _______________ \n \n \nCLASS ACTION ...,criminal & enforcement
4,lELw_IgBF5pVm5zYONwC,UNITED STATES DISTRICT COURT \n SOUTHERN DISTR...,consumer fraud


In [2]:
dataset.id.count()

np.int64(1204)

In [5]:
dataset.groupby('domain').agg(
    total_count=('id', 'count'),
    distinct_count=('id', 'nunique')
).sort_values('distinct_count', ascending=False)


Unnamed: 0_level_0,total_count,distinct_count
domain,Unnamed: 1_level_1,Unnamed: 2_level_1
consumer fraud,200,200
securities,200,200
privacy,200,200
employment & labor,200,200
"civil rights, immigration, family",167,167
antitrust,126,126
products liability and mass tort,56,56
discrimination,20,20
criminal & enforcement,16,16
healthcare,9,9


In [6]:
# print(dataset.sample(1).text.iloc[0])


# Data Cleaning For Case Similarity

## Deduplication

In [7]:
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt

pd.set_option("display.max_colwidth", 300)
plt.rcParams["figure.figsize"] = (8, 5)
plt.rcParams["axes.grid"] = True

print("Initial shape:", dataset.shape)

# Find all rows that are duplicated on the exact pair (id, text)
dup_mask = dataset.duplicated(subset=["id", "text"], keep=False)
duplicates_df = dataset[dup_mask].copy()

print(f"Number of rows involved in [id, text] duplicates: {duplicates_df.shape[0]}")

# Save them for transparency / potential manual inspection
if not duplicates_df.empty:
    duplicates_df.to_csv("duplicates_id_text.csv", index=False)
    print("Saved [id, text] duplicates to 'duplicates_id_text.csv'.")

# Now create a deduplicated version keeping the first occurrence
dataset_deduped = dataset.drop_duplicates(subset=["id", "text"], keep="first").reset_index(drop=True)

print("Shape after deduplication on [id, text]:", dataset_deduped.shape)


Initial shape: (1204, 3)
Number of rows involved in [id, text] duplicates: 0
Shape after deduplication on [id, text]: (1204, 3)


In [8]:
import re
import numpy as np

# Start from your deduped / filtered dataframe:
# df has columns: id, text, domain

def clean_for_similarity(text: str) -> str:
    """
    Light normalization for semantic similarity:
    - ensure string
    - normalize whitespace (incl. newlines) to single spaces
    - strip
    - lowercase
    We keep punctuation and numbers – they can be informative in legal texts.
    """
    if not isinstance(text, str):
        text = "" if text is None else str(text)

    # Collapse all whitespace (spaces, tabs, newlines) into a single space
    text = re.sub(r"\s+", " ", text)

    # Strip leading/trailing spaces
    text = text.strip()

    # Lowercase for stability
    text = text.lower()

    return text

# Apply cleaning
dataset_deduped["text_clean_similarity"] = dataset_deduped["text"].apply(clean_for_similarity)

# Build helper mappings
case_ids = dataset_deduped["id"].tolist()
id_to_index = {cid: i for i, cid in enumerate(case_ids)}


# Part #1 - Similarity Calculation

In [9]:
# def calculate_cases_similarity(case_id_a, case_id_b):
#     '''
#     This method should return a similarity score [0-1] that represents how similar the cases are
#
#     @param case_id_a - the id of the first case
#     @param case_id_b - the id of the second case
#     @returns a similarity score between the cases
#     @rtype float
#     '''
#     pass

> We compute case-to-case similarity using Sentence-Transformers rather than raw ```transformers```, as the former are explicitly trained for producing semantically meaningful sentence and document embeddings suitable for cosine-similarity comparison.
>
> Among available models, we selected ```all-MiniLM-L6-v2``` for its strong balance between semantic quality, speed, and hardware efficiency.
>
> While domain-specific models such as ```Legal-BERT``` exist, they are significantly heavier and optimized for fine-tuned downstream tasks rather than zero-shot semantic clustering. Given our dataset (~1 000 cases) and CPU-only environment, MiniLM provides near-state-of-the-art similarity performance with practical runtime.

In [10]:
from sentence_transformers import SentenceTransformer

# You can pick another suitable model if you like
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

similarity_model = SentenceTransformer(EMBEDDING_MODEL_NAME)

# Encode cleaned texts; normalize for direct cosine via dot product
embeddings = similarity_model.encode(
    dataset_deduped["text_clean_similarity"].tolist(),
    batch_size=32,
    show_progress_bar=True,
    normalize_embeddings=True
)

# embeddings is a 2D numpy array: [n_cases, dim]
embeddings = np.asarray(embeddings)


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Batches: 100%|██████████| 38/38 [00:22<00:00,  1.72it/s]


In [11]:
def calculate_cases_similarity(case_id_a, case_id_b):
    '''
    This method should return a similarity score [0-1] that represents how similar the cases are

    @param case_id_a - the id of the first case
    @param case_id_b - the id of the second case
    @returns a similarity score between the cases
    @rtype float
    '''
    # Ensure both IDs exist
    idx_a = id_to_index.get(case_id_a)
    idx_b = id_to_index.get(case_id_b)

    if idx_a is None:
        raise ValueError(f"Unknown case_id_a: {case_id_a}")
    if idx_b is None:
        raise ValueError(f"Unknown case_id_b: {case_id_b}")

    # If same case → max similarity
    if idx_a == idx_b:
        return 1.0

    vec_a = embeddings[idx_a]
    vec_b = embeddings[idx_b]

    # Cosine similarity for normalized vectors is just dot product in [-1, 1]
    cos_sim = float(np.dot(vec_a, vec_b))

    # Map from [-1, 1] to [0, 1], clamp for numerical stability
    score = (cos_sim + 1.0) / 2.0
    score = max(0.0, min(1.0, score))

    return score


In [12]:
# quick sanity check

# Pick two random IDs
import random
a, b = random.sample(case_ids, 2)
print(a, b, calculate_cases_similarity(a, b))

# Same ID should be 1.0
print(a, a, calculate_cases_similarity(a, a))


i1EyBIkBRpLueGJZMLf9 deRdEYcBD5gMZwczFnnT 0.80626180768013
i1EyBIkBRpLueGJZMLf9 i1EyBIkBRpLueGJZMLf9 1.0
