In [None]:
import argparse
import time
import os
import gc
import math
import multiprocessing
import io
import logging 
import itertools
import shutil
import pysnooper
import warnings
import glob
import pendulum
import json
import sys
import wandb
import matplotlib
import matplotlib.pyplot as plt
import scikitplot as skplt
import seaborn as sns
import numpy as np
import pandas as pd

from icecream import ic
from collections import Counter
from tqdm import tqdm_notebook as tqdm
from pathlib import Path
from IPython.core.interactiveshell import InteractiveShell

import torch
import torch.nn.functional as F
import transformers
import torchmetrics
from torch import nn
from torch import cuda

InteractiveShell.ast_node_interactivity = "all"

matplotlib.use('Agg')
warnings.filterwarnings("ignore")

seed = 9527
np.set_printoptions(suppress=True)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)

device = 'cuda' if cuda.is_available() else 'cpu'
print('Torch version: ', torch.__version__)
print('Device available:', torch.cuda.is_available())
print('Device name:', torch.cuda.get_device_name(0))
torch.set_printoptions(precision=8)


In [None]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--input_file', default=None, help="Input raw text file (or comma-separated list of files).")
parser.add_argument('--output_file', default=None, help="Output TF example file (or comma-separated list of files).")
parser.add_argument('--vocab_file', default=None, help="The vocabulary file that the ALBERT model was trained on.")
parser.add_argument('--spm_model_file', default=None, help="The model file for sentence piece tokenization.")
parser.add_argument('--input_file_mode', default="r",  help="The data format of the input file.")
parser.add_argument('--do_lower_case', default=True, help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
parser.add_argument('--do_whole_word_mask', default=True, help="Whether to use whole word masking rather than per-WordPiece masking.")
parser.add_argument('--do_permutation', default=False, help="Whether to do the permutation training.")
parser.add_argument('--favor_shorter_ngram', default=True, help="Whether to set higher probabilities for sampling shorter ngrams.")
parser.add_argument('--random_next_sentence', default=False, help="Whether to use the sentence that's right before the current sentence "
                    "as the negative sample for next sentence prection, rather than using "
                    "sentences from other random documents.")
parser.add_argument('--max_seq_length', default=512, help="Maximum sequence length.")
parser.add_argument('--ngram', default=3, help="Maximum number of ngrams to mask.")
parser.add_argument('--max_predictions_per_seq', default=20, help="Maximum number of masked LM predictions per sequence.")
parser.add_argument('--random_seed', default=12345, help="Random seed for data generation.")
parser.add_argument('--dupe_factor', default=5, help="Number of times to duplicate the input data (with different masks).")
parser.add_argument('--masked_lm_prob', default=0.15, help="Masked LM probability.")
parser.add_argument('--short_seq_prob', default=0.1, help="Probability of creating sequences which are shorter than the maximum length.")


opt = parser.parse_args(args=[
    '--input_file', '1995_income',  
    '--output_file', 'MLP',
    '--spm_model_file', './wiki-ja_albert.model',
    '--vocab_file', './wiki-ja_albert.vocab',
    '--do_whole_word_mask', False,
    '--do_permutation', False,
    '--favor_shorter_ngram', False,
    '--random_next_sentenc', False
])


In [None]:
seed = 202105

# main'
main_path = Path('/home/jupyter/gogolook')
main_cached_path = Path('/home/jupyter/gogolook/data')

# general setting
main_data_path = main_path / 'data' / 'jp_data' 
main_model_path = main_path / 'models'
cache_data_path = main_cached_path / 'cache_data_dir'
cache_models_path = main_cached_path / 'cache_models_fir'

# models
albert_zh_path = main_model_path / 'albert_zh'

# data
regex_file_format = '*.json'
data_tag = 'pretraining_data'
valid_data_tag = 'pretraining_data'
test_data_tag = 'pretraining_data'

experiment_train_data_path = main_data_path / f'train_{data_tag}'    
experiment_valid_data_path = main_data_path / f'valid_{valid_data_tag}'
experiment_test_data_path = main_data_path / f'test_{test_data_tag}'

training_data_path = experiment_train_data_path
validation_data_path = experiment_valid_data_path
testing_data_path = experiment_test_data_path


In [None]:
project = 'jp-pretrain-model'
project_shortname = 'jp-sms'
group_tag = 'experiment' # 1. functional 2. experiment 3. staging 4. production
job_type = 'baseline' # 1. baseline 2. optimize 3. hyper-tuning
addition_tag = [data_tag, 'pytorch'] # exponential_decay
method_tag = 'pretrain' # pretrain / finetune / pretrain_finetune
time_tag = pendulum.now(tz='Asia/Taipei').strftime('%Y%m%d%H%M%S')
run_id = '{}_{}_{}'.format(project, job_type, time_tag)
print('Run id is {}'.format(run_id))



In [None]:
import pyarrow as pa
from datasets import load_dataset
from datasets import total_allocated_bytes
from dataclasses import dataclass, field
from typing import Dict, Optional, Union, List
from transformers import (
    BertTokenizer, AlbertForPreTraining, AlbertModel, AlbertConfig, PreTrainedTokenizer)

from transformers import AutoModel, AutoTokenizer, BertJapaneseTokenizer
#import datasets
#datasets.logging.set_verbosity_info()
#datasets.logging.get_verbosity()


In [None]:
mecab_tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese", word_tokenizer_type="mecab", cache_dir=cache_models_path)
# Input Japanese Text
line = "アンパサンド (&、英語名：) とは並立助詞「…と…」を意味する記号である。ラテン語の の合字で、Trebuchet MSフォントでは、と表示され \"et\" の合字であることが容易にわかる。"
mecab_inputs = mecab_tokenizer(line, return_tensors="pt")
print(mecab_tokenizer.decode(mecab_inputs['input_ids'][0]))
corpus_size = len(mecab_tokenizer)

In [None]:
mecab_tokenizer.max_model_input_sizes

# Model config

In [None]:
optimizer_config = {
    "SGD": {
        "init_learning_rate": 1e-1, 
        "pre_finetune_learning_rate": 1e-3
    },
    "Adam": {
        "init_learning_rate": 1e-3,
        "pre_finetune_learning_rate": 1e-5
    }, 
    "RAdam": {
        "init_learning_rate": 1e-3,
        "pre_finetune_learning_rate": 2e-5
    }     
}

model_config = {
    "epochs": 5,
    "initial_epochs": 20,
    "batch_size": 32,
    "max_tokens_length": 512,
    "threshold": 0.5,
    "optimizer_method": "Adam",
    "init_learning_rate": optimizer_config['Adam']['init_learning_rate'],
    "pre_finetune_learning_rate": optimizer_config['Adam']['pre_finetune_learning_rate'],
    "end_learning_rate": 1e-5,
    "lsm": 0.0,
    "hidden_dropout_prob": 0.1,
    "use_warmup": False,
    "use_multi_gpus": True
}



# Data loader pipeline

In [None]:
torch.multiprocessing.set_start_method('spawn')
#datasets.config.IN_MEMORY_MAX_SIZE
@dataclass(eq=False)
class GenerateDatasets: 
    #files_list: str = field(
    #    default=None, metadata={"help": "The files list of data path"}
    #)
    data_path: str = field(
        default=None, metadata={"help": "The prefix path of files location"}
    )
    regex_file_format: str = field(
        default='*.parquet', metadata={"help": "The files format."}
    )
    batch_size: int = field(
        default=128, metadata={"help": "Batch size"}
    )
    is_training: bool = field(
        default=True, metadata={"help": "Is use training mode to create data pipeline"}
    )
    device: str = field(
        default='cpu', metadata={"help": "Which device to use [cpu, cuda]"}
    )
    cache_data_path: str = field(
        default=None, metadata={"help": "The path to cache data."}
    )
        
    def __post_init__(self):
        self.get_files_list = glob.glob(os.path.join(str(self.data_path), self.regex_file_format))
        #self.get_files_list = '/home/jupyter/gogolook/data/jp_data/valid_pretraining_data/valid_all-maxseq512_BG.parquet'
        self.encoding_columns = ['input_ids', 'token_type_ids', 'attention_mask']
        self.target_columns = ['masked_lm_labels', 'next_sentence_labels']
        
    def __call__(self, **kwargs):
        # data 已經存在 device (cuda) 裡，所以再用 pin_memory 會出現 error
        # RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned        
        dataset = load_dataset('parquet', data_files=self.get_files_list, cache_dir=self.cache_data_path, split='train')
        dataset.set_format(type='torch', columns=self.encoding_columns + self.target_columns) # , device=self.device
        #dataset = dataset.rename_column(self.target_column, 'labels')
        if self.is_training:
            drop_last = True
        else: 
            drop_last = False
            
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            pin_memory=True,
            shuffle=True,
            drop_last=drop_last,
            num_workers=multiprocessing.cpu_count())
        return dataloader
        

