# Train BART on CNN Dataset for Sentencification
## 0 Install Libraries

In [1]:
!pip install datasets transformers rouge_score nltk
!pip install --upgrade accelerate

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hCollecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

## 1 Dataset Preparation

In [2]:
import random


def parse_relations(rel_pth: str):
    ''' Parse the relation triple data file.
        It does not capture phrases, only captures the sentences and relations.
        If there's another units after the object, in the triples which are splitted by tabs,
        they will be joined into the object unit.
    '''
    rel = []
    stn = []
    rel_line = ""
    with open(rel_pth, 'r') as f:
        while True:
            line = f.readline()
            if not line: break
            line = line.strip().split('\t')
            if line[0] == 'S':
                if len(rel_line) > 0:
                    rel.append(rel_line)
                    rel_line = ""
                if len(rel) < len(stn): # There's only a phrase for past sentence.
                    del stn[-1]
                stn.append(' '.join(line[3:]))
            elif line[0] == 'R': # Relation triples
                if len(line) >= 4:
                    rel_line += '<subject>%s<predicate>%s<object>%s' % (line[1], line[2], ' '.join(line[3:]))
                else:
                    rel_line += '<subject>%s<predicate>%s' % (line[1], line[2])
    if len(rel_line) > 0:
        rel.append(rel_line)
        rel_line = ""
    return rel, stn

def concatenated_dataset(rel, stn, min_choice=1, max_choice=3, random_seed=42, len_dataset=None):
    ''' Build a dataset from parsed relations and sentences.
        Each data is a concatenation from randomly chosen sentences and their relations.
        If len_dataset is None, the number of data is same as the number of input sentences.
    '''
    random.seed(random_seed)
    rel_out, stn_out = [], []
    
    if len_dataset is None: len_dataset = len(stn)
        
    for i in range(len_dataset):
        n_choice = random.randint(min_choice, max_choice)
        idxs = random.sample(range(len(stn)), n_choice)
        
        r, s = [], []
        for j in idxs:
            r.append(rel[j])
            s.append(stn[j])
        
        rel_out.append(''.join(r))
        stn_out.append(' '.join(s))
        
    df = pd.DataFrame({'relations': rel_out, 'sentence': stn_out})
    out = Dataset.from_pandas(df)
    return out

In [3]:
import pandas as pd
from datasets import Dataset, DatasetDict

rel, stn = parse_relations("triple_350k.txt")
train_rel, train_stn = rel[:int(len(rel)*0.8)], stn[:int(len(rel)*0.8)]
validation_rel, validation_stn = rel[int(len(rel)*0.8):int(len(rel)*0.9)], stn[int(len(rel)*0.8):int(len(rel)*0.9)]
test_rel, test_stn = rel[int(len(rel)*0.9):], stn[int(len(rel)*0.9):]

train_dataset = concatenated_dataset(train_rel, train_stn)
validation_dataset = concatenated_dataset(validation_rel, validation_stn)
test_dataset = concatenated_dataset(test_rel, test_stn)

CNN_dataset = DatasetDict({'train': train_dataset, 'validation': validation_dataset, 'test': test_dataset})
CNN_dataset.reset_format()
CNN_dataset

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['relations', 'sentence'],
        num_rows: 822221
    })
    validation: Dataset({
        features: ['relations', 'sentence'],
        num_rows: 102778
    })
    test: Dataset({
        features: ['relations', 'sentence'],
        num_rows: 102778
    })
})

In [4]:
for i in range(100, 110):
    print(CNN_dataset['train']['sentence'][i])
    print(CNN_dataset['train']['relations'][i], '\n')

It's called Make A Leap (Leap stands for Lowering Emissions and Particulates).
<subject>It<predicate>Make<object>Particulates<subject>It<predicate>'s called<object>Make Particulates<subject>It<predicate>Make<object>A<subject>It<predicate>'s called<object>Make A 

Two days after the missed party, Adams was officially reported missing.
<subject>Adams<predicate>missing<subject>Adams<predicate>was officially reported<object>Two days after the missed party 

Speakers often blasted loud music from the front of the house, Pusztay recalled, while dogs walked on the property. Woods encouraged passengers to call 1-800-USA-RAIL for information about refunds and credits. "Since the agency's decision was final and since the Sacketts have no other adequate remedy in a court, they may bring their suit" under federal law, said Justice Antonin Scalia.
<subject>dogs<predicate>walked<object>on the property<subject>Pusztay<predicate>recalled<subject>Speakers<predicate>blasted<object>loud music from the fr

## 2 Model Training

