In [1]:
from datasets import load_dataset

In [7]:
# Load the datasets
ds1_train = load_dataset("joelniklaus/legal_case_document_summarization", split='train')
ds1_train = ds1_train.remove_columns(['dataset_name'])
ds1_train = ds1_train.rename_column('judgement', 'text')
ds1_train = ds1_train.rename_column('summary', 'label')
print(ds1_train)

# NOTE: This dataset only has 50 rows. It may not be a dataset we want to use.
# Although the summaries appear to be good
ds2_DatasetDict = load_dataset("manasvikalyan/legal-documents-summary")
ds2_actual = ds2_DatasetDict['data']
ds2_actual = ds2_actual.remove_columns(['summary_a2'])
ds2_actual = ds2_actual.rename_column('summary_a1', 'label')
ds2_actual = ds2_actual.rename_column('judgement', 'text')
print(ds2_actual)

ds3_train = load_dataset("coastalcph/lex_glue", "case_hold", split='train')
print(ds3_train)
ds4_train = load_dataset("coastalcph/lex_glue", "ecthr_a", split='train')
print(ds4_train)
ds5_train = load_dataset("coastalcph/lex_glue", "ecthr_b", split='train')
print(ds5_train)
ds6_train = load_dataset("coastalcph/lex_glue", "eurlex", split='train')
print(ds6_train)
ds7_train = load_dataset("coastalcph/lex_glue", "ledgar", split='train')
print(ds7_train)
ds8_train = load_dataset("coastalcph/lex_glue", "scotus", split='train')
print(ds8_train)


Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['text', 'label'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label'],
    num_rows: 50
})
Dataset({
    features: ['context', 'endings', 'label'],
    num_rows: 45000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 55000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 60000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 5000
})


In [20]:
from transformers import BertTokenizer

In [21]:
# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [22]:
# define a tokenization function for the summarization task
def tokenize_function(examples):
    # Tokenize the 'text' (input legal case) and 'label' (summary)
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

In [28]:
ds1_train_tokenized = ds1_train.map(tokenize_function, batched=True)
ds2_actual_tokenized = ds2_actual.map(tokenize_function, batched=True)

print(ds1_train_tokenized)
print(ds2_actual_tokenized)

Dataset({
    features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 50
})


In [24]:
# For the summarization task, the labels need to be tokenized separately
def tokenize_labels(examples):
    return tokenizer(examples['label'], padding="max_length", truncation=True, max_length=150) # max_length can be adjusted for summaries

In [29]:
# Apply the tokenization to the label column (summaries)
ds1_train_tokenized = ds1_train_tokenized.map(tokenize_labels, batched=True, remove_columns=['label'])
ds2_actual_tokenized = ds2_actual_tokenized.map(tokenize_labels, batched=True, remove_columns=['label'])

print(ds1_train_tokenized)
print(ds2_actual_tokenized)

Dataset({
    features: ['text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 50
})


In [30]:
ds1_train_tokenized = ds1_train_tokenized.rename_column("input_ids", "labels")
ds2_actual_tokenized = ds2_actual_tokenized.rename_column("input_ids", "labels")

print(ds1_train_tokenized)
print(ds2_actual_tokenized)

Dataset({
    features: ['text', 'labels', 'token_type_ids', 'attention_mask'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'labels', 'token_type_ids', 'attention_mask'],
    num_rows: 50
})