get_train_dataset = GenerateDatasets(
    data_path=training_data_path,
    batch_size=model_config['batch_size'],
    is_training=True,
    device=device,
    cache_data_path=cache_data_path)

get_valid_dataset = GenerateDatasets(
    data_path=training_data_path,
    batch_size=model_config['batch_size'],
    is_training=False,
    device=device,
    cache_data_path=cache_data_path)

train_dataloader = get_train_dataset()
#val_dataloader = get_valid_dataset()



In [None]:
get_train_dataset.get_files_list
#get_valid_dataset.get_files_list

In [None]:
next(iter(train_dataloader))

In [None]:
torch.cuda.empty_cache()
albert_config = AlbertConfig.from_json_file(albert_zh_path / 'albert_config' / 'albert_config_tiny.json')
pretrained_model_name_or_path = 'voidful/albert_chinese_tiny'
albert_pretrain_model = AlbertForPreTraining.from_pretrained(
    pretrained_model_name_or_path, 
    config=albert_config,             
    cache_dir=cache_models_path)
albert_pretrain_model.resize_token_embeddings(corpus_size)


In [None]:
albert_config

In [None]:
if model_config["use_multi_gpus"]:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    device_ids = [idx for idx in range(torch.cuda.device_count())]
    albert_pretrain_model = nn.DataParallel(albert_pretrain_model, device_ids=device_ids)
