In [1]:
import pandas as pd

dataset_dir = 'datasets'

# read the dataset
df = pd.read_csv(f'{dataset_dir}/podcast_with_summary_train.csv')
df.head()

Unnamed: 0,text,text_short,summary,summary2
0,The following is a conversation with Yosha Bac...,The following is a conversation with Yosha Bac...,Lex Friedman interviews Yosha Bach on intellig...,"Yosha Bach discusses intelligence, cognition, ..."
1,The following is a conversation with John Hopf...,The following is a conversation with John Hopf...,"John Hopfield, renowned physicist and biologis...","John Hopfield, a Princeton professor, applied ..."
2,The following is a conversation with Ilya Sots...,The following is a conversation with Ilya Sots...,A conversation with Ilya Sotskever on deep lea...,A conversation with Ilya Sotskever on deep lea...
3,The following is a conversation with Travis Ol...,The following is a conversation with Travis Ol...,Conversation with Travis Oliphant on his impac...,"Travis Oliphant's work with NumPy, SciPy, and ..."
4,"Well, the source of energy at the origin of li...","Well, the source of energy at the origin of li...",Nick Lane discusses life's energy origins from...,Nick Lane discusses the origin of life and ene...


In [2]:
# find the summary with the longest length
df['summary_length'] = df['summary'].apply(lambda x: len(x))
print(df['summary_length'].max())

# find the text with the longest length
df['text_length'] = df['text_short'].apply(lambda x: len(x))
print(df['text_length'].max())

175
1024


In [3]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments

dataset = Dataset.from_pandas(df)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_name = "t5-small"
# 60.5M params
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_data(examples):
    # Tokenize inputs
    model_inputs = tokenizer(
        examples["text_short"], max_length=1024, truncation=True, padding="max_length"
    )
    # Tokenize targets
    labels = tokenizer(
        examples["summary"], max_length=200, truncation=True, padding="max_length"
    )
    
    # Replace padding token id's of the labels by -100
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# Apply the preprocessing
tokenized_dataset = dataset.map(preprocess_data, batched=True)

Map: 100%|██████████| 255/255 [00:00<00:00, 2773.96 examples/s]


In [5]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

training_args = TrainingArguments(
    output_dir="./t5_output",
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=100,
    learning_rate=2e-5,
    per_device_train_batch_size=4,  
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=100,  
    logging_dir='./logs', 
    report_to="none", 
)



In [6]:
from transformers import DataCollatorForSeq2Seq, Trainer
import numpy as np

