## Train SummaRuNNer

## imports 

In [1]:
%load_ext lab_black

In [2]:
import sys

sys.path.append("..")

In [4]:
import dill
import yaml
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TestTubeLogger  # pip install test-tube

from functools import partial
from collections import OrderedDict
from konlpy.tag import Mecab

from model import SummaRunner
from model import SumDataset, Feature
from model import build_vocab
from model.types_ import *

import warnings

warnings.filterwarnings(action="ignore")

In [5]:
# DEVICE = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## config file

In [6]:
config_path = "./config.yaml"

with open(config_path, "r") as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

In [7]:
config

{'model_params': {'embed_dim': 100,
  'hidden_dim': 128,
  'num_layers': 1,
  'bidirectional': True,
  'dropout_p': 0.3,
  'maxlen': 50},
 'exp_params': {'train_path': '../../../../datasets/kor_data/total_data/train_50965.jsonl',
  'dev_path': '../../../../datasets/kor_data/total_data/dev_50965.jsonl',
  'test_path': '../../../../datasets/kor_data/total_data/test_50965.jsonl',
  'vocab_path': './word_index_v02.pkl',
  'batch_size': 32,
  'LR': 0.001},
 'trainer_params': {'gpus': 1, 'max_epochs': 30},
 'logging_params': {'save_dir': 'logs/',
  'name': 'SummaRuNNer',
  'manual_seed': 42}}

## Model

In [10]:
# ----------------
# SetUp Model
# ----------------

# vocab_size
config["model_params"]["vocab_size"] = len(word_index)
# num_class
config["model_params"]["num_class"] = 1

model = SummaRunner(**config["model_params"])
experiment = Experiment(model, config["exp_params"])

In [11]:
# ----------------
# TestTubeLogger
# ----------------
tt_logger = TestTubeLogger(
    save_dir=config["logging_params"]["save_dir"],
    name=config["logging_params"]["name"],
    debug=False,
    create_git_tag=False,
)

# ----------------
# Checkpoint
# ----------------
checkpoint_callback = ModelCheckpoint(
    filepath="./checkpoints/summarunner{epoch:02d}_{val_loss:.3f}",
    monitor="val_loss",
    verbose=True,
    save_top_k=5,
)

early_stopping = EarlyStopping(monitor="val_loss", patience=5, verbose=True)

EarlyStopping mode auto is unknown, fallback to auto mode.
EarlyStopping mode set to min for monitoring val_loss.


## Trainer

In [12]:
# ----------------
# Trainer
# ----------------

runner = Trainer(
    default_save_path=f"{tt_logger.save_dir}",
    min_epochs=1,
    logger=tt_logger,
    log_save_interval=100,
    train_percent_check=1.0,
    val_percent_check=1.0,
    num_sanity_val_steps=5,
    early_stop_callback=early_stopping,
    checkpoint_callback=checkpoint_callback,
    **config["trainer_params"],
)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]


## Train

In [None]:
# ----------------
# Start Train
# ----------------
runner.fit(experiment, train_loader, dev_loader)


   | Name                              | Type            | Params
------------------------------------------------------------------
0  | model                             | SummaRunner     | 4 M   
1  | model.abs_pos_embed               | Embedding       | 5 K   
2  | model.rel_pos_embed               | Embedding       | 500   
3  | model.encoder                     | Encoder         | 4 M   
4  | model.encoder.sent_encoder        | SentenceEncoder | 4 M   
5  | model.encoder.sent_encoder.embed  | Embedding       | 4 M   
6  | model.encoder.sent_encoder.bilstm | LSTM            | 235 K 
7  | model.encoder.doc_encoder         | DocumentEncoder | 395 K 
8  | model.encoder.doc_encoder.bilstm  | LSTM            | 395 K 
9  | model.fc                          | Linear          | 65 K  
10 | model.content                     | Linear          | 256   
11 | model.salience                    | Bilinear        | 65 K  
12 | model.novelty                     | Bilinear        | 65 K  
13 | mod

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…