In [5]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_checkpoint = 'facebook/bart-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.add_tokens(['<subject>', '<predicate>', '<object>'])
model.resize_token_embeddings(len(tokenizer))

Embedding(50268, 768)

In [6]:
def preprocess_function(examples, max_input_length=512, max_target_length=128):
    model_inputs = tokenizer(
        examples["relations"],
        max_length=max_input_length,
        truncation=True,
    )
    labels = tokenizer(
        examples["sentence"], max_length=max_target_length, truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [7]:
CNN_tokenized = CNN_dataset.map(preprocess_function, batched=True)
CNN_tokenized = CNN_tokenized.remove_columns(
    CNN_dataset["train"].column_names
)
CNN_tokenized

                                                                     

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 822221
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 102778
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 102778
    })
})

In [8]:
import numpy as np
from datasets import load_metric
import nltk
from nltk.tokenize import sent_tokenize

rouge_score = load_metric("rouge")
nltk.download("punkt")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

  rouge_score = load_metric("rouge")
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments

batch_size = 32
num_train_epochs = 5
# Show the training loss with every epoch
logging_steps = len(CNN_tokenized["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_name}-finetuned-CNN",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=CNN_tokenized["train"],
    eval_dataset=CNN_tokenized["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

res = trainer.train(resume_from_checkpoint=True)

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


In [12]:
res.metrics

{'train_runtime': 9126.6438,
 'train_samples_per_second': 450.451,
 'train_steps_per_second': 14.077,
 'total_flos': 7.610631514620826e+17,
 'train_loss': 0.0590253048282719,
 'epoch': 5.0}

## 3 Model Saving and Evaluation

In [17]:
import time
train_end = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime())
with open('result.txt', 'w') as f:
    f.write(str(trainer.state.log_history))
trainer.save_model(f"bart-base-finetuned-CNN-{train_end}")

In [20]:
import shutil
# train_end = 'bart-base-finetuned-CNN-2023-05-20_15-55-40'
shutil.make_archive(f'bart-base-finetuned-CNN-{train_end}', 'zip', f'bart-base-finetuned-CNN-{train_end}')

'/root/bart-base-finetuned-CNN-2023-05-21_09-19-38.zip'

In [21]:
from transformers import pipeline

summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device='cuda:0')

In [22]:
def print_summary(dataset, idx, summarizer):
    print(f"\n>>> {idx}")
    relations = dataset["test"][idx]["relations"]
    sentence = dataset["test"][idx]["sentence"]
    if len(relations.split()) == 0:
        print(f"\n>>> There's no contents.")
        return
    result = summarizer(relations)[0]["summary_text"]
    print(f"\n>>> Relations: {relations}")
    print(f"\n>>> Sentence: {sentence}")
    print(f"\n>>> Result: {result}")

In [23]:
for i in range(20): print_summary(CNN_dataset, i, summarizer)

Your max_length is set to 128, but your input_length is only 127. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=63)



>>> 0

>>> Relations: <subject>We<predicate>do not want to tell<object>him such bad news<subject>We<predicate>do not want<object>to tell him such bad news<subject>We<predicate>do not want to have<object>it on our conscience<subject>We<predicate>do not want<object>to have it on our conscience<subject>The groups<predicate>are looking<object>at issues such as housing to even personal displays of affection<subject>The groups<predicate>are looking<object>at issues such as housing to entitlements<subject>both the civil investigations<predicate>are<object>ongoing<subject>both the criminal investigations<predicate>are<object>ongoing<subject>The Alavi Foundation 's former president<predicate>remains<object>under investigation for alleged obstruction of justice

>>> Sentence: "We do not want to have it on our conscience and tell him such bad news. The groups are looking at issues such as housing to entitlements and even personal displays of affection. The Alavi Foundation's former president rem

Your max_length is set to 128, but your input_length is only 23. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=11)
Your max_length is set to 128, but your input_length is only 100. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)



>>> Relations: <subject>he<predicate>was cast<object>as undercover cop Brian O'Conner infiltrating a street - racing gang in 2001's " the Furious<subject>His career<predicate>took off<object>when he was cast as undercover cop Brian O'Conner infiltrating a street - racing gang in 2001's " the Furious<subject>undercover cop Brian O'Conner<predicate>infiltrating<object>a street - racing gang<subject>he<predicate>was cast<object>as undercover cop Brian O'Conner infiltrating a street - racing gang in 2001's " The Fast<subject>His career<predicate>took off<object>when he was cast as undercover cop Brian O'Conner infiltrating a street - racing gang in 2001's " The Fast<subject>15 people there<predicate>were<object>students with no ties to organized crime<subject>The city<predicate>has become<object>a focal point of Calderon's anti-drug efforts after the January 31 killings of 15 people there

