In [2]:
from datasets import load_dataset

In [3]:
# Load the datasets
# ds1_train = load_dataset("joelniklaus/legal_case_document_summarization", split='train')
# ds1_train = ds1_train.remove_columns(['dataset_name'])
# ds1_train = ds1_train.rename_column('judgement', 'text')
# ds1_train = ds1_train.rename_column('summary', 'labels')
# print(ds1_train)

# ds1_test = load_dataset("joelniklaus/legal_case_document_summarization", split='test')
# ds1_test = ds1_test.remove_columns(['dataset_name'])
# ds1_test = ds1_test.rename_column('judgement', 'text')
# ds1_test = ds1_test.rename_column('summary', 'labels')

# # NOTE: This dataset only has 50 rows. It may not be a dataset we want to use.
# # NOTE: THIS DATA IS NOT PLAYING NICELY WITH CONCATENATION
# # Although the summaries appear to be good
# ds2 = load_dataset("manasvikalyan/legal-documents-summary")
# ds2 = ds2['data']
# ds2 = ds2.remove_columns(['summary_a2'])
# ds2 = ds2.rename_column('summary_a1', 'labels')
# ds2 = ds2.rename_column('judgement', 'text')
# print(ds2)

In [4]:
# ds9: AjayMukundS/Legal_Text_Summarization-llama2
ds9_train = load_dataset("AjayMukundS/Legal_Text_Summarization-llama2", split='train')
ds9_test = load_dataset("AjayMukundS/Legal_Text_Summarization-llama2", split='test')
print(ds9_train)

Dataset({
    features: ['judgement', 'dataset_name', 'summary', 'text'],
    num_rows: 7773
})


In [5]:
from transformers import BartTokenizer

In [6]:
# Load the BART tokenizer
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')



In [7]:
from transformers import AutoTokenizer

In [8]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")

In [9]:
# Tokenize the input texts
def preprocess_data(batch):
    inputs = batch['text']
    targets = batch['summary']
    
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [10]:
# Tokenize the datasets for DistilBART
# Apply preprocessing to the dataset
train_dataset = ds9_train.map(preprocess_data, batched=True, remove_columns=ds9_train.column_names)
test_dataset = ds9_test.map(preprocess_data, batched=True, remove_columns=ds9_test.column_names)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [11]:
# Set dataset format for PyTorch
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [12]:
# Taking 10 examples from each tokenized set to test the training.
train_sample = train_dataset.select(range(30))
test_sample = test_dataset.select(range(10))

print(train_sample)
print(test_sample)


Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 30
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 10
})


In [13]:
from datasets import concatenate_datasets

In [14]:
# Maybe do later?

In [15]:
from transformers import BartForConditionalGeneration

In [16]:
# Load the DistilBART model for conditional generation
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

In [17]:
# To handle padding dynamically (i.e., pad to the longest sequence in a batch rather than a fixed length)
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [18]:
from transformers import TrainingArguments, Trainer, logging

In [19]:
import torch
from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device
print(f"Accelerator is using device: {device}")

Accelerator is using device: cpu


In [20]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir="./logs",
    logging_steps=100,
)



In [21]:
# Initialize the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_sample,
    eval_dataset=test_sample,
    tokenizer=tokenizer
)

In [22]:
# Train the model
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,2.389033
2,No log,2.340875
3,No log,2.342005




TrainOutput(global_step=24, training_loss=2.1253064473470054, metrics={'train_runtime': 339.6156, 'train_samples_per_second': 0.265, 'train_steps_per_second': 0.071, 'total_flos': 139312087695360.0, 'train_loss': 2.1253064473470054, 'epoch': 3.0})

In [23]:
# Evaluate the model
eval_results = trainer.evaluate()
print(eval_results)

{'eval_loss': 2.3420045375823975, 'eval_runtime': 9.9767, 'eval_samples_per_second': 1.002, 'eval_steps_per_second': 0.301, 'epoch': 3.0}


In [None]:
import random

# Get a random index
random_index = random.randint(0, len(ds9_test) - 1)

# Access the 'text' property of the randomly selected sample
random_sample_text = ds9_test[random_index]['text']
print(random_sample_text)

In [27]:
# Example of generating a summary

inputs = [random_sample_text]

# Tokenize the input
input_ids = tokenizer(inputs, return_tensors="pt", max_length=1024, truncation=True, padding="max_length").input_ids

# Generate summary
outputs = model.generate(input_ids=input_ids, max_length=128, num_beams=4, early_stopping=True)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(summary)



The plaintiffs applied for an interim injunction restraining the Bank of India Ltd. the first defendant from taking any steps in pursuance of a letter of credit opened in favour of M/s. V/O Tractors Export, Moscow, the second defendant.
The High Court observed that an order granting interim injunction "is a final order, as far as this Court is concerned"


In [None]:
# Save your model
model.save_pretrained('./trained_HX100_model_2')

# Save the tokenizer
tokenizer.save_pretrained('./trained_HX100_model_2')

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the trained model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('./trained_HX100_model_2')
tokenizer = AutoTokenizer.from_pretrained('./trained_HX100_model_2')

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [28]:
print(ds9_test)

Dataset({
    features: ['judgement', 'dataset_name', 'summary', 'text'],
    num_rows: 200
})


In [29]:
print(ds9_test.column_names)

['judgement', 'dataset_name', 'summary', 'text']


In [35]:
# Randomly select a sample from the test dataset for comparison
sample_index = random.randint(0, len(ds9_test) - 1)
sample = ds9_test[sample_index]

# Original text and summary
original_text = sample['text']
original_summary = sample['summary']

# Tokenize the input
input_ids = tokenizer([original_text], return_tensors="pt", max_length=1024, truncation=True, padding="max_length").input_ids

# Generate summary with the model
outputs = model.generate(input_ids=input_ids, max_length=1024, num_beams=4, early_stopping=True)
generated_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print both summaries for comparison
print("------------------------------------------------------------------------------------------")
print(f"Example {idx + 1}:")
print("Original Summary:", original_summary)
print("*****************************************************************************************")
print("Generated Summary:", generated_summary)
print("\n")

------------------------------------------------------------------------------------------
Example 190:
Original Summary: Section 28 of the Housing Act 1988 The Housing Act 1988 (the 1988 Act) was brought in with a view to stimulating the availability of rented accommodation in the private sector; it allowed landlords to let new tenancies on terms more advantageous to themselves [4].
Parliament included safeguards to deter unscrupulous landlords from evicting existing tenants with protected tenancies [15].
Section 27 of the 1988 Act provides the right to claim damages for unlawful eviction [6].
Section 28 sets out the method by which such damages are calculated, being the difference in value between two alternative calculations of the landlords interest in the building at the time immediately prior to the unlawful eviction: (1) The basis for the assessment of damages [for unlawful eviction] is the difference in value, determined as at the time immediately before the residential occupie