albert_pretrain_model.to(device)   


In [None]:
model_params = list(albert_pretrain_model.named_parameters())
optimizer_grounded_parameters_by_name = [
    {'params': [n for n, p in model_params if not any(nd in n for nd in ['bias', 'gamma', 'beta'])], 
     'weight_decay_rate': 1e-2  },
    {'params': [n for n, p in model_params if any(nd in n for nd in ['bias', 'gamma', 'beta'])], 
     'weight_decay_rate': 0.0 }
]

optimizer_grounded_parameters_by_name

In [None]:
model_params = list(albert_pretrain_model.named_parameters())

optimizer_grounded_parameters = [
    {'params': [p for n, p in model_params if not any(nd in n for nd in ['bias', 'gamma', 'beta'])], 
     'weight_decay_rate': 1e-2  },
    {'params': [p for n, p in model_params if any(nd in n for nd in ['bias', 'gamma', 'beta'])], 
     'weight_decay_rate': 0.0 }
]

In [None]:
from torch import optim
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR
        
optimizer = torch.optim.Adam(
    #params=optimizer_grounded_parameters, 
    params=albert_pretrain_model.parameters(),
    lr=1e-5,
    betas=(0.9, 0.98),
    weight_decay=0.0,
    eps=1e-6)

#optimizer = optim.SGD(sms_model.parameters(), lr=model_config['init_learning_rate'], weight_decay=1e-4)
#optimizer = optim.SGD(filter(lambda p: p.requires_grad, sms_model.parameters()), lr=model_config['init_learning_rate'], weight_decay=1e-4)

#scheduler = CyclicLR(
#    optimizer, 
#    base_lr=1e-5,
#    max_lr=model_config['init_learning_rate'],
#    step_size_up=model_config['training_steps'] * 1,
#    mode='triangular2',
#    scale_mode='cycle',
#    cycle_momentum=False
#)

if model_config["use_multi_gpus"]:
    optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    
#swa_model = AveragedModel(model)
#scheduler = CosineAnnealingLR(optimiezer, eta_min=1e-5, T_max=10)
#swa_start = 5
#swa_scheduler = SWALR(optimizer, swa_lr=0.05)


