# Possible models to use

## DistilBART - distilled version of BART, which is much smaller than the full BART model but retains much of its performance. Since it is distilled, it's faster and more efficient while still being well-suited for summarization tasks. DistilBART is designed for text summarization, and the cnn-12-6 variant is trained on news articles, making it a viable medium sized model for summarizing legal documents.

## T5 (Text-to-Text Transfer Transformer) - Small or Base - T5 treats every task as a text-to-text problem, making it very flexible for summarization. The small and base variants offer a middle ground between performance and model size, making them suitable for use cases where computational resources are limited.

In [23]:
from datasets import load_dataset

### Here I load the datasets and edit some of the columns prior to tokenizing the datasets

In [24]:
# Load the datasets
ds1_train = load_dataset("joelniklaus/legal_case_document_summarization", split='train')
ds1_train = ds1_train.remove_columns(['dataset_name'])
ds1_train = ds1_train.rename_column('judgement', 'text')
ds1_train = ds1_train.rename_column('summary', 'label')
print(ds1_train)

# NOTE: This dataset only has 50 rows. It may not be a dataset we want to use.
# NOTE: THIS DATA IS NOT PLAYING NICELY WITH CONCATENATION
# Although the summaries appear to be good
ds2_DatasetDict = load_dataset("manasvikalyan/legal-documents-summary")
ds2_actual = ds2_DatasetDict['data']
ds2_actual = ds2_actual.remove_columns(['summary_a2'])
ds2_actual = ds2_actual.rename_column('summary_a1', 'label')
ds2_actual = ds2_actual.rename_column('judgement', 'text')
print(ds2_actual)

ds3_train = load_dataset("coastalcph/lex_glue", "case_hold", split='train')
print(ds3_train)
ds4_train = load_dataset("coastalcph/lex_glue", "ecthr_a", split='train')
print(ds4_train)
ds5_train = load_dataset("coastalcph/lex_glue", "ecthr_b", split='train')
print(ds5_train)
ds6_train = load_dataset("coastalcph/lex_glue", "eurlex", split='train')
print(ds6_train)
ds7_train = load_dataset("coastalcph/lex_glue", "ledgar", split='train')
print(ds7_train)
ds8_train = load_dataset("coastalcph/lex_glue", "scotus", split='train')
print(ds8_train)


Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['text', 'label'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label'],
    num_rows: 50
})
Dataset({
    features: ['context', 'endings', 'label'],
    num_rows: 45000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 9000
})
Dataset({
    features: ['text', 'labels'],
    num_rows: 55000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 60000
})
Dataset({
    features: ['text', 'label'],
    num_rows: 5000
})


### Here I am pre-processing the data for the DistilBART model

In [25]:
from transformers import BartTokenizer

In [26]:
# Load the BART tokenizer
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')



In [27]:
# Tokenization function for text and summaries
def tokenize_function(examples):
    # Tokenize the input text
    inputs = tokenizer(examples['text'], max_length=512, truncation=True, padding='max_length')
    
    # Tokenize the output summary labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['label'], max_length=150, truncation=True, padding='max_length')

    # Set the tokenized labels in the input dictionary
    inputs['labels'] = labels['input_ids']
    
    return inputs

### Here I am just Tokenizing 'ds1' and 'ds2' for DistilBART (ds1_train and ds2_actual)

### TODO: Tokenize the other datasets later

In [28]:
# Tokenize the datasets for DistilBART (ds1_train and ds2_actual)
ds1_train_tokenized = ds1_train.map(tokenize_function, batched=True)
# ds2_actual_tokenized = ds2_actual.map(tokenize_function, batched=True) <-- NOT PLAYING NICELY RN

### TODO: set the other dataset formats later:

In [29]:
# Set the dataset format to PyTorch tensors
print(ds1_train_tokenized)
ds1_train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
print(ds1_train_tokenized)

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 7773
})
Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 7773
})


### Here I am concatenating the datasets to use all together

### TODO: concatenate the rest of the datasets later

In [30]:
from datasets import concatenate_datasets

In [31]:
# Concatenate/Merge the datasets
# FIX LATER --> combined_dataset = concatenate_datasets([ds1_train_tokenized, ds2_actual_tokenized])

### Splitting the combined dataset into train and validation sets

### TODO: concatenate the rest of the datasets later

In [32]:
# FIX LATER --> combined_dataset = combined_dataset.train_test_split(test_size=0.2)
ds1_train_tokenized = ds1_train_tokenized.train_test_split(test_size=0.2)
train_dataset = ds1_train_tokenized['train'] 
val_dataset = ds1_train_tokenized['test']

### Load the DistilBART model here

In [33]:
from transformers import BartForConditionalGeneration

In [34]:
# Load the DistilBART model for conditional generation
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

### Setting up training arguments for the model here

### TODO: These can be modified later to improve the model

In [35]:
from transformers import TrainingArguments, Trainer

In [36]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',            # output directory
    eval_strategy="epoch",       # evaluate at each epoch
    learning_rate=5e-5,                # learning rate
    per_device_train_batch_size=4,     # batch size for training
    per_device_eval_batch_size=4,      # batch size for evaluation
    num_train_epochs=3,                # number of training epochs
    weight_decay=0.01,                 # strength of weight decay
    save_total_limit=2,                # only keep last 2 checkpoints
)

In [37]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

### Training the model here

In [None]:
# Train the model
trainer.train()

Epoch,Training Loss,Validation Loss


{'text': 'Civil Appeal No. 1046 of 1982.\nFrom the Order dated 20.9.1980 of the Madhya Pradesh High Court in M.P. No. 84 of 1978.\nDr. N.M. Ghatate, S.V. Deshpande and S.K. Agnihotri for the Appellants.\nAman Vachher, S.K. mehta, Mrs. Anjali Verma, D.N. Mishra (for JBD & Co.) and Ashok Srivastava for the Respondents.\nThe Judgment of the Court was delivered by T.K. THOMMEN, J.\nThis appeal by the State of Madhya 176 Pradesh arises from the Order of the Madhya Pradesh High Court in Misc.\nPetition No.84 of 1978 quashing Order dated 1.10.1977 of the Additional Collector, Gwalior, whereby he initiated proceedings against the 3rd respondent, the Gwalior Dairy Limited (hereinafter called `the Company \') under section 182(2)(i) of the M.P. Land Revenue Code, 1959 (`the Code \').\nRespondent Nos.\n1,2 and 4 are shareholders of the third respondent.\nThe High Court by the impugned Order held that the Company was not a Government lessee within the meaning of section 181 [read with section 2(h)