data_collator = DataCollatorForSeq2Seq(
    tokenizer, model=model, label_pad_token_id=-100
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Trainer(


In [7]:
trainer.train()

  0%|          | 0/6400 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  2%|▏         | 101/6400 [00:08<08:09, 12.86it/s]

{'loss': 3.0088, 'grad_norm': 9.208634376525879, 'learning_rate': 1.96875e-05, 'epoch': 1.56}


  3%|▎         | 201/6400 [00:15<07:58, 12.96it/s]

{'loss': 2.4699, 'grad_norm': 5.115659713745117, 'learning_rate': 1.9375e-05, 'epoch': 3.12}


  5%|▍         | 301/6400 [00:23<07:53, 12.87it/s]

{'loss': 2.326, 'grad_norm': 5.296049118041992, 'learning_rate': 1.9062500000000003e-05, 'epoch': 4.69}


  6%|▋         | 401/6400 [00:31<07:43, 12.94it/s]

{'loss': 2.2053, 'grad_norm': 5.426468372344971, 'learning_rate': 1.8750000000000002e-05, 'epoch': 6.25}


  8%|▊         | 500/6400 [00:38<07:38, 12.86it/s]

{'loss': 2.0841, 'grad_norm': 5.410431385040283, 'learning_rate': 1.84375e-05, 'epoch': 7.81}


                                                  
  8%|▊         | 500/6400 [00:40<07:38, 12.86it/s]

{'eval_loss': 1.7665706872940063, 'eval_runtime': 1.4885, 'eval_samples_per_second': 171.313, 'eval_steps_per_second': 42.996, 'epoch': 7.81}


  9%|▉         | 601/6400 [00:48<07:29, 12.89it/s]

{'loss': 2.0378, 'grad_norm': 4.19727897644043, 'learning_rate': 1.8125e-05, 'epoch': 9.38}


 11%|█         | 701/6400 [00:56<07:22, 12.89it/s]

{'loss': 1.9979, 'grad_norm': 4.812345027923584, 'learning_rate': 1.7812500000000003e-05, 'epoch': 10.94}


 13%|█▎        | 801/6400 [01:04<07:15, 12.85it/s]

{'loss': 1.9184, 'grad_norm': 4.945547580718994, 'learning_rate': 1.7500000000000002e-05, 'epoch': 12.5}


 14%|█▍        | 901/6400 [01:11<07:05, 12.93it/s]

{'loss': 1.8957, 'grad_norm': 5.673614501953125, 'learning_rate': 1.71875e-05, 'epoch': 14.06}


 16%|█▌        | 1000/6400 [01:19<07:15, 12.40it/s]

{'loss': 1.821, 'grad_norm': 5.465287208557129, 'learning_rate': 1.6875e-05, 'epoch': 15.62}


                                                   
 16%|█▌        | 1000/6400 [01:21<07:15, 12.40it/s]

{'eval_loss': 1.4676735401153564, 'eval_runtime': 1.5265, 'eval_samples_per_second': 167.049, 'eval_steps_per_second': 41.926, 'epoch': 15.62}


 17%|█▋        | 1101/6400 [01:29<06:52, 12.84it/s]

{'loss': 1.8089, 'grad_norm': 5.986147403717041, 'learning_rate': 1.6562500000000003e-05, 'epoch': 17.19}


 19%|█▉        | 1201/6400 [01:37<06:42, 12.92it/s]

{'loss': 1.7648, 'grad_norm': 4.055754661560059, 'learning_rate': 1.6250000000000002e-05, 'epoch': 18.75}


 20%|██        | 1301/6400 [01:45<06:34, 12.94it/s]

{'loss': 1.7132, 'grad_norm': 5.830413341522217, 'learning_rate': 1.59375e-05, 'epoch': 20.31}


 22%|██▏       | 1401/6400 [01:52<06:28, 12.86it/s]

{'loss': 1.6686, 'grad_norm': 3.733013391494751, 'learning_rate': 1.5625e-05, 'epoch': 21.88}


 23%|██▎       | 1500/6400 [02:00<06:18, 12.93it/s]

{'loss': 1.6355, 'grad_norm': 5.81425142288208, 'learning_rate': 1.5312500000000003e-05, 'epoch': 23.44}


                                                   
 23%|██▎       | 1500/6400 [02:01<06:18, 12.93it/s]

{'eval_loss': 1.2624425888061523, 'eval_runtime': 1.492, 'eval_samples_per_second': 170.912, 'eval_steps_per_second': 42.895, 'epoch': 23.44}


 25%|██▌       | 1601/6400 [02:10<06:05, 13.14it/s]

{'loss': 1.6024, 'grad_norm': 11.128982543945312, 'learning_rate': 1.5000000000000002e-05, 'epoch': 25.0}


 27%|██▋       | 1701/6400 [02:17<06:06, 12.82it/s]

{'loss': 1.5567, 'grad_norm': 4.157416820526123, 'learning_rate': 1.4687500000000001e-05, 'epoch': 26.56}


 28%|██▊       | 1801/6400 [02:25<05:54, 12.98it/s]

{'loss': 1.5702, 'grad_norm': 3.7350826263427734, 'learning_rate': 1.4375e-05, 'epoch': 28.12}


 30%|██▉       | 1901/6400 [02:33<05:50, 12.84it/s]

{'loss': 1.5108, 'grad_norm': 4.509944915771484, 'learning_rate': 1.4062500000000001e-05, 'epoch': 29.69}


 31%|███▏      | 2000/6400 [02:41<05:42, 12.85it/s]

{'loss': 1.4955, 'grad_norm': 5.321959972381592, 'learning_rate': 1.375e-05, 'epoch': 31.25}


                                                   
 31%|███▏      | 2000/6400 [02:42<05:42, 12.85it/s]

{'eval_loss': 1.1083637475967407, 'eval_runtime': 1.5148, 'eval_samples_per_second': 168.344, 'eval_steps_per_second': 42.251, 'epoch': 31.25}


 33%|███▎      | 2101/6400 [02:51<05:33, 12.90it/s]

{'loss': 1.4715, 'grad_norm': 4.337399482727051, 'learning_rate': 1.3437500000000001e-05, 'epoch': 32.81}


 34%|███▍      | 2201/6400 [02:58<05:26, 12.87it/s]

{'loss': 1.4422, 'grad_norm': 4.254682540893555, 'learning_rate': 1.3125e-05, 'epoch': 34.38}


 36%|███▌      | 2301/6400 [03:06<05:19, 12.83it/s]

{'loss': 1.4181, 'grad_norm': 5.717597961425781, 'learning_rate': 1.2812500000000001e-05, 'epoch': 35.94}


 38%|███▊      | 2401/6400 [03:14<05:09, 12.93it/s]

{'loss': 1.42, 'grad_norm': 5.130390167236328, 'learning_rate': 1.25e-05, 'epoch': 37.5}


 39%|███▉      | 2500/6400 [03:22<04:55, 13.20it/s]

{'loss': 1.3914, 'grad_norm': 4.639967918395996, 'learning_rate': 1.2187500000000001e-05, 'epoch': 39.06}


                                                   
 39%|███▉      | 2500/6400 [03:23<04:55, 13.20it/s]

{'eval_loss': 0.9921921491622925, 'eval_runtime': 1.4926, 'eval_samples_per_second': 170.839, 'eval_steps_per_second': 42.877, 'epoch': 39.06}


 41%|████      | 2601/6400 [03:31<04:55, 12.84it/s]

{'loss': 1.3721, 'grad_norm': 4.704996585845947, 'learning_rate': 1.1875e-05, 'epoch': 40.62}


 42%|████▏     | 2701/6400 [03:39<04:45, 12.94it/s]

{'loss': 1.3499, 'grad_norm': 3.732299327850342, 'learning_rate': 1.1562500000000002e-05, 'epoch': 42.19}


 44%|████▍     | 2801/6400 [03:47<04:41, 12.80it/s]

{'loss': 1.3412, 'grad_norm': 5.176841735839844, 'learning_rate': 1.125e-05, 'epoch': 43.75}


 45%|████▌     | 2901/6400 [03:55<04:31, 12.91it/s]

{'loss': 1.3147, 'grad_norm': 4.611674785614014, 'learning_rate': 1.0937500000000002e-05, 'epoch': 45.31}


 47%|████▋     | 3000/6400 [04:02<04:23, 12.89it/s]

{'loss': 1.2968, 'grad_norm': 4.700432777404785, 'learning_rate': 1.0625e-05, 'epoch': 46.88}


                                                   
 47%|████▋     | 3000/6400 [04:04<04:23, 12.89it/s]

{'eval_loss': 0.9012151956558228, 'eval_runtime': 1.4995, 'eval_samples_per_second': 170.055, 'eval_steps_per_second': 42.681, 'epoch': 46.88}


 48%|████▊     | 3101/6400 [04:12<04:16, 12.86it/s]

{'loss': 1.2703, 'grad_norm': 4.611576557159424, 'learning_rate': 1.0312500000000002e-05, 'epoch': 48.44}


 50%|█████     | 3201/6400 [04:20<04:02, 13.17it/s]

{'loss': 1.2618, 'grad_norm': 5.742932319641113, 'learning_rate': 1e-05, 'epoch': 50.0}


 52%|█████▏    | 3301/6400 [04:28<04:00, 12.87it/s]

{'loss': 1.2579, 'grad_norm': 4.094886779785156, 'learning_rate': 9.6875e-06, 'epoch': 51.56}


 53%|█████▎    | 3401/6400 [04:35<03:50, 13.01it/s]

{'loss': 1.2095, 'grad_norm': 4.293745517730713, 'learning_rate': 9.375000000000001e-06, 'epoch': 53.12}


 55%|█████▍    | 3500/6400 [04:43<03:45, 12.89it/s]

{'loss': 1.2334, 'grad_norm': 5.052093982696533, 'learning_rate': 9.0625e-06, 'epoch': 54.69}


                                                   
 55%|█████▍    | 3500/6400 [04:44<03:45, 12.89it/s]

{'eval_loss': 0.8242633938789368, 'eval_runtime': 1.4915, 'eval_samples_per_second': 170.969, 'eval_steps_per_second': 42.91, 'epoch': 54.69}


 56%|█████▋    | 3601/6400 [04:53<03:36, 12.92it/s]

{'loss': 1.1849, 'grad_norm': 4.96734619140625, 'learning_rate': 8.750000000000001e-06, 'epoch': 56.25}


 58%|█████▊    | 3701/6400 [05:01<03:30, 12.85it/s]

{'loss': 1.1816, 'grad_norm': 4.782904148101807, 'learning_rate': 8.4375e-06, 'epoch': 57.81}


 59%|█████▉    | 3801/6400 [05:08<03:21, 12.91it/s]

{'loss': 1.2102, 'grad_norm': 5.93049430847168, 'learning_rate': 8.125000000000001e-06, 'epoch': 59.38}


 61%|██████    | 3901/6400 [05:16<03:14, 12.82it/s]

{'loss': 1.1652, 'grad_norm': 3.7036774158477783, 'learning_rate': 7.8125e-06, 'epoch': 60.94}


 62%|██████▎   | 4000/6400 [05:24<03:05, 12.91it/s]

{'loss': 1.1631, 'grad_norm': 5.288204193115234, 'learning_rate': 7.500000000000001e-06, 'epoch': 62.5}


                                                   
 62%|██████▎   | 4000/6400 [05:25<03:05, 12.91it/s]

{'eval_loss': 0.7700837850570679, 'eval_runtime': 1.495, 'eval_samples_per_second': 170.569, 'eval_steps_per_second': 42.809, 'epoch': 62.5}


 64%|██████▍   | 4101/6400 [05:33<02:55, 13.07it/s]

{'loss': 1.1429, 'grad_norm': 4.74070930480957, 'learning_rate': 7.1875e-06, 'epoch': 64.06}


 66%|██████▌   | 4201/6400 [05:41<02:51, 12.85it/s]

{'loss': 1.163, 'grad_norm': 6.863603115081787, 'learning_rate': 6.875e-06, 'epoch': 65.62}


 67%|██████▋   | 4301/6400 [05:49<02:42, 12.94it/s]

{'loss': 1.1222, 'grad_norm': 5.683067321777344, 'learning_rate': 6.5625e-06, 'epoch': 67.19}


 69%|██████▉   | 4401/6400 [05:57<02:35, 12.90it/s]

{'loss': 1.1548, 'grad_norm': 5.209715366363525, 'learning_rate': 6.25e-06, 'epoch': 68.75}


 70%|███████   | 4500/6400 [06:04<02:26, 12.97it/s]

{'loss': 1.114, 'grad_norm': 6.12436580657959, 'learning_rate': 5.9375e-06, 'epoch': 70.31}


                                                   
 70%|███████   | 4500/6400 [06:06<02:26, 12.97it/s]

{'eval_loss': 0.7253891825675964, 'eval_runtime': 1.491, 'eval_samples_per_second': 171.026, 'eval_steps_per_second': 42.924, 'epoch': 70.31}


 72%|███████▏  | 4601/6400 [06:14<02:20, 12.84it/s]

{'loss': 1.1254, 'grad_norm': 5.196582317352295, 'learning_rate': 5.625e-06, 'epoch': 71.88}


 73%|███████▎  | 4701/6400 [06:22<02:12, 12.86it/s]

{'loss': 1.0961, 'grad_norm': 4.068312644958496, 'learning_rate': 5.3125e-06, 'epoch': 73.44}


 75%|███████▌  | 4801/6400 [06:30<02:01, 13.20it/s]

{'loss': 1.098, 'grad_norm': 4.541546821594238, 'learning_rate': 5e-06, 'epoch': 75.0}


 77%|███████▋  | 4901/6400 [06:37<01:56, 12.88it/s]

{'loss': 1.0802, 'grad_norm': 5.998882293701172, 'learning_rate': 4.6875000000000004e-06, 'epoch': 76.56}


 78%|███████▊  | 5000/6400 [06:45<01:47, 13.06it/s]

{'loss': 1.0966, 'grad_norm': 4.248074054718018, 'learning_rate': 4.3750000000000005e-06, 'epoch': 78.12}


                                                   
 78%|███████▊  | 5000/6400 [06:47<01:47, 13.06it/s]

{'eval_loss': 0.6925932765007019, 'eval_runtime': 1.4925, 'eval_samples_per_second': 170.854, 'eval_steps_per_second': 42.881, 'epoch': 78.12}


 80%|███████▉  | 5101/6400 [06:55<01:41, 12.85it/s]

{'loss': 1.098, 'grad_norm': 6.818365097045898, 'learning_rate': 4.0625000000000005e-06, 'epoch': 79.69}


 81%|████████▏ | 5201/6400 [07:03<01:33, 12.84it/s]

{'loss': 1.0465, 'grad_norm': 4.979048728942871, 'learning_rate': 3.7500000000000005e-06, 'epoch': 81.25}


 83%|████████▎ | 5301/6400 [07:10<01:25, 12.85it/s]

{'loss': 1.0825, 'grad_norm': 5.419297695159912, 'learning_rate': 3.4375e-06, 'epoch': 82.81}


 84%|████████▍ | 5401/6400 [07:18<01:17, 12.89it/s]

{'loss': 1.0469, 'grad_norm': 3.566602945327759, 'learning_rate': 3.125e-06, 'epoch': 84.38}


 86%|████████▌ | 5500/6400 [07:26<01:09, 12.97it/s]

{'loss': 1.0626, 'grad_norm': 4.869844913482666, 'learning_rate': 2.8125e-06, 'epoch': 85.94}


                                                   
 86%|████████▌ | 5500/6400 [07:27<01:09, 12.97it/s]

{'eval_loss': 0.6721423864364624, 'eval_runtime': 1.491, 'eval_samples_per_second': 171.026, 'eval_steps_per_second': 42.924, 'epoch': 85.94}


 88%|████████▊ | 5601/6400 [07:35<01:01, 12.91it/s]

{'loss': 1.054, 'grad_norm': 4.634754180908203, 'learning_rate': 2.5e-06, 'epoch': 87.5}


 89%|████████▉ | 5701/6400 [07:43<00:53, 13.08it/s]

{'loss': 1.0451, 'grad_norm': 5.654057502746582, 'learning_rate': 2.1875000000000002e-06, 'epoch': 89.06}


 91%|█████████ | 5801/6400 [07:51<00:46, 12.94it/s]

{'loss': 1.0467, 'grad_norm': 4.909338474273682, 'learning_rate': 1.8750000000000003e-06, 'epoch': 90.62}


 92%|█████████▏| 5901/6400 [07:59<00:38, 12.92it/s]

{'loss': 1.0809, 'grad_norm': 6.184578895568848, 'learning_rate': 1.5625e-06, 'epoch': 92.19}


 94%|█████████▍| 6000/6400 [08:06<00:30, 12.96it/s]

{'loss': 1.0491, 'grad_norm': 5.639376640319824, 'learning_rate': 1.25e-06, 'epoch': 93.75}


                                                   
 94%|█████████▍| 6000/6400 [08:08<00:30, 12.96it/s]

{'eval_loss': 0.6595089435577393, 'eval_runtime': 1.489, 'eval_samples_per_second': 171.255, 'eval_steps_per_second': 42.982, 'epoch': 93.75}


 95%|█████████▌| 6101/6400 [08:16<00:23, 12.85it/s]

{'loss': 1.0445, 'grad_norm': 4.149512767791748, 'learning_rate': 9.375000000000001e-07, 'epoch': 95.31}


 97%|█████████▋| 6201/6400 [08:24<00:15, 12.90it/s]

{'loss': 1.0497, 'grad_norm': 7.004147052764893, 'learning_rate': 6.25e-07, 'epoch': 96.88}


 98%|█████████▊| 6301/6400 [08:32<00:07, 12.90it/s]

{'loss': 1.0458, 'grad_norm': 5.085520267486572, 'learning_rate': 3.125e-07, 'epoch': 98.44}


100%|██████████| 6400/6400 [08:39<00:00, 12.88it/s]

{'loss': 1.0596, 'grad_norm': 7.01862096786499, 'learning_rate': 0.0, 'epoch': 100.0}


100%|██████████| 6400/6400 [08:40<00:00, 12.30it/s]

{'train_runtime': 520.1561, 'train_samples_per_second': 49.024, 'train_steps_per_second': 12.304, 'train_loss': 1.4070651566982268, 'epoch': 100.0}





TrainOutput(global_step=6400, training_loss=1.4070651566982268, metrics={'train_runtime': 520.1561, 'train_samples_per_second': 49.024, 'train_steps_per_second': 12.304, 'total_flos': 6902431875072000.0, 'train_loss': 1.4070651566982268, 'epoch': 100.0})

In [8]:
def run_inference(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=1024
    )
    device = model.device
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to GPU

    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=200, 
        num_beams=4,
        early_stopping=True
    )
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

In [9]:
test_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary_test.csv')

test_text = test_df.text_short[0]
print("Original:", test_text)

summary = run_inference(test_text)

print("Predicted:", summary)
print("Actual:", test_df.summary[0])

Original: The following is a conversation with Andrew Ng, one of the most impactful educators, researchers, innovators, and leaders in artificial intelligence and technology space in general. He cofounded Coursera and Google Brain, launched Deep Learning AI, Landing AI, and the AI Fund, and was the chief scientist at Baidu. As a Stanford professor and with Coursera and Deep Learning AI, he has helped educate and inspire millions of students, including me. This is the Artificial Intelligence Podcast. If you enjoy it, subscribe on YouTube, give it five stars on Apple Podcast, support it on Patreon, or simply connect with me on Twitter at Lex Friedman, spelled F R I D M A N. As usual, I'll do one or two minutes of ads now and never any ads in the middle that can break the flow of the conversation. I hope that works for you and doesn't hurt the listening experience. This show is presented by Cash App, the number one finance app in the App Store. When you get it, use code LEXPODCAST.
Predic

In [10]:
import time
from SharedUtils import evaluate_and_save_metrics

# iterate over all rows in test_df
def evaluate_df(df, name):
    reference_summaries = []
    predicted_summaries = []
    total_time = 0
    for i, row in df.iterrows():
        test_text = row.text_short
        
        start_time = time.time()
        summary = run_inference(test_text)
        end_time = time.time()
        elapsed_time = end_time - start_time
        total_time += elapsed_time
        
        reference_summaries.append(row.summary)
        predicted_summaries.append(summary)
    
    model_name = "t5-small"

    rouge_results, results_bleu = evaluate_and_save_metrics(
        model_name,
        name,
        "finetuned",
        reference_summaries,
        predicted_summaries,
        total_time
    )

    results_df = pd.DataFrame({
        'summary': reference_summaries,
        'summary_tuned': predicted_summaries
    })
    results_df.to_csv(f"./results/{model_name}/{name}/summaries.csv")

    print(rouge_results)
    print(results_bleu)
    print("Total time (seconds):", total_time)
    print("Total time (minutes): ", total_time / 60)
    

In [11]:
evaluate_df(test_df, "test_dataset")

{'rouge1': 0.42616871051572736, 'rouge2': 0.186207504189896, 'rougeL': 0.346356334240144, 'rougeLsum': 0.3457898933348439}
{'bleu': 0.11286839272559764, 'precisions': [0.4532095901005414, 0.1757526444263629, 0.07296137339055794, 0.0326975476839237], 'brevity_penalty': 0.961324598499326, 'length_ratio': 0.9620535714285714, 'translation_length': 1293, 'reference_length': 1344}
Total time (seconds): 13.28156065940857
Total time (minutes):  0.22135934432347615


In [12]:
train_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary_train.csv')
evaluate_df(train_df, 'train_dataset')

{'rouge1': 0.5203022353242932, 'rouge2': 0.309463613303818, 'rougeL': 0.47092757451794753, 'rougeLsum': 0.47013257034315753}
{'bleu': 0.2432040962010115, 'precisions': [0.556078431372549, 0.30691434468524253, 0.1934640522875817, 0.1344867358708189], 'brevity_penalty': 0.9421339254342243, 'length_ratio': 0.9437453737971873, 'translation_length': 5100, 'reference_length': 5404}
Total time (seconds): 56.27474498748779
Total time (minutes):  0.9379124164581298


In [13]:
whole_df = pd.read_csv(f'{dataset_dir}/podcast_with_summary.csv')
evaluate_df(whole_df, 'whole_dataset')

{'rouge1': 0.4924141255332442, 'rouge2': 0.27201495504419093, 'rougeL': 0.43571749420356526, 'rougeLsum': 0.43675843020484584}
{'bleu': 0.21425726757200045, 'precisions': [0.5198959289868381, 0.267739340305712, 0.16095658073270014, 0.10722610722610723], 'brevity_penalty': 0.9677787711511946, 'length_ratio': 0.9682868998221695, 'translation_length': 6534, 'reference_length': 6748}
Total time (seconds): 68.71482419967651
Total time (minutes):  1.1452470699946085