>>> Sentence: His career really took off when he was cast as undercover cop Brian O'Conner infiltrat

Your max_length is set to 128, but your input_length is only 90. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=45)



>>> Relations: <subject>the settlement<predicate>was reached<object>in large part because of remedial actions<subject>a statement on Monday<predicate>stating<object>that the settlement was reached in large part because of remedial actions<subject>remedial actions<predicate>instituted<object>at the company over the past two years<subject>the settlement<predicate>was reached<object>in large part because of reforms<subject>a statement on Monday<predicate>stating<object>that the settlement was reached in large part because of reforms<subject>Maxim<predicate>released<object>a statement on Monday

>>> Sentence: Maxim released a statement on Monday stating that the settlement was reached in large part because of reforms and remedial actions instituted at the company over the past two years.

>>> Result: Maxim released a statement on Monday stating that the settlement was reached in large part because of reforms and remedial actions instituted at the company over the past two years.

>>> 4


Your max_length is set to 128, but your input_length is only 27. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=13)
Your max_length is set to 128, but your input_length is only 13. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=6)



>>> Relations: <subject>a lot of people<predicate>are<object>concerned about<subject>he<predicate>articulates<object>what a lot of people are concerned about<subject>I<predicate>think<object>he articulates what a lot of people are concerned about<subject>the tax initiatives<predicate>unveiled<object>Saturday evening<subject>Republicans<predicate>were<object>dismissive of the tax initiatives<subject>Thirty-nine people<predicate>have been injured<subject>authorities<predicate>said<object>Thirty-nine people have been injured

>>> Sentence: I think he articulates what a lot of people are concerned about. Republicans were dismissive of the tax initiatives unveiled Saturday evening. Thirty-nine people have been injured, authorities said.

>>> Result: "I think he articulates what a lot of people are concerned about. Republicans were dismissive of the tax initiatives unveiled Saturday evening. Thirty-nine people have been injured, authorities said.

>>> 5

>>> Relations: <subject>her career<p

Your max_length is set to 128, but your input_length is only 46. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=23)



>>> Relations: <subject>we<predicate>repaired as best<object>at the races

>>> Sentence: "The one he had got damaged during the first few races, and we repaired it as best we could at the races.

>>> Result: "But we repaired as best as we could at the races.

>>> 7

>>> Relations: <subject>Self - immolations that prompt political change<predicate>are<object>rare events<subject>Self - immolations that prompt political change<predicate>are<object>extraordinary events<subject>Self - immolations<predicate>prompt<object>political change

>>> Sentence: Self-immolations that prompt political change are extraordinary and rare events.

>>> Result: Self-immolations that prompt political change are extraordinary and rare events.

>>> 8

>>> Relations: <subject>they<predicate>arrested<object>one suspect later that day<subject>they<predicate>found<object>London 's vehicle in a residential area of Palm Springs<subject>Police<predicate>said<object>they found London 's vehicle in a residential area o

Your max_length is set to 128, but your input_length is only 84. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=42)



>>> Relations: <subject>them<predicate>to get<object>into drugs<subject>We<predicate>do n't want<object>them to get into drugs<subject>them<predicate>to get<object>into smoking<subject>We<predicate>do n't want<object>them to get into smoking<subject>them<predicate>wasting<object>their time<subject>We<predicate>do n't want<object>them wasting their time<subject>other Chinese students<predicate>living<object>overseas<subject>Users of the popular Chinese social media platform Weibo<predicate>expressed<object>anger over the concern about other Chinese students<subject>Users of the popular Chinese social media platform Weibo<predicate>expressed<object>anger over the attack<subject>We<predicate>know<object>our Customers are going to appreciate the fact that every seat on every flight is a reward seat<subject>every seat on every flight<predicate>is<object>a reward seat<subject>our Customers<predicate>to appreciate<object>the fact that every seat on every flight is a reward seat<subject>We<pr

Your max_length is set to 128, but your input_length is only 93. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=46)



>>> Relations: <subject>me<predicate>to come<object>back any more<subject>They<predicate>do n't ask<object>me to come back any more<subject>the coaliton<predicate>would be<object>the first ones to criticize the show<subject>They<predicate>said<object>if Lifetime were depicting Latinas in negative stereotypical roles, the coaliton would be the first ones<subject>the first ones<predicate>to criticize<object>the show<subject>Lifetime<predicate>were depicting<object>Latinas

