In [1]:
import os
import gc
import random
from pathlib import Path

import datasets
import nltk
import numpy as np
import pandas as pd
import torch
import transformers

from datasets import Dataset, concatenate_datasets
from evaluate import combine, load
from functional import seq
from huggingface_hub import notebook_login
from IPython.display import HTML, display
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
                          DataCollatorForSeq2Seq, GenerationConfig,
                          Seq2SeqTrainer, Seq2SeqTrainingArguments)

from funcutils import get

os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
assert torch.cuda.is_available()

NUM_TRAIN_EPOCHS = 5
IS_MULTI_CORPUS = True
TASK = 'mt' # 'd2s' or 's2d' or 'mt' pull from argv
MODEL_CKPNT = "t5-small" # t5-small or t5-base
NATURAL_LANGUAGE = "nl"
STRUCTURED_DATA = "sd"
MULTI_CORP = '-multicorp' if IS_MULTI_CORPUS else ""
LR = 2.0e-4
TRAIN_CHKPNT_NAME = f"models/{MODEL_CKPNT}-finetuned-webnlg-{TASK}-{LR:.1e}{MULTI_CORP}"

TARGET = NATURAL_LANGUAGE if TASK == 'd2s' else STRUCTURED_DATA 
INPUT = STRUCTURED_DATA if TASK == 'd2s' else NATURAL_LANGUAGE 
TRAIN_CHKPNT_NAME

'models/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp'

In [2]:
assert TARGET != INPUT
del NATURAL_LANGUAGE
del STRUCTURED_DATA

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPNT)

In [4]:
max_input_length = 256
max_target_length = 256
tokenize = lambda x: tokenizer(x, max_length = max_input_length, truncation=True, padding=True)
tokenize

<function __main__.<lambda>(x)>

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CKPNT)
model = model.to(device)
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [6]:
generation_config = GenerationConfig.from_pretrained(MODEL_CKPNT)
generation_config.min_length = 5
generation_config.num_beams = 4
generation_config.max_length = 2048
generation_config.early_stopping = True
generation_config.no_repeat_ngram_size = 2
generation_config.temperature = .9

In [7]:
batch_size = 64 if MODEL_CKPNT == "t5-small" else 16
# START: ADAPTED FROM https://huggingface.co/docs/transformers/tasks/summarization
args = Seq2SeqTrainingArguments(
    TRAIN_CHKPNT_NAME,
    eval_steps=1500,
    evaluation_strategy = "steps",
    learning_rate=LR,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,
    gradient_accumulation_steps=2 if MODEL_CKPNT != 't5-small' else 1, # so we have an effective batch size of 32
    weight_decay=0.01,
    save_total_limit=5,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    save_steps=600,
    generation_config=generation_config,
    generation_max_length=200,
)
# END: ADAPTED FROM https://huggingface.co/docs/transformers/tasks/summarization

In [8]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
metric = combine([
    load("rouge"),
])
metric

<evaluate.module.CombinedEvaluations at 0x7f52d8687dc0>

In [9]:
# START: COPIED FROM https://huggingface.co/docs/transformers/tasks/summarization
def compute_metrics(eval_pred):
    # monitor memory and force gc. probably slows us down, probably 
    torchmem = torch.cuda.memory_allocated()
    torchcap = torch.cuda.get_device_properties(0).total_memory

    print(f"torch has allocated {torchmem} of {torchcap}")

    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return result
# END: COPIED FROM https://huggingface.co/docs/transformers/tasks/summarization

In [10]:
df = pd.read_pickle("~/repos/nlgs-research/pipeline/normalized_data/webnlg_clean.pkl")

if IS_MULTI_CORPUS:
    df = pd.read_pickle("~/repos/nlgs-research/pipeline/normalized_data/webnlg_wikibio_joint.pkl")
df

