### Configuration


In [None]:
import os
import sys
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset
from utils_nlp.eval import compute_rouge_python
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessor,
)

from utils_nlp.models.transformers.datasets import SummarizationDataset
import nltk
from nltk import tokenize
import pickle

In [None]:
device = torch.device('cuda')
torch.cuda.set_device(f'cuda:{0}')

In [None]:
# bertsum train에서 사용한 값과 동일
USE_PREPROCSSED_DATA = True
MODEL_NAME = "bert-base-uncased"
MAX_POS_LENGTH = 512
NUM_GPUS = 1
ENCODER = "transformer"
MAX_STEPS=5e4

CACHE_DIR = ''
model_save_path = ''
processor = ExtSumProcessor(model_name=MODEL_NAME, cache_dir=CACHE_DIR)

In [None]:
# for loading a previous saved model
import torch
model_path = os.path.join(
        model_save_path,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ))
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)
summarizer.model.load_state_dict(torch.load(model_path, map_location="cpu"))

### XSUM PREPROCESSING

In [None]:
from datasets import load_dataset

xsum_train = load_dataset('xsum', split='train')
xsum_test = load_dataset('xsum', split='test')

In [None]:
train_src = []
train_tgt = []

for row in xsum_train:
    if len(row['document'])==0 or len(row['summary'])==0:
        continue
    train_src.append(row['document'])
    train_tgt.append(row['summary'])

In [None]:
train_dataset = SummarizationDataset(
    None,
    source=train_src,
    source_preprocessing=[tokenize.sent_tokenize],
    target=train_tgt,
    word_tokenize=nltk.word_tokenize,
)

In [None]:
processor = ExtSumProcessor(model_name=MODEL_NAME,  cache_dir=CACHE_DIR)
preprocessed_traindata = processor.preprocess(train_dataset)

### Train tokens

In [None]:
# top_n : salient sentences ratio (value: 0~1)
prediction = summarizer.predict(preprocessed_traindata, num_gpus=1, batch_size=10, sentence_separator="\n", top_n=0.5)

In [None]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [None]:
train_token = []

for i in range(len(prediction)):
    src_txt = preprocessed_traindata[i]['src_txt']
    exts = prediction[i].split('\n')
    mask = []
    for sent in src_txt:
        if sent in exts:
            mask.append(True)
        else: mask.append(False)
    document = [0]
    pos = []
    nos = []
    for j, sent in enumerate(src_txt):
        sent = sent + '\n'
        tokens = tokenizer.encode(sent, add_special_tokens=False, truncation=True)
        if mask[j]:
            pos.append([len(document), len(document)+ len(tokens) -1])
        else: nos.append([len(document), len(document) + len(tokens) -1 ])
        document.extend(tokens)
    document[-1] = 2
    summary = tokenizer(preprocessed_traindata[i]['tgt_txt'] ,add_special_tokens=True, truncation=True).input_ids
    dic = {'document':document, 'summary':summary, 'pos':pos, 'nos':nos}
    
    train_token.append(dic)

In [None]:
train_save_path = 'train'
with open(train_save_path, 'wb') as f:
    pickle.dump(train_token, f)

### Test tokens

In [None]:
test_token = []

for row in xsum_test:
    document = tokenizer(row['document'], add_special_tokens=True, truncation=True).input_ids
    summary = tokenizer(row['summary'], add_special_tokens=True, truncation=True).input_ids
    if len(document)==0 or len(summary)==0: continue
    dic = {'docuemnt':document, 'summary':summary}
    test_token.append(dic)

In [None]:
test_save_path = 'test'
with open(test_save_path, 'wb') as f:
    pickle.dump(test_token, f)