In [None]:
!pip3 install accelerate

In [None]:
# TPU 환경
# pip3 uninstall -y tensorflow
# pip3 install accelerate==0.26.0

In [None]:
# TPU 환경
# import os
# os.environ["TPU_NAME"] = os.environ["TPU_WORKER_ID"]
# os.environ.pop('TPU_PROCESS_ADDRESSES')
# os.environ.pop('CLOUD_TPU_TASK_ID')

In [None]:
from datasets import load_dataset
from transformers import BartTokenizerFast, BartForConditionalGeneration

def preprocess_data(example, tokenizer):
    return tokenizer(
        example["document"],
        text_target=example["summary"],
        truncation=True
    )

model_name = "gogamza/kobart-base-v2"
tokenizer = BartTokenizerFast.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

dataset = load_dataset("daekeun-ml/naver-news-summarization-ko")
print(dataset)

tokenizer.model_max_length = model.config.max_position_embeddings
processed_dataset = dataset.map(
    lambda example: preprocess_data(example, tokenizer),
    batched=True,
    remove_columns=dataset["train"].column_names
)

sample = processed_dataset["train"]["labels"][0]
print(sample)
print(tokenizer.decode(sample))

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from accelerate.utils import set_seed
from accelerate import Accelerator, notebook_launcher

def create_dataloaders(batch_size):
    seq2seq_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding="longest",
        return_tensors="pt"
    )
    train_dataloader = DataLoader(
        processed_dataset["train"],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=seq2seq_collator
    )
    eval_dataloader = DataLoader(
        processed_dataset["validation"],
        shuffle=False,
        batch_size=batch_size,
        collate_fn=seq2seq_collator
    )
    return train_dataloader, eval_dataloader

def training_loop(model, epochs, seed, mixed_precision, batch_size, logging_steps):
    set_seed(seed)
    accelerator = Accelerator(mixed_precision=mixed_precision)
    train_dataloader, eval_dataloader = create_dataloaders(batch_size)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5)
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

    for epoch in range(epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            if step % logging_steps == 0:
                accelerator.print(f"epoch {epoch}: {loss.item()}")

model = BartForConditionalGeneration.from_pretrained(model_name)
args = (model, 5, 2024, "fp16", 8, 100)
notebook_launcher(training_loop, args, num_processes=4)