# Possible models to use

## DistilBART - distilled version of BART, which is much smaller than the full BART model but retains much of its performance. Since it is distilled, it's faster and more efficient while still being well-suited for summarization tasks. DistilBART is designed for text summarization, and the cnn-12-6 variant is trained on news articles, making it a viable medium sized model for summarizing legal documents.

## T5 (Text-to-Text Transfer Transformer) - Small or Base - T5 treats every task as a text-to-text problem, making it very flexible for summarization. The small and base variants offer a middle ground between performance and model size, making them suitable for use cases where computational resources are limited.

In [3]:
from datasets import load_dataset

### Here I load the datasets and edit some of the columns prior to tokenizing the datasets

In [7]:
# Load the datasets
ds1_train = load_dataset("joelniklaus/legal_case_document_summarization", split='train')
ds1_train = ds1_train.remove_columns(['dataset_name'])
ds1_train = ds1_train.rename_column('judgement', 'text')
ds1_train = ds1_train.rename_column('summary', 'label')
print(ds1_train)

# NOTE: This dataset only has 50 rows. It may not be a dataset we want to use.
# Although the summaries appear to be good
ds2_DatasetDict = load_dataset("manasvikalyan/legal-documents-summary")
ds2_actual = ds2_DatasetDict['data']
ds2_actual = ds2_actual.remove_columns(['summary_a2'])
ds2_actual = ds2_actual.rename_column('summary_a1', 'label')
ds2_actual = ds2_actual.rename_column('judgement', 'text')
print(ds2_actual)

ds3_train = load_dataset("coastalcph/lex_glue", "case_hold", split='train')
print(ds3_train)
ds4_train = load_dataset("coastalcph/lex_glue", "ecthr_a", split='train')
print(ds4_train)
ds5_train = load_dataset("coastalcph/lex_glue", "ecthr_b", split='train')
print(ds5_train)
ds6_train = load_dataset("coastalcph/lex_glue", "eurlex", split='train')
print(ds6_train)
ds7_train = load_dataset("coastalcph/lex_glue", "ledgar", split='train')
print(ds7_train)
ds8_train = load_dataset("coastalcph/lex_glue", "scotus", split='train')
print(ds8_train)


Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['text', 'label'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label'],
    num_rows: 50
})
Dataset({
    features: ['context', 'endings', 'label'],
    num_rows: 45000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 55000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 60000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 5000
})


### Here I am pre-processing the data for the DistilBART model

In [12]:
from transformers import BartTokenizer

In [14]:
# Load the BART tokenizer
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]



In [15]:
# Tokenization function for text and summaries
def tokenize_function(examples):
    # Tokenize both the input text and summary (label)
    model_inputs = tokenizer(examples['text'], max_length=512, truncation=True, padding='max_length')
    
    # Tokenize the labels (summaries)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['label'], max_length=150, truncation=True, padding='max_length')
    
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

### Here I am just Tokenizing 'ds1' and 'ds2' for DistilBART (ds1_train and ds2_actual)

### TODO: Tokenize the other datasets later

In [16]:
# Tokenize the datasets for DistilBART (ds1_train and ds2_actual)
ds1_train_tokenized = ds1_train.map(tokenize_function, batched=True)
ds2_actual_tokenized = ds2_actual.map(tokenize_function, batched=True)

Map:   0%|          | 0/7773 [00:00<?, ? examples/s]



Map:   0%|          | 0/50 [00:00<?, ? examples/s]

### Here I am concatenating the datasets to use all together

### TODO: concatenate the rest of the datasets later

In [20]:
from datasets import concatenate_datasets

In [23]:
# Concatenate/Merge the datasets
combined_dataset = concatenate_datasets([ds1_train_tokenized, ds2_actual_tokenized])

### Splitting the combined dataset into train and validation sets

### TODO: concatenate the rest of the datasets later

In [24]:
combined_dataset = combined_dataset.train_test_split(test_size=0.2)

train_dataset = combined_dataset['train']he 
val_dataset = combined_dataset['test']

### Load the DistilBART model here

In [26]:
from transformers import BartForConditionalGeneration

In [27]:
# Load the DistilBART model for conditional generation
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

### Setting up training arguments for the model here

### TODO: These can be modified later to improve the model

In [28]:
from transformers import TrainingArguments, Trainer

In [30]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',            # output directory
    eval_strategy="epoch",       # evaluate at each epoch
    learning_rate=5e-5,                # learning rate
    per_device_train_batch_size=4,     # batch size for training
    per_device_eval_batch_size=4,      # batch size for evaluation
    num_train_epochs=3,                # number of training epochs
    weight_decay=0.01,                 # strength of weight decay
    save_total_limit=2,                # only keep last 2 checkpoints
)

In [31]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

### Training the model here

# Possible models to use

## DistilBART - distilled version of BART, which is much smaller than the full BART model but retains much of its performance. Since it is distilled, it's faster and more efficient while still being well-suited for summarization tasks. DistilBART is designed for text summarization, and the cnn-12-6 variant is trained on news articles, making it a viable medium sized model for summarizing legal documents.

## T5 (Text-to-Text Transfer Transformer) - Small or Base - T5 treats every task as a text-to-text problem, making it very flexible for summarization. The small and base variants offer a middle ground between performance and model size, making them suitable for use cases where computational resources are limited.

In [3]:
from datasets import load_dataset

