In [None]:
import os
import json
import math
import torch
import pickle
import pathlib
import shutil
import warnings
import transformers

import numpy as np
import pandas as pd
import plotly.express as px

In [None]:
from pathlib import Path
from itertools import chain
from tqdm import tqdm

In [None]:
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import BertTokenizer
from transformers import BertModel

In [None]:
import nbimporter

from classify_comment_taptap_model import BERTClass

best model, loss, epoch
last  model,       , epoch

In [None]:
def clean(flag, first_round):
    if first_round: 
        warnings.warn(f'1st round, path ./model/{flag} will be removed if exists!')
        path = os.path.join('./model', flag)
        shutil.rmtree(path, ignore_errors=True)

In [None]:
def load_or_create_model(flag, model_name=None, mlb=None, device='cpu'):
    """ models split into best, last, epoch flaged
    """
    path = os.path.join('./model', flag)
    path_model_last = os.path.join(path, 'last.bin')
    
    # no dir exist
    if not os.path.exists(path):
        # os.mkdir(path)
        pathlib.Path(path).mkdir(parents=True)
    
    model = BERTClass(model_name, mlb) if not os.path.exists(path_model_last) else torch.load(path_model_last)
    model.to(device)
    
    return model

In [None]:
def better_loss(loss1, loss2):
    return loss1[0] < loss2[0] and loss1[1] < loss2[1]

In [None]:
# find and save best model

def create_or_update_best(flag, model, loss, epoch):
    path = os.path.join('./model', flag)
    path_best_epoch = os.path.join(path, 'best.epoch')
    path_best_model = os.path.join(path, 'best.bin')
    path_best_loss = os.path.join(path, 'best.loss')
    
    
    # not exist
    if not os.path.exists(path_best_model):
        with open(path_best_epoch, 'wb') as f:
            pickle.dump(epoch, f)
            
        print(f'current best, epoch accumulate: {epoch}')
        
        torch.save(model, path_best_model)
        with open(path_best_loss, 'wb') as f:
            pickle.dump(loss, f)
        return
    
    
    # exists
    with open(path_best_loss, 'rb') as f:
        history_best_loss = pickle.load(f)
    
    if better_loss(loss, history_best_loss):
        # best epoch
        best_epoch = load_or_create_acc_epoch(flag) + epoch
        with open(path_best_epoch, 'wb') as f:
            pickle.dump(best_epoch, f)
        
        print(f'current best, epoch accumulate: {best_epoch}')
        
        # best model
        torch.save(model, path_best_model)
        # best loss
        with open(path_best_loss, 'wb') as f:
            pickle.dump(loss, f)

In [None]:
def load_best(flag):
    path = os.path.join('./model', flag)
    path_best_model = os.path.join(path, 'best.bin')
    return torch.load(path_best_model)

In [None]:
def load_acc_epoch(flag):
    path = os.path.join('./model', flag)
    path_epoch = os.path.join(path, 'last.epoch')
    
    with open(path_epoch, 'rb') as f:
        return pickle.load(f)

In [None]:
def load_or_create_acc_epoch(flag, epoch=0):
    path = os.path.join('./model', flag)
    path_epoch = os.path.join(path, 'last.epoch')
    
    # non exist
    if not os.path.exists(path_epoch):
        with open(path_epoch, 'wb') as f:
            pickle.dump(epoch, f)
        return epoch
    
    # exists
    with open(path_epoch, 'rb') as f:
        history_epoch = pickle.load(f)
        
    acc_epoch = history_epoch + epoch
    with open(path_epoch, 'wb') as f:
        pickle.dump(acc_epoch, f)
        
    return acc_epoch

In [None]:
# find and save best model

def create_or_update_last(flag, model, epoch):
    path = os.path.join('./model', flag)
    path_last_model = os.path.join(path, 'last.bin')
    
    
    # not exist
    acc_epoch = load_or_create_acc_epoch(flag, epoch)
    print(f'last epoch accumulate: {acc_epoch}')
        
    torch.save(model, path_last_model)

In [None]:
# save train loss, auto increment

def create_or_update_loss(flag, loss):
    path = os.path.join('./model', flag)
    path_loss = os.path.join(path, 'loss')
    
    # non exist
    if not os.path.exists(path_loss):
        with open(path_loss, 'wb') as f:
            pickle.dump(loss, f)
        return
    
    # exists
    with open(path_loss, 'rb') as f:
        history_loss = pickle.load(f)
        
    with open(path_loss, 'wb') as f:
        pickle.dump(history_loss+loss, f)

In [None]:
# load exist loss

def load_loss(flag):
    path = os.path.join('./model', flag)
    path_loss = os.path.join(path, 'loss')
    
    with open(path_loss, 'rb') as f:
        loss = pickle.load(f)
        
    return loss

In [None]:
# plot loss, with history

def plot_loss(flag):
    loss = load_loss(flag)
    
    df = pd.DataFrame(loss, columns=['train', 'test'])
    #print(df)
    
    df = df.stack()
    df.name = 'loss'
    df = df.reset_index().rename(columns={'level_0':'epoch', 'level_1':'stage'})
    df.loc[:, 'epoch'] += 1
    
    fig = px.line(df, x="epoch", y="loss", color='stage')
    fig.show()
    
    return df