# 01. Preprocess

## imports

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import warnings

warnings.filterwarnings(action="ignore")

In [4]:
import json
import platform
import pickle
import dill
import yaml

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TestTubeLogger  # pip install test-tube

from functools import partial
from collections import defaultdict, OrderedDict, Counter
from tqdm import tqdm

from model import SummaRunner

# from utils.data import SumDataset, Feature
from model import build_vocab
from model.types_ import *

In [5]:
if platform.system() == "Windows":
    try:
        from eunjeon import Mecab
    except:
        print("please install eunjeon module")
else:  # Ubuntu일 경우
    from konlpy.tag import Mecab

## Data Load

In [6]:
train_path = "../../../../datasets/kor_data/total_data/train_50965.jsonl"
dev_path = "../../../../datasets/kor_data/total_data/dev_50965.jsonl"
test_path = "../../../../datasets/kor_data/total_data/test_50965.jsonl"

In [7]:
with open(train_path, "r", encoding="utf-8") as f:
    jsonl = list(f)

train_data = []
for json_str in jsonl:
    train_data.append(json.loads(json_str))

In [8]:
# train_data[0]

## Build Vocab Function

In [9]:
def build_vocab(
    dataset: JSONType, stopwords: Optional[List[str]] = None, num_words: int = 40000
):
    # 0. tokenizer
    tokenizer = Mecab()

    # 1. tokenization
    all_tokens = []
    for data in tqdm(dataset):
        sents = data["article_original"]
        for sent in sents:
            tokens = tokenizer.morphs(sent)
            if stopwords:
                all_tokens.extend([token for token in tokens if token not in stopwords])
            else:
                all_tokens.extend(tokens)

    # 2. build vocab
    vocab = Counter(all_tokens)
    vocab = vocab.most_common(num_words)

    # 3. add pad & unk tokens
    word_index = defaultdict()
    word_index["<PAD>"] = 0
    word_index["<UNK>"] = 1

    for idx, (word, _) in enumerate(vocab, 2):
        word_index[word] = idx

    index_word = {idx: word for word, idx in word_index.items()}

    return word_index, index_word

In [10]:
# word_index, index_word = build_vocab(train_data)

# with open("./word_index_v02.pkl", "wb") as f:
#     dill.dump(word_index, f)

In [11]:
with open("../utils//word_index_v02.pkl", "rb") as f:
    word_index = dill.load(f)

In [12]:
len(word_index)

40002

## DataSet Class

In [13]:
class SumDataset(Dataset):
    def __init__(self, path):
        with open(path, "r", encoding="utf-8") as f:
            jsonl = list(f)

        self.data = []
        for json_str in jsonl:
            self.data.append(json.loads(json_str))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        doc = self.data[idx]["article_original"]
        ext_indices = self.data[idx]["extractive"]
        summaries = self.data[idx]["abstractive"]

        return doc, ext_indices, summaries

In [14]:
trainset = SumDataset(train_path)

## Feature Class

In [15]:
class Feature:
    def __init__(self, word_index, tokenizer):
        self.word_index = word_index
        self.index_word = {idx: word for word, idx in word_index.items()}
        assert len(self.word_index) == len(self.index_word)
        self.PAD_IDX = 0
        self.UNK_IDX = 1
        self.PAD_TOKEN = "<PAD>"
        self.UNK_TOKEN = "<UNK>"
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.word_index)

    def index_to_word(self, idx):
        return self.index_word[idx]

    def word_to_index(self, w):
        if w in self.word_index:
            return self.word_index[w]
        else:
            return self.UNK_IDX

    ###################
    # Create Features #
    ###################
    def make_features(
        self,
        docs,
        ext_idx_list,
        summaries_list,
        doc_trunc=100,
        sent_trunc=128,
        split_token="\n",
    ):

        # trunc document
        # 문서 내 doc_trunc 문장 개수까지 가져옴
        sents_list, targets, doc_lens, ext_sums, abs_sums = [], [], [], [], []
        for doc, ext_indices, abs_sum in zip(docs, ext_idx_list, summaries_list):
            labels = []
            for idx in range(len(doc)):
                if idx in ext_indices:
                    labels.append(1)
                else:
                    labels.append(0)

            max_sent_num = min(doc_trunc, len(doc))
            sents = doc[:max_sent_num]
            labels = labels[:max_sent_num]
            ext_sum = [sent for sent, label in zip(sents, labels) if label == 1]

            sents_list.append(sents)
            targets.append(labels)
            doc_lens.append(len(sents))
            ext_sums.append(ext_sum)
            abs_sums.append(abs_sum)

        # trunc or pad sent
        # 문장 내 sent_trunc 단어 개수까지 가져옴
        max_sent_len = 0
        batch_sents = []
        features_list = []
        for sents in sents_list:
            for sent in sents:
                words = self.tokenizer.morphs(sent)
                # words = [word for word in words if len(word) > 1]
                if len(words) > sent_trunc:
                    words = words[:sent_trunc]
                max_sent_len = len(words) if len(words) > max_sent_len else max_sent_len
                batch_sents.append(words)

            features = []
            for sent in batch_sents:
                feature = [self.PAD_IDX for _ in range(max_sent_len - len(sent))] + [
                    self.word_to_index(w) for w in sent
                ]
                features.append(feature)

            features_list.append(features)

        return features, targets, doc_lens, ext_sums, abs_sums, docs

    def make_predict_features(
        self, docs, sent_trunc=128, doc_trunc=100, split_token=". ",
    ):

        sents_list, doc_lens = [], []
        for doc in docs:
            sents = doc.split(split_token)
            max_sent_num = min(doc_trunc, len(sents))
            sents = sents[:max_sent_num]
            sents_list.extend(sents)
            doc_lens.append(len(sents))

        # trunc or pad sent
        max_sent_len = 0
        batch_sents = []
        for sent in sents_list:
            words = self.tokenizer.morphs(sent)
            # words = [word for word in words if len(word) > 1]
            if len(words) > sent_trunc:
                words = words[:sent_trunc]
            max_sent_len = len(words) if len(words) > max_sent_len else max_sent_len
            batch_sents.append(words)

        features = []
        for sent in batch_sents:
            feature = [self.PAD_IDX for _ in range(max_sent_len - len(sent))] + [
                self.word_to_index(w) for w in sent
            ]
            features.append(feature)

        return features, doc_lens

## DataLoader 

### collate_fn

In [30]:
def collate_fn(batch, feature):
    docs = [entry[0] for entry in batch]
    labels_list = [entry[1] for entry in batch]
    summaries_list = [entry[2] for entry in batch]

    features, targets, doc_lens, ext_sums, abs_sums, docs = feature.make_features(
        docs, labels_list, summaries_list
    )

    features = torch.LongTensor(features)
    #     targets = torch.FloatTensor(targets)
    max_doc_len = max(doc_lens)
    doc_lens = torch.LongTensor(doc_lens)
    return features, targets, doc_lens, max_doc_len, ext_sums, abs_sums, docs

In [31]:
# Feature class
mecab = Mecab()
feature = Feature(word_index, mecab)

In [32]:
# DataLoader
train_loader = DataLoader(
    dataset=trainset,
    batch_size=2,
    shuffle=True,
    collate_fn=partial(collate_fn, feature=feature),
    num_workers=8,
)

In [33]:
for batch in train_loader:
    features, targets, doc_lens, max_doc_len, ext_sums, abs_sums, docs = batch
    break

In [34]:
targets

[[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

In [36]:
features.shape

torch.Size([31, 73])