Unnamed: 0,subset,category,index,sd,nl
4341,train,SportsTeam,2722,Azerbaijan Premier League|champions|Qarabag FK,[The name of the championship football team in...
3299,train,WikBio,393200000,Dov Sternberg|nationality|American;Dov Sternbe...,[Dov Sternberg is an American karateka.]
9142,train,Politician,7523,Abraham A. Ribicoff|successor|John N. Dempsey;...,[Abraham A Ribicoff was born in the U.S. and w...
1157,dev,City,1157,"Albuquerque, New Mexico|area code|505, 575; Al...","[Albuquerque, in New Mexico, has a total are o..."
1257,dev,Monument,1257,"Adams County, Pennsylvania|has to its west|Fra...",[The 11th Mississippi Infantry Monument is a C...
...,...,...,...,...,...
4404,train,SportsTeam,2785,Massimo Drago|club|S.S.D. Potenza Calcio,[Massimo Drago played for S.S.D. Potenza Calci...
1176,dev,City,1176,"United States|demonym|Americans; Albany, Georg...",[The people inhabiting the United States are k...
9363,train,SportsTeam,7744,A.D. Isidro Metapan|ground|Estadio Jorge Caler...,[A.D. (Asociacion Deportiva) Isidro Metapan pl...
2167,train,WikBio,148200000,David Cooke|occupation|rugby union internation...,[David Cooke is a former a rugby union interna...


 we must invent `seed_number` since d2s can output multiple sentences for the
 same data input. So the seed will be a generation parameter, in the case that
 we are working in a deterministic environment, so generation can vary as
 desired. This computes a cartesian product.

In [11]:
cartesian_sd_nl = []
for (i, subset, cat, indx, sd, nl) in df.itertuples():
    for j, nl_option in enumerate(nl):
        pairing = dict(
            record_idx=i,
            seed_number=j,
            subset=subset,
            category=cat,
            split_index=indx,
            sd=sd,
            nl=nl_option,
            task=TASK if TASK != 'mt' else 's2d' 
        )
        cartesian_sd_nl.append(pairing)
        if TASK == "mt":
            reverse_pair = pairing.copy()
            reverse_pair['sd'] = nl_option
            reverse_pair['nl'] = sd
            reverse_pair['task'] = 'd2s'
            cartesian_sd_nl.append(reverse_pair)

# calling this "flattened" because it no longer has nested records
has_not_run = True
flt = pd.DataFrame(cartesian_sd_nl)
flt

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task
0,4341,0,train,SportsTeam,2722,Azerbaijan Premier League|champions|Qarabag FK,The name of the championship football team in ...,s2d
1,4341,0,train,SportsTeam,2722,The name of the championship football team in ...,Azerbaijan Premier League|champions|Qarabag FK,d2s
2,4341,1,train,SportsTeam,2722,Azerbaijan Premier League|champions|Qarabag FK,Qarabag FK are the champions of the Azerbaijan...,s2d
3,4341,1,train,SportsTeam,2722,Qarabag FK are the champions of the Azerbaijan...,Azerbaijan Premier League|champions|Qarabag FK,d2s
4,4341,2,train,SportsTeam,2722,Azerbaijan Premier League|champions|Qarabag FK,The champions of the Azerbaijan Premier League...,s2d
...,...,...,...,...,...,...,...,...
94865,10126,0,train,Artist,8507,"Alfredo Zitarrosa, born in Uruguay, is a music...",Alfredo Zitarrosa|record label|RCA Records; Al...,d2s
94866,10126,1,train,Artist,8507,Alfredo Zitarrosa|record label|RCA Records; Al...,Singer Alfredo Zitarrosa is associated with Ta...,s2d
94867,10126,1,train,Artist,8507,Singer Alfredo Zitarrosa is associated with Ta...,Alfredo Zitarrosa|record label|RCA Records; Al...,d2s
94868,10126,2,train,Artist,8507,Alfredo Zitarrosa|record label|RCA Records; Al...,"Alfredo Zitarrosa, born in Uruguay, plays Taqu...",s2d


In [12]:
# prepend the seed number. This should be rt of the prompt hereafter for `d2s`
# tasks. So, prompting with two different numbers should never generate the same
# output.

if (TASK == "mt") and has_not_run:
    has_not_run = False
    flt['sd'] = flt.task + flt.seed_number.map(lambda x: " " + str(x) + ": ") + flt.sd

    # allow the model to code switch between corpora
    if IS_MULTI_CORPUS:
        flt['sd'] = flt.category.map(lambda x: 'wb' if x == 'WikiBio' else "") + flt.sd
flt

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task
0,4341,0,train,SportsTeam,2722,s2d 0: Azerbaijan Premier League|champions|Qar...,The name of the championship football team in ...,s2d
1,4341,0,train,SportsTeam,2722,d2s 0: The name of the championship football t...,Azerbaijan Premier League|champions|Qarabag FK,d2s
2,4341,1,train,SportsTeam,2722,s2d 1: Azerbaijan Premier League|champions|Qar...,Qarabag FK are the champions of the Azerbaijan...,s2d
3,4341,1,train,SportsTeam,2722,d2s 1: Qarabag FK are the champions of the Aze...,Azerbaijan Premier League|champions|Qarabag FK,d2s
4,4341,2,train,SportsTeam,2722,s2d 2: Azerbaijan Premier League|champions|Qar...,The champions of the Azerbaijan Premier League...,s2d
...,...,...,...,...,...,...,...,...
94865,10126,0,train,Artist,8507,"d2s 0: Alfredo Zitarrosa, born in Uruguay, is ...",Alfredo Zitarrosa|record label|RCA Records; Al...,d2s
94866,10126,1,train,Artist,8507,s2d 1: Alfredo Zitarrosa|record label|RCA Reco...,Singer Alfredo Zitarrosa is associated with Ta...,s2d
94867,10126,1,train,Artist,8507,d2s 1: Singer Alfredo Zitarrosa is associated ...,Alfredo Zitarrosa|record label|RCA Records; Al...,d2s
94868,10126,2,train,Artist,8507,s2d 2: Alfredo Zitarrosa|record label|RCA Reco...,"Alfredo Zitarrosa, born in Uruguay, plays Taqu...",s2d


In [13]:
tokenized = tokenize(list(flt[INPUT].values))

 !!Heads-up!! The following fields comprise the "interface" of the model,
 despite the fact the documentation doesn't make this obvious. Without these
 particular names, ['input_ids', 'attention_mask', 'labels'],
 the model will not train and provide cryptic error messges.

In [14]:
flt['input_ids'] = tokenized['input_ids']
flt['attention_mask'] = tokenized['attention_mask']
flt['labels'] = flt[TARGET].map(lambda x: tokenize(x)['input_ids'])
flt['input_ids'].map(len)

0        256
1        256
2        256
3        256
4        256
        ... 
94865    256
94866    256
94867    256
94868    256
94869    256
Name: input_ids, Length: 94870, dtype: int64

In [15]:
flt

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels
0,4341,0,train,SportsTeam,2722,s2d 0: Azerbaijan Premier League|champions|Qar...,The name of the championship football team in ...,s2d,"[37, 564, 13, 8, 10183, 3370, 372, 16, 8, 71, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 3, 632, 10, 71, 2558, 9441, 70..."
1,4341,0,train,SportsTeam,2722,d2s 0: The name of the championship football t...,Azerbaijan Premier League|champions|Qarabag FK,d2s,"[71, 2558, 9441, 7066, 6552, 3815, 9175, 17788...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 37, 564, 13, 8, 10..."
2,4341,1,train,SportsTeam,2722,s2d 1: Azerbaijan Premier League|champions|Qar...,Qarabag FK are the champions of the Azerbaijan...,s2d,"[1593, 2551, 7893, 377, 439, 33, 8, 6336, 7, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 71, 2558, 9441, 7066,..."
3,4341,1,train,SportsTeam,2722,d2s 1: Qarabag FK are the champions of the Aze...,Azerbaijan Premier League|champions|Qarabag FK,d2s,"[71, 2558, 9441, 7066, 6552, 3815, 9175, 17788...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 1593, 2551, 7893, 377..."
4,4341,2,train,SportsTeam,2722,s2d 2: Azerbaijan Premier League|champions|Qar...,The champions of the Azerbaijan Premier League...,s2d,"[37, 6336, 7, 13, 8, 71, 2558, 9441, 7066, 655...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 71, 2558, 9441, 7066,..."
...,...,...,...,...,...,...,...,...,...,...,...
94865,10126,0,train,Artist,8507,"d2s 0: Alfredo Zitarrosa, born in Uruguay, is ...",Alfredo Zitarrosa|record label|RCA Records; Al...,d2s,"[19850, 32, 3969, 2046, 1859, 9, 9175, 60, 762...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 19850, 32, 3969, 2..."
94866,10126,1,train,Artist,8507,s2d 1: Alfredo Zitarrosa|record label|RCA Reco...,Singer Alfredo Zitarrosa is associated with Ta...,s2d,"[24366, 19850, 32, 3969, 2046, 1859, 9, 19, 19...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 19850, 32, 3969, 2046..."
94867,10126,1,train,Artist,8507,d2s 1: Singer Alfredo Zitarrosa is associated ...,Alfredo Zitarrosa|record label|RCA Records; Al...,d2s,"[19850, 32, 3969, 2046, 1859, 9, 9175, 60, 762...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 24366, 19850, 32, 396..."
94868,10126,2,train,Artist,8507,s2d 2: Alfredo Zitarrosa|record label|RCA Reco...,"Alfredo Zitarrosa, born in Uruguay, plays Taqu...",s2d,"[19850, 32, 3969, 2046, 1859, 9, 6, 2170, 16, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 19850, 32, 3969, 2046..."


In [16]:
# this will keep only the needed fields in memory on the GPU
def pd_to_dataset(df: pd.DataFrame, split='train') -> Dataset:
    print(df)
    d = df[df.subset== split][['input_ids','attention_mask','labels']]
    return Dataset.from_pandas(d)
        
# get_ds alias should bake in the desired argument. Makes you wish python
# supported currying
get_ds = lambda x: pd_to_dataset(flt, x)
tds = get_ds('train')
eds = get_ds('dev')
tds

       record_idx  seed_number subset    category  split_index  \
0            4341            0  train  SportsTeam         2722   
1            4341            0  train  SportsTeam         2722   
2            4341            1  train  SportsTeam         2722   
3            4341            1  train  SportsTeam         2722   
4            4341            2  train  SportsTeam         2722   
...           ...          ...    ...         ...          ...   
94865       10126            0  train      Artist         8507   
94866       10126            1  train      Artist         8507   
94867       10126            1  train      Artist         8507   
94868       10126            2  train      Artist         8507   
94869       10126            2  train      Artist         8507   

                                                      sd  \
0      s2d 0: Azerbaijan Premier League|champions|Qar...   
1      d2s 0: The name of the championship football t...   
2      s2d 1: Azerbaijan Pr

Dataset({
    features: ['input_ids', 'attention_mask', 'labels', '__index_level_0__'],
    num_rows: 77790
})

In [17]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tds,
    eval_dataset=eds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [18]:
# we must try-catch because resume_from_checkpoint throws a value error (for
# some reason instead of raising a warning) if training did not begin first.
try:
    trainer.train(resume_from_checkpoint=True)
except ValueError as e:
    print(e)
    trainer.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


No valid checkpoint found in output directory (models/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp)


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1500,0.8444,0.48071,0.781354,0.585964,0.65854,0.682488,43.392261
3000,0.7098,0.412711,0.804703,0.620581,0.682433,0.7074,43.594069
4500,0.678,0.385597,0.815051,0.6363,0.6933,0.718141,44.197637
6000,0.651,0.376364,0.819612,0.64259,0.698263,0.723851,44.293095




torch has allocated 767947776 of 25447170048
torch has allocated 767947776 of 25447170048
torch has allocated 767947776 of 25447170048
torch has allocated 767947776 of 25447170048


In [19]:
if False:
    trainer.push_to_hub()

In [20]:
if True:
    trainer.push_to_hub()

For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.


OSError: Tried to clone a repository in a non-empty folder that isn't a git repository ('/home/vente/repos/nlgs-research/models/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp'). If you really want to do this, do it manually:
 cd /home/vente/repos/nlgs-research/models/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp && git init && git remote add origin && git pull origin main
 or clone repo to a new folder and move your existing files there afterwards.

In [21]:
if True:
    trainer.push_to_hub()

For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.
Cloning https://huggingface.co/vente/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp into local empty directory.


Upload file pytorch_model.bin:   0%|          | 1.00/231M [00:00<?, ?B/s]

Upload file spiece.model:   0%|          | 1.00/773k [00:00<?, ?B/s]

Upload file training_args.bin:   0%|          | 1.00/5.31k [00:00<?, ?B/s]

To https://huggingface.co/vente/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp
   559b40c..6c02aa5  main -> main

   559b40c..6c02aa5  main -> main

   559b40c..6c02aa5  main -> main

To https://huggingface.co/vente/t5-small-finetuned-webnlg-mt-2.0e-04-multicorp
   6c02aa5..074ebbe  main -> main

   6c02aa5..074ebbe  main -> main



In [22]:
tds = get_ds('test')
# debug = Dataset.from_dict(tds[0:2])
predictions = trainer.predict(tds)
predictions

       record_idx  seed_number subset    category  split_index  \
0            4341            0  train  SportsTeam         2722   
1            4341            0  train  SportsTeam         2722   
2            4341            1  train  SportsTeam         2722   
3            4341            1  train  SportsTeam         2722   
4            4341            2  train  SportsTeam         2722   
...           ...          ...    ...         ...          ...   
94865       10126            0  train      Artist         8507   
94866       10126            1  train      Artist         8507   
94867       10126            1  train      Artist         8507   
94868       10126            2  train      Artist         8507   
94869       10126            2  train      Artist         8507   

                                                      sd  \
0      s2d 0: Azerbaijan Premier League|champions|Qar...   
1      d2s 0: The name of the championship football t...   
2      s2d 1: Azerbaijan Pr

torch has allocated 768295424 of 25447170048


PredictionOutput(predictions=array([[   0,    3,    7, ..., -100, -100, -100],
       [   0,    3,   26, ..., -100, -100, -100],
       [   0,    3,    7, ..., -100, -100, -100],
       ...,
       [   0,    3,   26, ..., -100, -100, -100],
       [   0,    3,    7, ..., -100, -100, -100],
       [   0,    3,   26, ..., -100, -100, -100]]), label_ids=array([[   3,    7,  357, ..., -100, -100, -100],
       [   3,   26,  357, ..., -100, -100, -100],
       [   3,    7,  357, ..., -100, -100, -100],
       ...,
       [   3,   26,  357, ..., -100, -100, -100],
       [   3,    7,  357, ..., -100, -100, -100],
       [   3,   26,  357, ..., -100, -100, -100]]), metrics={'test_loss': 0.367276668548584, 'test_rouge1': 0.8208428392933265, 'test_rouge2': 0.6433249391671216, 'test_rougeL': 0.6981754611733058, 'test_rougeLsum': 0.7222714254923237, 'test_gen_len': 43.332149621212125, 'test_runtime': 156.2663, 'test_samples_per_second': 54.062, 'test_steps_per_second': 1.689})

In [23]:
flat_keep_positive = lambda x: [e for e in x if e > 1]
pred_df = pd.DataFrame(columns=['pred_ids'], data=pd.Series(list(predictions.predictions)))
decoded = pred_df.pred_ids.map(flat_keep_positive).map(tokenizer.decode)
pred_df['decoded'] = decoded
pred_df['subset'] = 'test'
pred_df

Unnamed: 0,pred_ids,decoded,subset
0,"[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...,test
1,"[0, 3, 26, 357, 7, 209, 10, 5954, 7, 1334, 573...",d2s 1: Andrews County Airport is located in Te...,test
2,"[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...,test
3,"[0, 3, 26, 357, 7, 209, 10, 5954, 7, 1334, 573...",d2s 1: Andrews County Airport is located in Te...,test
4,"[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...,test
...,...,...,...
8443,"[0, 3, 26, 357, 7, 209, 10, 11375, 32, 2255, 1...",d2s 1: Alberto Teisaire was a Provisional Pres...,test
8444,"[0, 3, 7, 357, 26, 209, 10, 11375, 32, 2255, 1...","s2d 1: Alberto Teisaire|office (worked at, wor...",test
8445,"[0, 3, 26, 357, 7, 209, 10, 11375, 32, 2255, 1...",d2s 1: Alberto Teisaire was a Provisional Pres...,test
8446,"[0, 3, 7, 357, 26, 209, 10, 11375, 32, 2255, 1...",s2d 1: Alberto Teisaire|successor|Isaac Rojas;...,test


In [24]:
test_set = flt[flt.subset == 'test'].copy()
test_set['pred_ids'] = list(pred_df['pred_ids'].values)
test_set['decoded'] = list(pred_df['decoded'].values)
test_set

Unnamed: 0,record_idx,seed_number,subset,category,split_index,sd,nl,task,input_ids,attention_mask,labels,pred_ids,decoded
86,15192,0,test,Airport,697,s2d 0: Andrews County Airport|location|Texas; ...,The runway length at Andrews County airport ( ...,s2d,"[37, 22750, 2475, 44, 5954, 7, 1334, 3761, 41,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 3, 632, 10, 5954, 7, 1334, 573...","[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...
87,15192,0,test,Airport,697,d2s 0: The runway length at Andrews County air...,Andrews County Airport|location|Texas; Andrews...,d2s,"[5954, 7, 1334, 5735, 9175, 14836, 9175, 13598...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 37, 22750, 2475, 4...","[0, 3, 26, 357, 7, 209, 10, 5954, 7, 1334, 573...",d2s 1: Andrews County Airport is located in Te...
88,15192,1,test,Airport,697,s2d 1: Andrews County Airport|location|Texas; ...,"Located in Texas, Andrews County Airport, is 9...",s2d,"[3, 8691, 16, 2514, 6, 5954, 7, 1334, 5735, 6,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 5954, 7, 1334, 5735, ...","[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...
89,15192,1,test,Airport,697,"d2s 1: Located in Texas, Andrews County Airpor...",Andrews County Airport|location|Texas; Andrews...,d2s,"[5954, 7, 1334, 5735, 9175, 14836, 9175, 13598...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 3, 8691, 16, 2514, 6,...","[0, 3, 26, 357, 7, 209, 10, 5954, 7, 1334, 573...",d2s 1: Andrews County Airport is located in Te...
90,15192,2,test,Airport,697,s2d 2: Andrews County Airport|location|Texas; ...,Andrews County Airport is located in Texas and...,s2d,"[5954, 7, 1334, 5735, 19, 1069, 16, 2514, 11, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 5954, 7, 1334, 5735, ...","[0, 3, 7, 357, 26, 209, 10, 5954, 7, 1334, 573...",s2d 1: Andrews County Airport|location|Texas; ...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
94703,15119,0,test,Politician,624,d2s 0: Alberto Teisaire worked as the Provisio...,"Alberto Teisaire|office (worked at, worked as)...",d2s,"[11375, 32, 2255, 159, 2378, 9175, 19632, 41, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 3, 632, 10, 11375, 32, 2255, 1...","[0, 3, 26, 357, 7, 209, 10, 11375, 32, 2255, 1...",d2s 1: Alberto Teisaire was a Provisional Pres...
94704,15119,1,test,Politician,624,"s2d 1: Alberto Teisaire|office (worked at, wor...",Alberto Teisaire worked as a Provisional Presi...,s2d,"[11375, 32, 2255, 159, 2378, 1279, 38, 3, 9, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 209, 10, 11375, 32, 2255, 159,...","[0, 3, 7, 357, 26, 209, 10, 11375, 32, 2255, 1...","s2d 1: Alberto Teisaire|office (worked at, wor..."
94705,15119,1,test,Politician,624,d2s 1: Alberto Teisaire worked as a Provisiona...,"Alberto Teisaire|office (worked at, worked as)...",d2s,"[11375, 32, 2255, 159, 2378, 9175, 19632, 41, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 26, 357, 7, 209, 10, 11375, 32, 2255, 159,...","[0, 3, 26, 357, 7, 209, 10, 11375, 32, 2255, 1...",d2s 1: Alberto Teisaire was a Provisional Pres...
94706,15119,2,test,Politician,624,"s2d 2: Alberto Teisaire|office (worked at, wor...",Isaac Rojas was the successor to Alberto Teisa...,s2d,"[20876, 2158, 1191, 7, 47, 8, 22261, 12, 11375...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 7, 357, 26, 204, 10, 11375, 32, 2255, 159,...","[0, 3, 7, 357, 26, 209, 10, 11375, 32, 2255, 1...",s2d 1: Alberto Teisaire|successor|Isaac Rojas;...


In [25]:
save_fname

NameError: name 'save_fname' is not defined

In [26]:
save_fname = f"~/repos/nlgs-research/pipeline/predictions/{TASK}-{MODEL_CKPNT}-{NUM_TRAIN_EPOCHS}{MULTI_CORP}.pkl"
test_set.to_pickle(save_fname)
save_fname

'~/repos/nlgs-research/pipeline/predictions/mt-t5-small-5-multicorp.pkl'

In [27]:
def text_to_prediction_single(text):
    tensors = tokenizer("<pad>" + text + "</s>", return_tensors='pt').to(device)['input_ids']
    generation = trainer.model.generate(tensors,
        early_stopping=True,
        num_beams=5,
        max_new_tokens=1024,
        temperature=.9,
    ) 
    return tokenizer.decode(generation[0], skip_special_tokens=True)

t = "The leader of Aarhus is Jacob Bundsgaard."
text_to_prediction_single(t)

's2d 1: Aarhus|leader name|Jacob Bundsgaard'

In [28]:
print("\n".join(map(tokenizer.decode,
                np.where(predictions.predictions != -100, predictions.predictions, tokenizer.pad_token_id)
                )))

<pad> s2d 1: Andrews County Airport|location|Texas; Andrews County Airport|runway length|1773.0; Andrews County Airport|elevation above the sea level (in metres)|973.0</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
<pad> d2s 1: Andrews County Airport is located in Texas and is 973 metres above sea level. It has a runway length of 1773.0.</s><pad>