In [None]:
# Define Mertice
from torchmetrics import MetricCollection
metric_collection = MetricCollection([
    torchmetrics.Accuracy(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Precision(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Recall(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.F1(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device)
], prefix='Train_')

val_metric_collection = MetricCollection([
    torchmetrics.Accuracy(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Precision(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Recall(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.F1(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
], prefix='Val_')


# Training model

### Init wandb

In [None]:
if model_config['use_warmup']:
    model_config['warmup_steps'] = int(len(train_dataloader) * model_config['epochs'] * 0.1)
    model_config['decay_steps'] = len(train_dataloader) * model_config['epochs']
else:
    model_config['warmup_steps'] = None 
    model_config['decay_steps'] = None
model_config['training_steps'] = len(train_dataloader)

wandb.init(
    project=project,
    group=group_tag,
    job_type=job_type,
    name=run_id,
    notes=method_tag,
    tags=addition_tag,
    sync_tensorboard=False,
    config={**model_config},
    reinit=True    
)

wandb_config = wandb.config


### Simple test

In [None]:
'''
prefix = 'train'
for epoch in tqdm(range(10)): # model_config['epochs']
    start_time = time.time()    
    train_batch_loss = 0
    val_batch_loss = 0    
    
    # Training Step
    albert_pretrain_model = albert_pretrain_model.train()
    for step, train_batch in tqdm(enumerate(train_dataloader), 
                                  dynamic_ncols=False, 
                                  bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                  total=len(train_dataloader),
                                  leave=True, 
                                  unit='steps'):        
        input_ids = train_batch['input_ids'].to(device)
        attention_mask = train_batch['attention_mask'].to(device)
        token_type_ids = train_batch['token_type_ids'].to(device)

        mlm_labels = train_batch['masked_lm_labels'].to(device)
        sop_labels = train_batch['next_sentence_labels'].to(device)
        
        outputs = albert_pretrain_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=mlm_labels,
            sentence_order_label=sop_labels
        )
                    
        loss = outputs.loss.mean()
        perplexity  = torch.exp(loss)
        
        optimizer.zero_grad()
        loss.backward()
        if model_config["use_multi_gpus"]:
            optimizer.module.step()
        else:
            optimizer.step()
            
        wandb.log({
            "loss": loss,
            "perplexity": perplexity,            
        }, step=step)
        
'''

### Define training stepsm

In [None]:
def training_step(model, input_ids, attention_mask, token_type_ids, mlm_labels, sop_labels, use_multi_gpus=False):
    # Forward pass
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        labels=mlm_labels,
        sentence_order_label=sop_labels
    )
    loss = outputs.loss.mean()
    #if ((step + 1) % 100) == 0:
    #    show_log(train_batch_loss / step, example_count, step)

    # Backward pass
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    if use_multi_gpus:
        optimizer.module.step()
    else:
        optimizer.step()
    #scheduler.step()
    #torch.nn.utils.clip_grad_norm_(optimizer_grounded_parameters, max_norm=0.5)    
    #if epoch > swa_start:
    #    swa_model.update_parameters(model)
    #    swa_scheduler.step()
    #else:
    #    scheduler.step()
    return loss

@torch.no_grad()
def validataion_step(model, input_ids, attention_mask, token_type_ids, mlm_labels, sop_labels):    
    model.eval()
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        labels=mlm_labels,
        sentence_order_label=sop_labels)
    loss = outputs.loss.mean()
    return loss

@torch.no_grad()
def testing_step(model, dataset_inputs):
    pass


### Training

In [None]:

print('[RUN ID]: {}'.format(run_id))

use_epoch_tracking = False
use_step_tracking = True
#wandb.watch(sms_model, log="all", log_freq=1000)        
def show_logs(loss, step, is_epoch=False, prefix='Train', **kwargs):
    loss = float(loss)
    if is_epoch:
        wandb.log({"epoch": step, f"{prefix}_loss": loss}, step=step)
    else:
        wandb.log({f"{prefix}_step_loss": loss}, step=step)
        #print(f"{prefix} loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    if "perplexity" in kwargs.keys():
        wandb.log({f"{prefix}_perplexity": kwargs["perplexity"]}, step=step)
    
total_batches = len(train_dataloader) * model_config['epochs']

for epoch in tqdm(range(model_config['epochs'])): # model_config['epochs']
    start_time = time.time()    
    train_batch_loss = 0
    valid_batch_loss = 0  
    
    train_perplexity = 0
    valid_perplexity = 0
    
    # Training Step
    albert_pretrain_model = albert_pretrain_model.train()
    for step, train_batch in tqdm(enumerate(train_dataloader), 
                                  dynamic_ncols=False, 
                                  bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                  total=len(train_dataloader),
                                  leave=True, 
                                  unit='steps'):        
        input_ids = train_batch['input_ids'].to(device)
        attention_mask = train_batch['attention_mask'].to(device)
        token_type_ids = train_batch['token_type_ids'].to(device)        
        mlm_labels = train_batch['masked_lm_labels'].to(device)
        sop_labels = train_batch['next_sentence_labels'].to(device)
        
        train_loss = training_step(
            model=albert_pretrain_model, 
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            mlm_labels=mlm_labels, 
            sop_labels=sop_labels,
            use_multi_gpus=model_config["use_multi_gpus"])
        train_batch_loss += train_loss.item()
        train_perplexity += torch.exp(train_loss)
        
        if model_config["use_multi_gpus"]:
            last_lr = optimizer.module.param_groups[0]['lr']
        else:
            last_lr = optimizer.param_groups[0]['lr']
          
        if use_step_tracking:
            record_step = (step + 1) * (epoch + 1)
            wandb.log({'learning_rate': last_lr}, step=record_step)
            show_logs(
                train_batch_loss / record_step, 
                record_step, 
                perplexity=train_perplexity.item() / record_step)
    
    if use_epoch_tracking:
        train_epoch_loss = train_batch_loss / step
        wandb.log({'learning_rate': last_lr}, step=epoch)
        show_log(train_epoch_loss, epoch, is_epoch=True)  
        #train_metric_records = metric_collection.compute()
        #wandb.log(train_metric_records, step=epoch)
    
    # Validation Step
    albert_pretrain_model = albert_pretrain_model.eval()
    for step, valid_batch in tqdm(enumerate(val_dataloader), 
                                dynamic_ncols=False, 
                                bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                total=len(val_dataloader),
                                leave=True, 
                                unit='steps'):            
        input_ids = valid_batch['input_ids'].to(device)
        attention_mask = valid_batch['attention_mask'].to(device)
        token_type_ids = valid_batch['token_type_ids'].to(device)
        mlm_labels = valid_batch['masked_lm_labels'].to(device)
        sop_labels = valid_batch['next_sentence_labels'].to(device)
            
        valid_loss = validataion_step(
            model=albert_pretrain_model, 
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            mlm_labels=mlm_labels,
            sop_labels=sop_labels
        )
        valid_batch_loss += valid_loss.item()
        valid_perplexity += torch.exp(valid_loss)
        
        if use_step_tracking:
            record_step = (step + 1) * (epoch + 1)
            wandb.log({'learning_rate': last_lr}, step=record_step)
            show_logs(
                valid_batch_loss / record_step, 
                record_step, 
                prefix='valid', 
                perplexity=valid_perplexity.item() / record_step)
            
        #sk_metrics = sklearn_metrics(val_outputs, labels, 'train')
        #ic(sk_metrics)
        #ic(val_metric_collection(outputs, labels).compute())
        #ic(val_metric_collection(outputs, labels))
            
    if use_epoch_tracking:        
        valid_epoch_loss = valid_batch_loss / step 
        show_logs(valid_epoch_loss, epoch, is_epoch=True, prefix='Val')
        #val_metric_records = val_metric_collection.compute()
        #wandb.log(val_metric_records, step=epoch)
    
    loss_template = ("Epoch {}/{} - {:.0f}s {:.0f}ms/step - lr:{:} - loss: {:.6f} - val_loss: {:.6f}")    
    #metrics_template = (
    #    """
    #    categorical_accuracy: {:.4f} - f1_score: {:.4f} - multi_precision: {:.4f} - multi_recall: {:.4f}
    #    val_categorical_accuracy: {:.4f} -  val_f1_score: {:.4f} - val_multi_precision: {:.4f} - val_multi_recall: {:.4f}
    #    """
    #)
    end_time = time.time()
    each_steps_compute_time = (end_time - start_time)
    print(loss_template.format(
        epoch, model_config['epochs'], each_steps_compute_time, each_steps_compute_time * 1000 / model_config['training_steps'], 
        last_lr, train_epoch_loss, val_epoch_loss))

    #print(metrics_template.format(
    #    train_metric_records['Train_Accuracy'],
    #    train_metric_records['Train_F1'],
    #    train_metric_records['Train_Precision'],
    #    train_metric_records['Train_Recall'],
    #    val_metric_records['Val_Accuracy'],
    #    val_metric_records['Val_F1'],
    #    val_metric_records['Val_Precision'],
    #    val_metric_records['Val_Recall']
    #))
    
    if use_epoch_tracking:
        metric_collection.reset()
        val_metric_collection.reset()

wandb.finish()        

In [None]:
assert 1 == 2