In [1]:
!pip install transformers



In [2]:
!pip install torch



In [3]:
!pip install torch torchvision torchaudio



In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load bart model
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [5]:
# load article text
f = open('../article_text.txt', 'r', encoding='utf-8')
article = f.read()
f.close()

In [8]:
inputs = tokenizer.encode(article, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = model.generate(inputs, max_length=200, min_length=50, length_penalty=2.0, num_beams=10, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary

'NEW: Negotiations over occupation have stalled, student group says. NEW: Students occupying Hamilton Hall face expulsion from the university, the school says. Dozens of protesters are detained at the University of North Carolina in Chapel Hill. Portland State University in Oregon closed the campus Tuesday citing an "ongoing incident"'

In [12]:
# load in huggingface library cnn article dataset
from datasets import load_dataset
dataset = load_dataset('cnn_dailymail', '3.0.0')

In [15]:
# preprocess datasets
def preprocess_function(examples):
    inputs = [doc for doc in examples["article"]]
    targets = [summary for summary in examples["highlights"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=142, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 287113/287113 [07:05<00:00, 674.58 examples/s]
Map: 100%|██████████| 13368/13368 [00:19<00:00, 668.47 examples/s]
Map: 100%|██████████| 11490/11490 [00:16<00:00, 698.19 examples/s]


In [17]:
!pip install transformers datasets torch accelerate


Collecting accelerate
  Downloading accelerate-0.29.3-py3-none-any.whl.metadata (18 kB)
Downloading accelerate-0.29.3-py3-none-any.whl (297 kB)
   ---------------------------------------- 0.0/297.6 kB ? eta -:--:--
   --------------------------------- ------ 245.8/297.6 kB 7.4 MB/s eta 0:00:01
   ---------------------------------------- 297.6/297.6 kB 6.1 MB/s eta 0:00:00
Installing collected packages: accelerate
Successfully installed accelerate-0.29.3


In [21]:
pip install accelerate==0.21.0

Collecting accelerate==0.21.0Note: you may need to restart the kernel to use updated packages.

  Downloading accelerate-0.21.0-py3-none-any.whl.metadata (17 kB)
Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
   ---------------------------------------- 0.0/244.2 kB ? eta -:--:--
   --------------------------------- ------ 204.8/244.2 kB 6.3 MB/s eta 0:00:01
   ---------------------------------------- 244.2/244.2 kB 3.8 MB/s eta 0:00:00
Installing collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.29.3
    Uninstalling accelerate-0.29.3:
      Successfully uninstalled accelerate-0.29.3
Successfully installed accelerate-0.21.0


In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    report_to="none"  
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

In [None]:
trainer.train()

In [7]:
# save model and tokenzier
tokenizer.save_pretrained('tokenizer')
model.save_pretrained('summary_model')


Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
