In [58]:
import numpy as np
from datasets import load_dataset, Dataset
import pandas as pd
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

In [26]:
dataset = load_dataset("allenai/multi_lexsum", "v20230518", cache_dir="./data")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'sources', 'sources_metadata', 'summary/long', 'summary/short', 'summary/tiny', 'case_metadata'],
        num_rows: 3177
    })
    validation: Dataset({
        features: ['id', 'sources', 'sources_metadata', 'summary/long', 'summary/short', 'summary/tiny', 'case_metadata'],
        num_rows: 454
    })
    test: Dataset({
        features: ['id', 'sources', 'sources_metadata', 'summary/long', 'summary/short', 'summary/tiny', 'case_metadata'],
        num_rows: 908
    })
})


In [45]:
train_df = pd.DataFrame(dataset['train'])
validation_df = pd.DataFrame(dataset['validation'])
test_df = pd.DataFrame(dataset['test'])

In [46]:
df = pd.concat([train_df, validation_df, test_df], keys=['train', 'validation', 'test'], ignore_index=False)

In [47]:
df.head(5)

Unnamed: 0,Unnamed: 1,id,sources,sources_metadata,summary/long,summary/short,summary/tiny,case_metadata
train,0,EE-AL-0045,[Case 1:05-cv-00530-D Document 1-1 Filed 09/19...,"{'doc_id': ['EE-AL-0045-0001', 'EE-AL-0045-000...","On September 15, 2005, the Equal Employment Op...",Equal Employment Opportunity Commission brough...,,{'case_name': 'EEOC v. House of Philadelphia C...
train,1,PB-NJ-0003,[Case 3:05-cv-01784-SRC-JJH Document 2 Filed 0...,"{'doc_id': ['PB-NJ-0003-0001', 'PB-NJ-0003-000...",NOTE: This is one of three identically named ...,The case was brought by a non-profit organizat...,,{'case_name': 'Disability Rights New Jersey v....
train,2,EE-FL-0136,[Case 9:07-cv-80713-KAM Document 1 Entered on ...,"{'doc_id': ['EE-FL-0136-0001', 'EE-FL-0136-000...","On August 9, 2007, the United States Departmen...",,,{'case_name': 'United States v. Palm Beach Cou...
train,3,EE-CA-0305,[2006 WL 1787244\n2006 WL 1787244 (N.D.Cal.) (...,"{'doc_id': ['EE-CA-0305-0001', 'EE-CA-0305-000...","On May 11, 2006, African-American employees of...",This case was brought by African American empl...,,{'case_name': 'Wynne v. McCormick & Schmick's ...
train,4,NH-NJ-0002,[IN THE UNITED STATES DISTRICT COURT FOR THE D...,"{'doc_id': ['NH-NJ-0002-0001', 'NH-NJ-0002-000...",Pursuant to the Civil Rights of Institutionali...,Pursuant to the Civil Rights of Institutionali...,,"{'case_name': 'U.S. v. Mercer County, New Jers..."


In [53]:
df.isna().sum()

id                     0
sources                0
sources_metadata       0
summary/long           0
summary/short       1401
summary/tiny        2936
case_metadata          0
summary                0
dtype: int64

In [55]:
df['document'] = df['sources'].apply(lambda x: ' '.join(x))
df['summary'] = df['summary/long']
df.loc[:, ['document', 'summary']].head(5)

Unnamed: 0,Unnamed: 1,document,summary
train,0,Case 1:05-cv-00530-D Document 1-1 Filed 09/19/...,"On September 15, 2005, the Equal Employment Op..."
train,1,Case 3:05-cv-01784-SRC-JJH Document 2 Filed 05...,NOTE: This is one of three identically named ...
train,2,Case 9:07-cv-80713-KAM Document 1 Entered on F...,"On August 9, 2007, the United States Departmen..."
train,3,2006 WL 1787244\n2006 WL 1787244 (N.D.Cal.) (T...,"On May 11, 2006, African-American employees of..."
train,4,IN THE UNITED STATES DISTRICT COURT FOR THE DI...,Pursuant to the Civil Rights of Institutionali...


In [30]:
checkpoint = "t5-large"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, cache_dir="../models")
model = TFAutoModelForSeq2SeqLM.from_pretrained(checkpoint, cache_dir="../models")

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [59]:
train_dataset = Dataset.from_pandas(df.xs('train'))
val_dataset = Dataset.from_pandas(df.xs('validation'))
test_dataset = Dataset.from_pandas(df.xs('test'))

In [61]:
def preprocess_function(train):
    inputs = [doc for doc in train['document']]
    targets = [summary for summary in train['summary']]
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="tf")
    labels = tokenizer(text_target=targets, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
    
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# Apply the preprocessing function to the datasets
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/3177 [00:00<?, ? examples/s]