## Test SummaRuNNer

## imports 

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [3]:
import dill
import yaml
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TestTubeLogger  # pip install test-tube

from functools import partial
from collections import OrderedDict
from konlpy.tag import Mecab
from tqdm import tqdm

from experiment import Experiment
from model import SummaRunner
from model import SumDataset, Feature
from model import build_vocab, collate_fn
from model.types_ import *

import warnings

warnings.filterwarnings(action="ignore")

In [4]:
DEVICE = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## config file

In [5]:
config_path = "./config.yaml"

with open(config_path, "r") as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

## DataLoader

In [6]:
# ----------------
# DataLoader
# ----------------

# data path
test_path = config["exp_params"]["test_path"]
vocab_path = config["exp_params"]["vocab_path"]

# vocab
with open(vocab_path, "rb") as f:
    word_index = dill.load(f)

# pretrained vectors

# Feature class
feature = Feature(word_index, Mecab())

# Dataset
testset = SumDataset(test_path)

# DataLoader
test_loader = DataLoader(
    dataset=testset,
    batch_size=config["exp_params"]["batch_size"],
    shuffle=False,
    collate_fn=partial(collate_fn, feature=feature),
    num_workers=8,
)

In [7]:
for batch in test_loader:
    docs, labels, doc_lens, max_doc_len, ext_sums, abs_sums, orgin_docs = batch
    break

In [8]:
abs_sums[0]

'국회 국방위원회 소속 민홍청 의원이 통일시대를 대비하여 자연환경이 우수하고 나라의 슬픈 역사를 간직하고 있는 DMZ를 보전하고 DMZ의 평화적인 이용을 내용으로 하는 특별법을 발의하였다. '

## Model load

In [9]:
ckpt_path = "../checkpoints/summarunnerepoch=13_val_loss=0.446.ckpt"

checkpoint = torch.load(ckpt_path)
checkpoint["state_dict"] = OrderedDict(
    [(key.replace("model.", ""), val) for key, val in checkpoint["state_dict"].items()]
)

In [10]:
# ----------------
# SetUp Model
# ----------------

# vocab_size
config["model_params"]["vocab_size"] = len(word_index)

model = SummaRunner(**config["model_params"]).to(DEVICE)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

SummaRunner(
  (abs_pos_embed): Embedding(100, 50)
  (rel_pos_embed): Embedding(25, 50)
  (encoder): Encoder(
    (sent_encoder): SentenceEncoder(
      (embed): Embedding(40002, 100, padding_idx=0)
      (bilstm): LSTM(100, 128, batch_first=True, bidirectional=True)
    )
    (doc_encoder): DocumentEncoder(
      (bilstm): LSTM(256, 128, batch_first=True, bidirectional=True)
    )
  )
  (fc): Linear(in_features=256, out_features=256, bias=True)
  (content): Linear(in_features=256, out_features=1, bias=False)
  (salience): Bilinear(in1_features=256, in2_features=256, out_features=1, bias=False)
  (novelty): Bilinear(in1_features=256, in2_features=256, out_features=1, bias=False)
  (abs_pos): Linear(in_features=50, out_features=1, bias=False)
  (rel_pos): Linear(in_features=50, out_features=1, bias=False)
)

## Test

In [12]:
num_topk = 3
file_id = 1

for batch in tqdm(test_loader):
    features, targets, doc_lens, ext_sums, abs_sums, docs = batch
    preds = model(features.to(DEVICE), doc_lens)

    start = 0
    for doc_id, doc_len in enumerate(doc_lens):
        stop = start + doc_len
        pred = preds[start:stop]

        topk_indices = pred.topk(num_topk)[1].tolist()
        topk_indices.sort()

        doc = docs[doc_id]
        hyp = [doc[idx] for idx in topk_indices]
        ext_ref = ext_sums[doc_id]
        abs_ref = abs_sums[doc_id]

        with open(f"../outputs/ext_ref/{file_id}.txt", "w", encoding="utf8") as f:
            f.write("\n".join(ext_ref))
        with open(f"../outputs/abs_ref/{file_id}.txt", "w", encoding="utf8") as f:
            f.write("\n".join(abs_ref))
        with open(f"../outputs/hyp/{file_id}.txt", "w", encoding="utf8") as f:
            f.write("\n".join(hyp))

        start = stop
        file_id += 1

  3%|▎         | 157/5000 [00:38<19:44,  4.09it/s]


In [13]:
import os

In [None]:
os.