### Here I load the datasets and edit some of the columns prior to tokenizing the datasets

In [7]:
# Load the datasets
ds1_train = load_dataset("joelniklaus/legal_case_document_summarization", split='train')
ds1_train = ds1_train.remove_columns(['dataset_name'])
ds1_train = ds1_train.rename_column('judgement', 'text')
ds1_train = ds1_train.rename_column('summary', 'label')
print(ds1_train)

# NOTE: This dataset only has 50 rows. It may not be a dataset we want to use.
# Although the summaries appear to be good
ds2_DatasetDict = load_dataset("manasvikalyan/legal-documents-summary")
ds2_actual = ds2_DatasetDict['data']
ds2_actual = ds2_actual.remove_columns(['summary_a2'])
ds2_actual = ds2_actual.rename_column('summary_a1', 'label')
ds2_actual = ds2_actual.rename_column('judgement', 'text')
print(ds2_actual)

ds3_train = load_dataset("coastalcph/lex_glue", "case_hold", split='train')
print(ds3_train)
ds4_train = load_dataset("coastalcph/lex_glue", "ecthr_a", split='train')
print(ds4_train)
ds5_train = load_dataset("coastalcph/lex_glue", "ecthr_b", split='train')
print(ds5_train)
ds6_train = load_dataset("coastalcph/lex_glue", "eurlex", split='train')
print(ds6_train)
ds7_train = load_dataset("coastalcph/lex_glue", "ledgar", split='train')
print(ds7_train)
ds8_train = load_dataset("coastalcph/lex_glue", "scotus", split='train')
print(ds8_train)


Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['text', 'label'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label'],
    num_rows: 50
})
Dataset({
    features: ['context', 'endings', 'label'],
    num_rows: 45000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 55000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 60000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 5000
})


### Here I am pre-processing the data for the DistilBART model

In [12]:
from transformers import BartTokenizer

In [14]:
# Load the BART tokenizer
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]



In [15]:
# Tokenization function for text and summaries
def tokenize_function(examples):
    # Tokenize both the input text and summary (label)
    model_inputs = tokenizer(examples['text'], max_length=512, truncation=True, padding='max_length')
    
    # Tokenize the labels (summaries)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['label'], max_length=150, truncation=True, padding='max_length')
    
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

### Here I am just Tokenizing 'ds1' and 'ds2' for DistilBART (ds1_train and ds2_actual)

### TODO: Tokenize the other datasets later

In [16]:
# Tokenize the datasets for DistilBART (ds1_train and ds2_actual)
ds1_train_tokenized = ds1_train.map(tokenize_function, batched=True)
ds2_actual_tokenized = ds2_actual.map(tokenize_function, batched=True)

Map:   0%|          | 0/7773 [00:00<?, ? examples/s]



Map:   0%|          | 0/50 [00:00<?, ? examples/s]

### Here I am concatenating the datasets to use all together

### TODO: concatenate the rest of the datasets later

In [20]:
from datasets import concatenate_datasets

In [23]:
# Concatenate/Merge the datasets
combined_dataset = concatenate_datasets([ds1_train_tokenized, ds2_actual_tokenized])

### Splitting the combined dataset into train and validation sets

### TODO: concatenate the rest of the datasets later

In [24]:
combined_dataset = combined_dataset.train_test_split(test_size=0.2)

train_dataset = combined_dataset['train']he 
val_dataset = combined_dataset['test']

### Load the DistilBART model here

In [26]:
from transformers import BartForConditionalGeneration

In [27]:
# Load the DistilBART model for conditional generation
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

### Setting up training arguments for the model here

### TODO: These can be modified later to improve the model

In [28]:
from transformers import TrainingArguments, Trainer

In [30]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',            # output directory
    eval_strategy="epoch",       # evaluate at each epoch
    learning_rate=5e-5,                # learning rate
    per_device_train_batch_size=4,     # batch size for training
    per_device_eval_batch_size=4,      # batch size for evaluation
    num_train_epochs=3,                # number of training epochs
    weight_decay=0.01,                 # strength of weight decay
    save_total_limit=2,                # only keep last 2 checkpoints
)

In [31]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

### Training the model here

# START HERE. HAVING ISSUES WITH RUNNING THIS RN

In [32]:
# Train the model
trainer.train()

ValueError: too many dimensions 'str'

In [35]:
print(combined_dataset['train'][0])  # Print the first training example to inspect

{'text': 'Appeals Nos. 934935 of 1963.\nAppeals from the judgment and orders dated August 12, 1960, and April 30, 1960, of the Madhya Pradesh High Court in Civil Suit No. 1 of 1958 and Misc.Petition No. 101 of 1958 respectively.\nC.K. Daphtary, Attorney General, R. Ganapathy Iyer and R. H. Dhebar, for the appellants (in both the appeals).\nM.C. Setalvad, K. A. Chitale, M. K. Nambyar.\nRameshwar Nath and section N. Andley, for the respondents (in both the appeals).\nApril 28, 1964.\nThe judgment of the Court was delivered by WANCHOO, J.\nThese two appeals on certificates granted by the Madhya Pradesh High Court raise common questions of law and will be dealt with together.\nThe respondent the Gwalior Rayon Silk Manufacturing (Weaving) Company Limited (hereinafter referred to as the company) is registered under the Indian Companies Act.\nIt is necessary to set out how the company came to be established in order to understand the case put forward by the company.\nIn October 1946 Messrs. B