### Configuration


In [None]:
import os
import shutil
import sys
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessor,
)

from utils_nlp.models.transformers.datasets import SummarizationDataset

In [None]:
device = torch.device('cuda')
torch.cuda.set_device(f'cuda:{0}')

In [None]:
MODEL_NAME = "bert-base-uncased"

In [None]:
# the cache data path during fine tuning
CACHE_DIR = ''
processor = ExtSumProcessor(model_name=MODEL_NAME, cache_dir=CACHE_DIR)

### Data preprocessing

In [None]:
DATA_PATH = CACHE_DIR

In [None]:
train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=-1, local_cache_path=DATA_PATH)

ext_sum_train = processor.preprocess(train_dataset, oracle_mode="greedy")
ext_sum_test = processor.preprocess(test_dataset, oracle_mode="greedy")

In [None]:
# save preprocessed data
save_path = os.path.join('./', "processed")
torch.save(ext_sum_train, os.path.join(save_path, "train_full.pt"))
torch.save(ext_sum_test, os.path.join(save_path, "test_full.pt"))

# load preprocessed data
# ext_sum_train = torch.load(os.path.join(save_path, "train_full.pt"))
# ext_sum_test = torch.load(os.path.join(save_path, "test_full.pt"))

### Model training

In [None]:
USE_PREPROCSSED_DATA = True

BATCH_SIZE = 10
MAX_POS_LENGTH = 512

NUM_GPUS = 1
ENCODER = "transformer"

LEARNING_RATE=2e-3

REPORT_EVERY=100
MAX_STEPS=5e4
WARMUP_STEPS=5e3

In [None]:
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)

In [None]:
summarizer.fit(
            ext_sum_train,
            num_gpus=NUM_GPUS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=2,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            verbose=False,
            report_every=REPORT_EVERY,
            clip_grad_norm=False,
            use_preprocessed_data=USE_PREPROCSSED_DATA
        )

In [None]:
# model save
model_save_path = ''

summarizer.save_model(
    os.path.join(
        model_save_path,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ),
    )
)