>>> Sentence: They don't ask me to come back any more. They said if Lifetime were depicting Latinas in negative stereotypical roles, the coaliton would be the first ones to criticize the show.

>>> Result: They don't ask me to come back any more. They said if Lifetime were depicting Latinas in negative stereotypical roles, the coaliton would be the first ones to criticize the show.

>>> 11


Your max_length is set to 128, but your input_length is only 65. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=32)



>>> Relations: <subject>I<predicate>should n't say<object>that<subject>I shouldn't say that<predicate>says<object>he<subject>the Archdiocese of Philadelphia<predicate>to prevent<object>the sexual abuse of children<subject>a 2011 grand jury report<predicate>blamed<object>the Archdiocese of Philadelphia for failing to prevent the sexual abuse of children<subject>Dozens of priests<predicate>were placed<object>on administrative leave after the release of a 2011 grand jury report<subject>we<predicate>treat<object>everyone the same

>>> Sentence: "I shouldn't say that," he says. Dozens of priests were placed on administrative leave after the release of a 2011 grand jury report that blamed the Archdiocese of Philadelphia for failing to prevent the sexual abuse of children. Yeah, we treat everyone the same.

>>> Result: "I shouldn't say that," he says. Dozens of priests were placed on administrative leave after the release of a 2011 grand jury report that blamed the Archdiocese of Philadelphi

Your max_length is set to 128, but your input_length is only 111. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=55)



>>> Relations: <subject>NATO<predicate>would not confirm<object>what other types of ordnance may have been dropped<subject>a 2,000 - pound weapon<predicate>to penetrate<object>reinforced concrete<subject>a 2,000 - pound weapon<predicate>designed<object>to penetrate reinforced concrete<subject>A photograph of the site<predicate>showed<object>the unexploded bomb

>>> Sentence: A photograph of the site showed the unexploded bomb, a 2,000-pound weapon designed to penetrate reinforced concrete, but NATO would not confirm what other types of ordnance may have been dropped.

>>> Result: A photograph of the site showed the unexploded bomb, a 2,000-pound weapon designed to penetrate reinforced concrete, but NATO would not confirm what other types of ordnance may have been dropped.

>>> 13


Your max_length is set to 128, but your input_length is only 37. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=18)



>>> Relations: <subject>PRESS Inc<predicate>cites<object>insurance industry sources as saying 2,000 spectators are injured in a year<subject>her staff<predicate>would be<object>ready if he resigns<subject>she<predicate>told<object>the governor her staff would be ready if he resigns<subject>She<predicate>said<object>she told the governor her staff would be ready<subject>she<predicate>would be<object>ready if he resigns<subject>she<predicate>told<object>the governor she would be ready if he resigns<subject>She<predicate>said<object>she told the governor she would be ready<subject>he<predicate>resigns

>>> Sentence: PRESS Inc., a racing safety company, cites insurance industry sources as saying 2,000 spectators are injured in a year. She said she told the governor she and her staff would be ready if he resigns.

>>> Result: Press Inc. cites insurance industry sources as saying 2,000 spectators are injured in a year. She said she told the governor she and her staff would be ready if he re

Your max_length is set to 128, but your input_length is only 79. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=39)



>>> Relations: <subject>the International Monetary Fund chief<predicate>had agreed to turn over<object>his U.N. travel document<subject>the International Monetary Fund chief<predicate>had agreed<object>to turn over his U.N. travel document<subject>He<predicate>said<object>the International Monetary Fund chief had agreed to turn over his U.N. travel document then<subject>the International Monetary Fund chief<predicate>had agreed to post<object>$ 1 million in cash , to be confined to home detention in Manhattan with electronic monitoring<subject>the International Monetary Fund chief<predicate>had agreed<object>to post $ 1 million in cash<subject>He<predicate>said<object>the International Monetary Fund chief had agreed to post $ 1 million in cash then<subject>$ 1 million in cash<predicate>to be confined<object>to home detention in Manhattan with electronic monitoring<subject>everyone<predicate>was<object>interested in everyone else's business Here<subject>Cronin<predicate>said<object>Her

Your max_length is set to 128, but your input_length is only 51. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=25)
Your max_length is set to 128, but your input_length is only 15. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=7)



>>> Relations: <subject>Terrorism<predicate>has<object>no religion or country<subject>Jumma Atiga<predicate>said<object>Terrorism has no religion or country<subject>high<predicate>ranking<object>Libyan politician<subject>Camden 's Mayor Dana Redd<predicate>will rehire<object>50 15 firefighters<subject>Dana Redd<predicate>is Mayor of<object>Camden<subject>Camden 's Mayor Dana Redd<predicate>will rehire<object>50 police officers

>>> Sentence: "Terrorism has no religion or country," said Jumma Atiga, a high-ranking Libyan politician. Camden's Mayor Dana Redd will rehire 50 police officers and 15 firefighters.

>>> Result: "Terrorism has no religion or country," said Jumma Atiga, a high-ranking Libyan politician. Camden's Mayor Dana Redd will rehire 50 police officers and 15 firefighters.

>>> 17

>>> Relations: <subject>the victim<predicate>was<subject>it<predicate>did not specify<object>where the victim was from<subject>The victim<predicate>was not<object>of Dutch nationality<subject>t

Your max_length is set to 128, but your input_length is only 98. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=49)



>>> Relations: <subject>Zine el<predicate>is President of<object>Tunisia

>>> Sentence: Less than a month after the self-immolation, Tunisian President Zine el

>>> Result: Tunisia's President Zine el-Abidine Ben Ali.

>>> 19

>>> Relations: <subject>he<predicate>did n't accidentally say<object>that he was going to flaccin ' Disneyland<subject>he<predicate>to flaccin<object>Disneyland<subject>He<predicate>voted<object>the Most Valuable Player of the game<subject>He<predicate>accepted<object>an award for being<subject>the prominent fiscal matters front and center<predicate>highlight<object>his area of expertise now<subject>Ryan's recent reticence<predicate>is<object>more noticeable because the prominent fiscal matters front and center now highlight his area of expertise

>>> Sentence: He also accepted an award for being voted the Most Valuable Player of the game, but at least he didn't accidentally say that he was going to flaccin' Disneyland. Ryan's recent reticence is more noticeable

In [12]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# model_checkpoint = 'facebook/bart-base'
model_checkpoint = 'bart-base-finetuned-CNN-2023-05-20_15-55-40'
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.add_tokens(['<subject>', '<predicate>', '<object>'])
model.resize_token_embeddings(len(tokenizer))

Embedding(50268, 768, padding_idx=1)

In [25]:
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50268, 768)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50268, 768)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    

In [24]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# model_checkpoint = f'bart-base-finetuned-CNN-{train_end}'
# model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to('cuda')
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# tokenizer.add_tokens(['<subject>', '<predicate>', '<object>'])
# model.resize_token_embeddings(len(tokenizer))
 
# Repository 생성 & model upload
REPO_NAME = 'Cynki/rtsum_abs_bart' # ex) 'my-bert-fine-tuned'
AUTH_TOKEN = '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' # <https://huggingface.co/settings/token>
 
## Upload to Huggingface Hub
model.push_to_hub(
    REPO_NAME, 
    use_temp_dir=True, 
    use_auth_token=AUTH_TOKEN
)
tokenizer.push_to_hub(
    REPO_NAME, 
    use_temp_dir=True, 
    use_auth_token=AUTH_TOKEN
)

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]
pytorch_model.bin:   0%|          | 0.00/558M [00:00<?, ?B/s][A
pytorch_model.bin:   0%|          | 8.19k/558M [00:00<12:02:05, 12.9kB/s][A
pytorch_model.bin:   0%|          | 90.1k/558M [00:00<1:08:13, 136kB/s]  [A
pytorch_model.bin:   0%|          | 156k/558M [00:01<47:15, 197kB/s]   [A
pytorch_model.bin:   0%|          | 377k/558M [00:01<19:25, 478kB/s][A
pytorch_model.bin:   0%|          | 598k/558M [00:01<13:44, 676kB/s][A
pytorch_model.bin:   0%|          | 1.38M/558M [00:01<05:33, 1.67MB/s][A
pytorch_model.bin:   0%|          | 2.10M/558M [00:01<04:02, 2.29MB/s][A
pytorch_model.bin:   1%|          | 4.55M/558M [00:02<02:17, 4.03MB/s][A
pytorch_model.bin:   1%|          | 6.04M/558M [00:02<01:56, 4.75MB/s][A
pytorch_model.bin:   1%|▏         | 7.53M/558M [00:02<01:41, 5.40MB/s][A
pytorch_model.bin:   2%|▏         | 9.01M/558M [00:02<01:32, 5.92MB/s][A
pytorch_model.bin:   2%|▏         | 10.5M/558M [00:03<01:25,

CommitInfo(commit_url='https://huggingface.co/Cynki/rtsum_abs_bart/commit/3b0158af9599cbfce549e461f62848cdaca0de84', commit_message='Upload tokenizer', commit_description='', oid='3b0158af9599cbfce549e461f62848cdaca0de84', pr_url=None, pr_revision=None, pr_num=None)