In [1]:
#%pip install --user lightning-bolts -q

In [2]:
import argparse
import time
import os
import gc
import math
import multiprocessing
import io
import logging 
import itertools
import shutil
import pysnooper
import warnings
import glob
import pendulum
import random
import json
import sys
import wandb
import matplotlib
import matplotlib.pyplot as plt
import scikitplot as skplt
import seaborn as sns
import numpy as np
import pandas as pd

from icecream import ic
from collections import Counter
from tqdm import tqdm_notebook as tqdm
from pathlib import Path
from IPython.core.interactiveshell import InteractiveShell

import torch
import torch.nn.functional as F
import transformers
import torchmetrics
from torch import nn
from torch import cuda
from typing import Dict, List, Tuple
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter
    
import pyarrow as pa
from datasets import load_dataset
from datasets import total_allocated_bytes
from dataclasses import dataclass, field
from typing import Dict, Optional, Union, List
from transformers import (
    BertTokenizer, AlbertForPreTraining, AlbertModel, AlbertConfig, BertJapaneseTokenizer, PreTrainedTokenizer, PreTrainedModel
)
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

from torch.cuda.amp import GradScaler, autocast
from torch import optim
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR
from transformers import get_polynomial_decay_schedule_with_warmup    
    
InteractiveShell.ast_node_interactivity = "all"

matplotlib.use('Agg')
warnings.filterwarnings("ignore")

seed = 9527
np.set_printoptions(suppress=True)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)

device = 'cuda' if cuda.is_available() else 'cpu'
print('Torch version: ', torch.__version__)
print('Device available:', torch.cuda.is_available())
print('Device name:', torch.cuda.get_device_name(0))
torch.set_printoptions(precision=8)


Torch version:  1.8.0
Device available: True
Device name: Tesla P100-PCIE-16GB


In [3]:
logger = logging.getLogger(__name__)
logger



In [4]:
'''
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
torch.distributed.init_process_group(backend='nccl')

local_rank = -1
if local_rank not in [-1, 0]:
    torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
    
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
'''    

'\nfrom torch.utils.data.distributed import DistributedSampler\nfrom torch.nn.parallel import DistributedDataParallel as DDP\ntorch.distributed.init_process_group(backend=\'nccl\')\n\nlocal_rank = -1\nif local_rank not in [-1, 0]:\n    torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab\n    \ndef setup(rank, world_size):\n    os.environ[\'MASTER_ADDR\'] = \'localhost\'\n    os.environ[\'MASTER_PORT\'] = \'12355\'\n\n    # initialize the process group\n    dist.init_process_group("gloo", rank=rank, world_size=world_size)\n'

In [5]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")

args = parser.parse_args(args=[
    '--server_ip', 'localhost',
    '--server_port', '12355'
])


_StoreAction(option_strings=['--local_rank'], dest='local_rank', nargs=None, const=None, default=-1, type=<class 'int'>, choices=None, help='For distributed training: local_rank', metavar=None)

_StoreAction(option_strings=['--server_ip'], dest='server_ip', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='For distant debugging.', metavar=None)

_StoreAction(option_strings=['--server_port'], dest='server_port', nargs=None, const=None, default='', type=<class 'str'>, choices=None, help='For distant debugging.', metavar=None)

In [6]:
seed = 202105

# main'
main_path = Path('/home/jupyter/gogolook')
main_cached_path = Path('/home/jupyter/gogolook/data')

# general setting
main_data_path = main_path / 'data' / 'jp_data' 
main_model_path = main_path / 'models'
tensorboard_path = main_path / 'tensorboard_log_dir'
cache_data_path = main_cached_path / 'cache_data_dir'
cache_models_path = main_cached_path / 'cache_models_dir'

# models
albert_zh_path = main_model_path / 'albert_zh'
pretrain_model_path = main_model_path / 'jp_pretrain_model'

# data
regex_file_format = '*.json'
data_tag = 'pretraining_data'
valid_data_tag = 'pretraining_data'
test_data_tag = 'pretraining_data'

experiment_total_data_path = main_data_path / f'total_{data_tag}'
experiment_train_data_path = main_data_path / f'train_{data_tag}'
experiment_valid_data_path = main_data_path / f'valid_{valid_data_tag}'
experiment_test_data_path = main_data_path / f'test_{test_data_tag}'

training_data_path = experiment_train_data_path
validation_data_path = experiment_valid_data_path
testing_data_path = experiment_test_data_path


for folder in [pretrain_model_path, tensorboard_path]:
    if not folder.exists():
        folder.mkdir()
        

In [7]:
project = 'jp-pretrain-model'
project_shortname = 'jp-sms'
group_tag = 'experiment' # 1. functional 2. experiment 3. staging 4. production
job_type = 'baseline' # 1. baseline 2. optimize 3. hyper-tuning
addition_tag = [data_tag, 'pytorch'] # exponential_decay
method_tag = 'pretrain' # pretrain / finetune / pretrain_finetune
time_tag = pendulum.now(tz='Asia/Taipei').strftime('%Y%m%d%H%M%S')
run_id = '{}_{}_{}'.format(project, job_type, time_tag)
print('Run id is {}'.format(run_id))


Run id is jp-pretrain-model_baseline_20210914151521


In [8]:
mecab_tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese", word_tokenizer_type="mecab", cache_dir=cache_models_path)
# Input Japanese Text
line = "アンパサンド (&、英語名：) とは並立助詞「…と…」を意味する記号である。ラテン語の の合字で、Trebuchet MSフォントでは、と表示され \"et\" の合字であることが容易にわかる。"
mecab_inputs = mecab_tokenizer(line, return_tensors="pt")
print(mecab_tokenizer.decode(mecab_inputs['input_ids'][0]))
corpus_size = len(mecab_tokenizer)


[CLS] アンパサンド (&、 英語 名 :) と は 並立 助詞 「... と...」 を 意味 する 記号 で ある 。 ラテン語 の の 合 字 で 、 Trebuchet MS フォント で は 、 と 表示 さ れ " et " の 合 字 で ある こと が 容易 に わかる 。 [SEP]


In [9]:
mecab_tokenizer.max_model_input_sizes

{'cl-tohoku/bert-base-japanese': 512,
 'cl-tohoku/bert-base-japanese-whole-word-masking': 512,
 'cl-tohoku/bert-base-japanese-char': 512,
 'cl-tohoku/bert-base-japanese-char-whole-word-masking': 512}

# Model config

In [10]:
optimizer_config = {
    "SGD": {
        "learning_rate": 1e-1, 
        "end_learning_rate": 1e-3
    },
    "Adam": {
        "learning_rate": 1e-3,
        "end_learning_rate": 1e-5,
        "weight_decay": 0.01,
        "epsilon": 1e-8
    }, 
    "RAdam": {
        "learning_rate": 1e-3,
        "end_learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epsilon": 1e-8
    }     
}

model_config = {
    "epochs": 3,
    "num_train_epochs": 3,
    "per_gpu_train_batch_size": 16,
    "per_gpu_eval_batch_size": 32,
    "batch_size": 128,
    "max_tokens_length": 512,
    "threshold": 0.5,
    "optimizer_method": "Adam",
    "learning_rate": optimizer_config['Adam']['learning_rate'],
    "end_learning_rate": optimizer_config['Adam']['end_learning_rate'],
    "weight_decay": optimizer_config['Adam']['weight_decay'],
    "epsilon": optimizer_config['Adam']['epsilon'],
    "lsm": 0.0,
    "hidden_dropout_prob": 0.1,
    "max_grad_norm": 1,
    "use_warmup": True,
    "n_gpu": torch.cuda.device_count(),
    "gradient_accumulation_steps": 8,
    "output_dir": str(pretrain_model_path / run_id),
    "max_steps": 125000,
    "logging_steps": 100,
    "save_steps": 25000,
    "evaluate_during_training": False,
    "save_total_limit": 3,
    "seed": 9527,
}



# Data loader pipeline

In [11]:
#torch.multiprocessing.set_start_method('spawn')
#datasets.config.IN_MEMORY_MAX_SIZE
@dataclass(eq=False)
class GenerateDatasets: 
    #files_list: str = field(
    #    default=None, metadata={"help": "The files list of data path"}
    #)
    data_path: str = field(
        default=None, metadata={"help": "The prefix path of files location"}
    )
    regex_file_format: str = field(
        default='*.parquet', metadata={"help": "The files format."}
    )
    batch_size: int = field(
        default=128, metadata={"help": "Batch size"}
    )
    is_training: bool = field(
        default=True, metadata={"help": "Is use training mode to create data pipeline"}
    )
    device: str = field(
        default='cpu', metadata={"help": "Which device to use [cpu, cuda]"}
    )
    cache_data_path: str = field(
        default=None, metadata={"help": "The path to cache data."}
    )
    use_streaming_mode: bool = field(
        default=False, metadata={"help": "Use streaming mode to download data."}
    )
        
    def __post_init__(self):
        self.get_files_list = glob.glob(os.path.join(str(self.data_path), self.regex_file_format))
        #self.get_files_list = '/home/jupyter/gogolook/data/jp_data/valid_pretraining_data/valid_all-maxseq512_BG.parquet'
        self.encoding_columns = ['input_ids', 'token_type_ids', 'attention_mask']
        self.target_columns = ['masked_lm_labels', 'next_sentence_labels']
        
    def __call__(self, **kwargs):
        # data 已經存在 device (cuda) 裡，所以再用 pin_memory 會出現 error
        # RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned        
        dataset = load_dataset('parquet', data_files=self.get_files_list, cache_dir=self.cache_data_path, split='train')
        dataset.set_format(type='torch', columns=self.encoding_columns + self.target_columns) # , device=self.device
        #dataset = dataset.rename_column(self.target_column, 'labels')
        if self.is_training:
            drop_last = True
        else: 
            drop_last = False
            
        #dataloader = torch.utils.data.DataLoader(
        #    dataset,
        #    batch_size=self.batch_size,
        #    pin_memory=True,
        #    shuffle=True,
        #    drop_last=drop_last,
        #    num_workers=multiprocessing.cpu_count())
        return dataset # dataloader
        

get_train_dataset = GenerateDatasets(
    data_path=training_data_path,
    batch_size=model_config['batch_size'],
    is_training=True,
    device=device,
    cache_data_path=cache_data_path)

get_valid_dataset = GenerateDatasets(
    data_path=training_data_path,
    batch_size=model_config['batch_size'],
    is_training=False,
    device=device,
    cache_data_path=cache_data_path)

train_dataset = get_train_dataset()
valid_dataset = get_valid_dataset()

    



### Testing load from gcs

In [12]:
#import gcsfs
#from datasets import load_from_disk
#gcs = gcsfs.GCSFileSystem(project='data-research-216307') 
#gcs_files_list = gcs.glob('gs://gogolook-ml-data-production/serve-dev/sms/data/experimental_jp_data/train_pretraining_data/*.parquet')
#gcs_files_list = [ "gs://" + path for path in gcs_files_list]
#dataset = load_from_disk(dataset_path="gs://gogolook-ml-data-production/serve-dev/sms/data/experimental_jp_data/train_pretraining_data/", fs=gcs)

# saves encoded_dataset to your s3 bucket
#train_dataset.save_to_disk('gcs://gogolook-ml-data-production/serve-dev/sms/data/experimental_jp_data/preprocessing_dataset', fs=gcs)
#train_dataset.save_to_disk('/home/jupyter/gogolook/data/jp_data/preprocessing_dataset/')

### Streaming test

In [13]:
'''
get_files_list = glob.glob(os.path.join(str(experiment_train_data_path), "*.parquet"))

dataset = load_dataset('parquet', data_files=get_files_list[0], cache_dir=cache_data_path, split='train', streaming=True)


map_dataset = dataset.map(lambda example: (example["input_ids"], example["token_type_ids"], example["attention_mask"]), batched=True, batch_size=64)

shuffled_dataset = map_dataset.shuffle(buffer_size=100, seed=seed)


torch_dataset  = shuffled_dataset.with_format("torch")
assert isinstance(torch_dataset, torch.utils.data.IterableDataset)
#sampler = torch.utils.data.Sampler(torch_dataset)
#batch_sampler = torch.utils.data.BatchSampler(sampler, 64, False)

def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    worker_id = worker_info.id
    split_size = len(dataset.data) // worker_info.num_workers
    dataset.data = dataset.data[worker_id * split_size:(worker_id + 1) * split_size] 
    
def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)

dataloader = torch.utils.data.DataLoader(
        torch_dataset,
        batch_size=128,
        pin_memory=True,
        drop_last=False,
        num_workers=multiprocessing.cpu_count())
def worker_init_fn(_):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    worker_id = worker_info.id
    split_size = 64 // worker_info.num_workers
    dataset.data = dataset.data[worker_id * split_size:(worker_id + 1) * split_size]     
dataloader = torch.utils.data.DataLoader(torch_dataset, batch_size=64, worker_init_fn=worker_init_fn, num_workers=multiprocessing.cpu_count())
''' 

'\nget_files_list = glob.glob(os.path.join(str(experiment_train_data_path), "*.parquet"))\n\ndataset = load_dataset(\'parquet\', data_files=get_files_list[0], cache_dir=cache_data_path, split=\'train\', streaming=True)\n\n\nmap_dataset = dataset.map(lambda example: (example["input_ids"], example["token_type_ids"], example["attention_mask"]), batched=True, batch_size=64)\n\nshuffled_dataset = map_dataset.shuffle(buffer_size=100, seed=seed)\n\n\ntorch_dataset  = shuffled_dataset.with_format("torch")\nassert isinstance(torch_dataset, torch.utils.data.IterableDataset)\n#sampler = torch.utils.data.Sampler(torch_dataset)\n#batch_sampler = torch.utils.data.BatchSampler(sampler, 64, False)\n\ndef worker_init_fn(_):\n    worker_info = torch.utils.data.get_worker_info()\n    dataset = worker_info.dataset\n    worker_id = worker_info.id\n    split_size = len(dataset.data) // worker_info.num_workers\n    dataset.data = dataset.data[worker_id * split_size:(worker_id + 1) * split_size] \n    \ndef w

In [14]:
#get_train_dataset.get_files_list
#get_valid_dataset.get_files_list

In [15]:
#model_config['training_steps'] = len(train_dataloader) * model_config['epochs']
#if model_config['use_warmup']:
#    model_config['warmup_steps'] = int(len(train_dataloader) * model_config['epochs'] * 0.1)
#    model_config['decay_steps'] = len(train_dataloader) * model_config['epochs']
#else:
#    model_config['warmup_steps'] = None 
#    model_config['decay_steps'] = None

torch.cuda.empty_cache()
albert_config = AlbertConfig.from_json_file(albert_zh_path / 'albert_config' / 'albert_config_tiny.json')
pretrained_model_name_or_path = 'voidful/albert_chinese_tiny'
albert_pretrain_model = AlbertForPreTraining.from_pretrained(
    pretrained_model_name_or_path, 
    config=albert_config,             
    cache_dir=cache_models_path)
albert_pretrain_model.resize_token_embeddings(corpus_size)


Some weights of AlbertForPreTraining were not initialized from the model checkpoint at voidful/albert_chinese_tiny and are newly initialized: ['sop_classifier.classifier.bias', 'sop_classifier.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedding(32000, 128)

In [16]:
albert_config

AlbertConfig {
  "_name_or_path": "voidful/albert_chinese_tiny",
  "attention_probs_dropout_prob": 0.0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "directionality": "bidi",
  "embedding_size": 128,
  "eos_token_id": 3,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 312,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 1248,
  "layer_norm_eps": 1e-12,
  "ln_type": "postln",
  "max_position_embeddings": 512,
  "model_type": "albert",
  "num_attention_heads": 12,
  "num_hidden_groups": 1,
  "num_hidden_layers": 4,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.0",
  "type_vocab_size": 2,
  "vocab_size": 32000
}

In [17]:
if model_config["n_gpu"] > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    #device_ids = [idx for idx in range(torch.cuda.device_count())]
    #albert_pretrain_model = nn.DataParallel(albert_pretrain_model, device_ids=device_ids)


Let's use 4 GPUs!


# Define Optimizer

In [18]:
model_params = list(albert_pretrain_model.named_parameters())
no_decay = ["bias", "gamma", "beta", "LayerNorm.weight"]

optimizer_grounded_parameters_by_name = [
    {'params': [n for n, p in model_params if not any(nd in n for nd in no_decay)], 
     'weight_decay_rate': 0.0 },
    {'params': [n for n, p in model_params if any(nd in n for nd in no_decay)], 
     'weight_decay_rate': 0.0 }
]

optimizer_grounded_parameters_by_name

[{'params': ['albert.embeddings.word_embeddings.weight',
   'albert.embeddings.position_embeddings.weight',
   'albert.embeddings.token_type_embeddings.weight',
   'albert.encoder.embedding_hidden_mapping_in.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.attention.dense.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight',
   'albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight',
   'albert.pooler.weight',
   'predictions.dense.weight',
   'sop_classifier.classifier.weight'],
  'weight_decay_rate': 0.0},
 {'params': ['albert.embeddings.LayerNorm.weight',
   'albert.embeddings.LayerNorm.bias',
   'a

In [19]:
from torch.optim.lr_scheduler import _LRScheduler

class PolynomialDecay(_LRScheduler):
    def __init__(self, optimizer, decay_steps, end_learning_rate=0.0001, power=0.5, cycle=False, last_epoch=-1, verbose=False):
        if decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')            
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle
        super(PolynomialDecay, self).__init__(optimizer, last_epoch, verbose)
    
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")
        #dtype = initial_learning_rate.dtype
        #end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
        #power = math_ops.cast(self.power, dtype)
        #global_step_recomp = math_ops.cast(step, dtype)
        #decay_steps_recomp = math_ops.cast(self.decay_steps, dtype)
        global_step_recomp = self.last_epoch
        decay_steps_recomp = self.decay_steps
        
        if self.cycle:
            if global_step_recomp == 0:
                multiplier = 1.0 
            else:
                multiplier = math.ceil(global_step_recomp / self.decay_steps)
            decay_steps_recomp = decay_steps_recomp * multiplier
        else:
            global_step_recomp = min(global_step_recomp, decay_steps_recomp)
            
        p = global_step_recomp / decay_steps_recomp
        ic(self.last_epoch, optimizer.param_groups[0]['lr'], p)
        return [((group['lr'] - self.end_learning_rate) * math.pow(1 - p, self.power) + self.end_learning_rate) for group in self.optimizer.param_groups]
    
    def _get_closed_form_lr(self):
        return [(base_lr - self.end_learning_rate) * math.pow(1 - p, self.power) + self.end_learning_rate for base_lr in self.base_lrs]


In [20]:
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup,
)


def get_optimizer(config: dict, model: PreTrainedModel, num_training_steps: int):
    model_params = list(model.named_parameters())
    no_decay = ["bias", "gamma", "beta", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_params if not any(nd in n for nd in no_decay)], 
         'weight_decay_rate': 1e-2  },
        {'params': [p for n, p in model_params if any(nd in n for nd in no_decay)], 
         'weight_decay_rate': 0.0 }
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=config["learning_rate"], eps=config["epsilon"])
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=int(num_training_steps * 0.1), num_training_steps=num_training_steps
    )
    return optimizer, scheduler

    
#optimizer = torch.optim.Adam(
#    params=optimizer_grounded_parameters,
#    lr=model_config["learning_rate"],
#    betas=(0.9, 0.98),
#    weight_decay=config["weight_decay"],
#    eps=config["adam_epsilon"])

#scheduler = LinearWarmupCosineAnnealingLR(
#    optimizer, 
#    warmup_epochs=model_config['warmup_steps'], 
#    max_epochs=model_config['training_steps'],
#    eta_min=model_config["end_learning_rate"])

#optimizer = optim.SGD(sms_model.parameters(), lr=model_config['init_learning_rate'], weight_decay=1e-4)
#optimizer = optim.SGD(filter(lambda p: p.requires_grad, sms_model.parameters()), lr=model_config['init_learning_rate'], weight_decay=1e-4)

#scheduler = CyclicLR(
#    optimizer, 
#    base_lr=1e-5,
#    max_lr=model_config['init_learning_rate'],
#    step_size_up=model_config['training_steps'] * 1,
#    mode='triangular2',
#    scale_mode='cycle',
#    cycle_momentum=False
#)


#if model_config["use_multi_gpus"]:
#optimizer = nn.DataParallel(optimizer, device_ids=device_ids)




In [21]:
# Define Mertice
from torchmetrics import MetricCollection
metric_collection = MetricCollection([
    torchmetrics.Accuracy(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Precision(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Recall(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.F1(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device)
], prefix='Train_')

val_metric_collection = MetricCollection([
    torchmetrics.Accuracy(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Precision(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.Recall(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
    torchmetrics.F1(num_classes=2, average='macro', multiclass=True, dist_sync_on_step=True, mdmc_average='global').to(device),
], prefix='Val_')


# Training model

In [22]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [23]:
def _sorted_checkpoints(config, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    if not os.path.isdir(config["output_dir"]):
        os.makedirs(config["output_dir"], exist_ok=True)
        
    ordering_and_checkpoint_path = []
    glob_checkpoints = glob.glob(os.path.join(config["output_dir"], "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(config, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not config["save_total_limit"]:
        return
    if config["save_total_limit"] <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(config, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= config["save_total_limit"]:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - config["save_total_limit"])
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)
        

### Init wandb

In [24]:
#wandb.tensorboard.patch(root_logdir=str(tensorboard_path / run_id))
wandb.init(
    project=project,
    group=group_tag,
    job_type=job_type,
    name=run_id,
    notes=method_tag,
    tags=addition_tag,
    sync_tensorboard=True,
    config={**model_config},
    reinit=True
)

wandb_config = wandb.config


[34m[1mwandb[0m: Currently logged in as: [33myuyuliao20[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


### Define training stepsm

In [25]:
from tqdm import tqdm, trange, tqdm_notebook

def training_step(
    config: dict,
    train_dataset: torch.utils.data.Dataset, 
    eval_dataset: torch.utils.data.Dataset, 
    model: PreTrainedModel,  
    device: str,
    init_wandb: object
):
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()
    train_batch_size = config["per_gpu_train_batch_size"] * max(1, config["n_gpu"])
    ic(train_batch_size)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, 
        sampler=train_sampler, 
        batch_size=train_batch_size, 
        pin_memory=True, 
        drop_last=True, 
        num_workers=multiprocessing.cpu_count())
    
    if config["max_steps"] > 0:
        t_total = config["max_steps"]
        config["num_train_epochs"] = config["max_steps"] // (len(train_dataloader) // config["gradient_accumulation_steps"]) + 1
    else:
        t_total = len(train_dataloader) // config["gradient_accumulation_steps"] * config["num_train_epochs"]
        
    optimizer, scheduler = get_optimizer(config, model, t_total)
    
    if config["n_gpu"]:
        model = torch.nn.DataParallel(model)
    
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
    # Train !
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", config["num_train_epochs"])
    logger.info("  Instantaneous batch size per GPU = %d", config["per_gpu_train_batch_size"])
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * config["gradient_accumulation_steps"]* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", config["gradient_accumulation_steps"])
    logger.info("  Total optimization steps = %d", t_total)
    
    set_seed(config["seed"])
    global_step = 0
    epochs_trained = 0 
    train_loss, logging_loss = 0.0, 0.0
    #train_iterator = trange(
    #    epochs_trained, int(config["num_train_epochs"]), desc="Epoch", disable=args.local_rank not in [-1, 0]
    #)
    
    train_iterator = tqdm_notebook(range(
        epochs_trained, int(config["num_train_epochs"])), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    scaler = GradScaler()
    model.train()
    for epoch in train_iterator:
        #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        epoch_iterator = tqdm_notebook(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        
        if args.local_rank != -1:
            train_sampler.set_epoch(epoch)
            
        for step, batch in enumerate(epoch_iterator):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)        
            mlm_labels = batch['masked_lm_labels'].to(device)
            sop_labels = batch['next_sentence_labels'].to(device)
    
            with autocast():
                # Forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    labels=mlm_labels,
                    sentence_order_label=sop_labels
                )
                assert outputs.prediction_logits.dtype is torch.float16
                if config["n_gpu"] > 1:
                    loss = outputs.loss.mean()  # mean() to average on multi-gpu parallel training
                if config["gradient_accumulation_steps"] > 1:
                    loss = loss / config["gradient_accumulation_steps"]
                assert loss.dtype is torch.float32
            scaler.scale(loss).backward()
            train_loss += loss.item()
            
            if (step + 1) % config["gradient_accumulation_steps"] == 0:
                #torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])  
                # Backward pass
                # Zero gradients, perform a backward pass, and update the weights.            
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                
                if (args.local_rank in [-1, 0]) and (config["logging_steps"] > 0) and (global_step % config["logging_steps"] == 0):
                    ic(global_step % config["logging_steps"])
                    # Log metrics
                    if (args.local_rank == -1 and config["evaluate_during_training"]):  
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate_step(
                            config=config, 
                            model=model, 
                            dataset=eval_dataset, 
                            device=device,
                            init_wandb=init_wandb
                        )
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)

                    #last_lr = optimizer.param_groups[0]['lr']
                    #last_lr = scheduler.optimizer.param_groups[0]["lr"]
                    print("=== Sent event to wandb===")
                    init_wandb.log({'lr': scheduler.get_lr()[0]}, step=global_step)
                    show_logs(
                        _wandb=init_wandb,
                        loss=(train_loss - logging_loss) / config["logging_steps"],
                        step=global_step
                    )
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (train_loss - logging_loss) / config["logging_steps"], global_step)
                    logging_loss = train_loss

                if (args.local_rank in [-1, 0]) and (config["save_steps"] > 0) and (global_step % config["save_steps"] == 0):
                        checkpoint_prefix = "checkpoint"
                        # Save model checkpoint
                        output_dir = os.path.join(config["output_dir"], "{}-{}".format(checkpoint_prefix, global_step))
                        os.makedirs(output_dir, exist_ok=True)
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        logger.info("Saving model checkpoint to %s", output_dir)
                        _rotate_checkpoints(args, checkpoint_prefix)
                    
            if config["max_steps"]> 0 and global_step > config["max_steps"]:
                epoch_iterator.close()
                break
                
        if config["max_steps"] > 0 and global_step > config["max_steps"]:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()
        
    return global_step, tr_loss / global_step

def evaluate_step(
    config: dict, 
    model: PreTrainedModel, 
    dataset: torch.utils.data.Dataset, 
    device: str, 
    init_wandb: object,
    prefix: Optional[str]="") -> dict:
    
    eval_output_dir = config["output_dir"]
    if args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir, exist_ok=True)
    
    eval_batch_size = config["per_gpu_eval_batch_size"] * max(1, config["n_gpu"])
    eval_sampler = SequentialSampler(dataset)
        
    eval_dataloader = DataLoader(
        dataset,
        sampler=eval_sampler,
        batch_size=eval_batch_size, 
        pin_memory=True, 
        drop_last=False,
        num_workers=multiprocessing.cpu_count())
    
    if config["n_gpu"] > 1:
        model = torch.nn.DataParallel(model)
        
    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    
    for batch in tqdm_notebook(eval_dataloader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        mlm_labels = batch['masked_lm_labels'].to(device)
        sop_labels = batch['next_sentence_labels'].to(device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=mlm_labels,
                sentence_order_label=sop_labels
            )
            loss = outputs.loss.mean()
        eval_loss += loss.item()
        nb_eval_steps += 1
        
    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))
    result = {
        "loss": eval_loss,
        "perplexity": perplexity
    }
    show_logs(init_wandb, eval_loss, nb_eval_steps, prefix="Eval", perplexity=perplexity.item())
    return result




In [26]:
'''
def training_step(model, input_ids, attention_mask, token_type_ids, mlm_labels, sop_labels, scaler, use_multi_gpus=False):
    with autocast():
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=mlm_labels,
            sentence_order_label=sop_labels
        )
        assert outputs.prediction_logits.dtype is torch.float16
        
        loss = outputs.loss.mean()
        assert loss.dtype is torch.float32
    # Backward pass
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    #if use_multi_gpus:
    #    scaler.step(optimizer.module)
    #else:
    scaler.step(optimizer)
    scaler.update()
    #torch.nn.utils.clip_grad_norm_(optimizer_grounded_parameters, max_norm=0.5)
    #if epoch > swa_start:
    #    swa_model.update_parameters(model)
    #    swa_scheduler.step()
    #else:
    #    scheduler.step()
    return loss

@torch.no_grad()
def validataion_step(model, input_ids, attention_mask, token_type_ids, mlm_labels, sop_labels):    
    with autocast():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=mlm_labels,
            sentence_order_label=sop_labels
        )
        loss = outputs.loss.mean()
    return loss

'''

'\ndef training_step(model, input_ids, attention_mask, token_type_ids, mlm_labels, sop_labels, scaler, use_multi_gpus=False):\n    with autocast():\n        # Forward pass\n        outputs = model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            labels=mlm_labels,\n            sentence_order_label=sop_labels\n        )\n        assert outputs.prediction_logits.dtype is torch.float16\n        \n        loss = outputs.loss.mean()\n        assert loss.dtype is torch.float32\n    # Backward pass\n    # Zero gradients, perform a backward pass, and update the weights.\n    optimizer.zero_grad()\n    scaler.scale(loss).backward()\n    #if use_multi_gpus:\n    #    scaler.step(optimizer.module)\n    #else:\n    scaler.step(optimizer)\n    scaler.update()\n    #torch.nn.utils.clip_grad_norm_(optimizer_grounded_parameters, max_norm=0.5)\n    #if epoch > swa_start:\n    #    swa_model.update_parameters(model)\n

### Training

In [None]:
print('[RUN ID]: {}'.format(run_id))
torch.cuda.empty_cache()
use_epoch_tracking = False
use_step_tracking = True
#wandb.watch(sms_model, log="all", log_freq=1000)        
def show_logs(_wandb, loss, step, is_epoch=False, prefix='Train', **kwargs):
    loss = float(loss)
    if is_epoch:
        _wandb.log({"epoch": step, f"{prefix}_loss": loss}, step=step)
    else:
        _wandb.log({f"{prefix}_step_loss": loss}, step=step)
        #print(f"{prefix} loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    if "perplexity" in kwargs.keys():
        _wandb.log({f"{prefix}_perplexity": kwargs["perplexity"]}, step=step)

def save_model(model, save_model_path):
    logging.info("[INFO] Start to save model ...")
    #if not save_model_path.parent.exists():
    #    save_model_path.parent.mkdir()       
    torch.save(model.state_dict(), save_model_path)
    
start_time = time.time()
albert_pretrain_model.to(device)

if args.local_rank not in [-1, 0]:
    torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab
logger.info("Training/evaluation parameters %s", args)

global_step, tr_loss = training_step(
    config=model_config, 
    train_dataset=train_dataset, 
    eval_dataset=valid_dataset, 
    model=albert_pretrain_model, 
    device=device,
    init_wandb=wandb)

ic(" global_step = %s, average loss = %s", global_step, tr_loss)

end_time = time.time()
each_steps_compute_time = (end_time - start_time)
print(each_steps_compute_time)
    

[RUN ID]: jp-pretrain-model_baseline_20210914151521


AlbertForPreTraining(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(32000, 128)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=312, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=312, out_features=312, bias=True)
                (key): Linear(in_features=312, out_features=312, bias=True)
                (value): Linear(in_features=312, out_features=312, bias=True)
            

ic| train_batch_size: 64


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=8.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=131487.0, style=ProgressStyle(description…

ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


ic| global_step % config["logging_steps"]: 0


=== Sent event to wandb===


In [None]:
'''
print('[RUN ID]: {}'.format(run_id))
torch.cuda.empty_cache()
use_epoch_tracking = False
use_step_tracking = True
#wandb.watch(sms_model, log="all", log_freq=1000)        
def show_logs(loss, step, is_epoch=False, prefix='Train', **kwargs):
    loss = float(loss)
    if is_epoch:
        wandb.log({"epoch": step, f"{prefix}_loss": loss}, step=step)
    else:
        wandb.log({f"{prefix}_step_loss": loss}, step=step)
        #print(f"{prefix} loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    if "perplexity" in kwargs.keys():
        wandb.log({f"{prefix}_perplexity": kwargs["perplexity"]}, step=step)

def save_model(model, save_model_path):
    logging.info("[INFO] Start to save model ...")
    #if not save_model_path.parent.exists():
    #    save_model_path.parent.mkdir()       
    torch.save(model.state_dict(), save_model_path)
    
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in tqdm(range(model_config['epochs'])): # model_config['epochs']
    start_time = time.time()    
    train_batch_loss = 0
    valid_batch_loss = 0  
    
    train_perplexity = 0
    valid_perplexity = 0
    
    # Training Step
    albert_pretrain_model = albert_pretrain_model.train()
    for step, train_batch in tqdm(enumerate(train_dataloader), 
                                  dynamic_ncols=False, 
                                  bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                  total=len(train_dataloader),
                                  leave=True, 
                                  unit='steps'):        
        input_ids = train_batch['input_ids'].to(device)
        attention_mask = train_batch['attention_mask'].to(device)
        token_type_ids = train_batch['token_type_ids'].to(device)        
        mlm_labels = train_batch['masked_lm_labels'].to(device)
        sop_labels = train_batch['next_sentence_labels'].to(device)
        
        train_loss = training_step(
            model=albert_pretrain_model, 
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            mlm_labels=mlm_labels, 
            sop_labels=sop_labels,
            scaler=scaler,
            use_multi_gpus=model_config["use_multi_gpus"]
        )
        scheduler.step()
        train_batch_loss += train_loss.item()
        train_perplexity += torch.exp(train_loss)
        
        #if model_config["use_multi_gpus"]:
        #    last_lr = optimizer.module.param_groups[0]['lr']
        #else:
        #last_lr = optimizer.param_groups[0]['lr']
        last_lr = scheduler.optimizer.param_groups[0]["lr"]
          
        if use_step_tracking:
            record_step = (step + 1) + (len(train_dataloader)) * epoch
            wandb.log({'learning_rate': last_lr}, step=record_step)
            show_logs(
                train_batch_loss / record_step, 
                record_step, 
                perplexity=train_perplexity.item() / record_step)
    
    save_model_checkpoint_path = str(save_model_path / f'{wandb.run.name}_{epoch}_model_weight.pt')
    save_model(albert_pretrain_model, save_model_checkpoint_path)
    
    if use_epoch_tracking:
        train_epoch_loss = train_batch_loss / step
        wandb.log({'learning_rate': last_lr}, step=epoch)
        show_log(train_epoch_loss, epoch, is_epoch=True)

    end_time = time.time()
    each_steps_compute_time = (end_time - start_time)
    print(each_steps_compute_time)
'''    

In [None]:
wandb.finish()


# Save model

In [None]:
save_models_path = main_model_path / wandb.run.name
if not save_models_path.exists():
    save_models_path.mkdir()
    

In [None]:
torch.save({
            'epoch': epoch,
            'model_state_dict': albert_pretrain_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
}, str(save_models_path / 'jp_pretrain_model.pt'))

In [None]:
torch.save(albert_pretrain_model.state_dict(), str(save_models_path / 'jp_pretrain_model_weight.pt'))


In [None]:
checkpoint = torch.load(str(save_models_path / 'jp_pretrain_model_weight.pt'))
albert_pretrain_model.load_state_dict(checkpoint)


In [None]:
albert_pretrain_model.module.state_dict().keys()

In [None]:
checkpoint.keys()

In [None]:
assert 1 == 2

# Test

In [None]:
print('[RUN ID]: {}'.format(run_id))
torch.cuda.empty_cache()
use_epoch_tracking = False
use_step_tracking = True
#wandb.watch(sms_model, log="all", log_freq=1000)        
def show_logs(loss, step, is_epoch=False, prefix='Train', **kwargs):
    loss = float(loss)
    if is_epoch:
        wandb.log({"epoch": step, f"{prefix}_loss": loss}, step=step)
    else:
        wandb.log({f"{prefix}_step_loss": loss}, step=step)
        #print(f"{prefix} loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    if "perplexity" in kwargs.keys():
        wandb.log({f"{prefix}_perplexity": kwargs["perplexity"]}, step=step)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in tqdm(range(model_config['epochs'])): # model_config['epochs']
    start_time = time.time()    
    train_batch_loss = 0
    valid_batch_loss = 0  
    
    train_perplexity = 0
    valid_perplexity = 0
    
    # Training Step
    albert_pretrain_model = albert_pretrain_model.train()
    for step, train_batch in tqdm(enumerate(train_dataloader), 
                                  dynamic_ncols=False, 
                                  bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                  total=len(train_dataloader),
                                  leave=True, 
                                  unit='steps'):        
        input_ids = train_batch['input_ids'].to(device)
        attention_mask = train_batch['attention_mask'].to(device)
        token_type_ids = train_batch['token_type_ids'].to(device)        
        mlm_labels = train_batch['masked_lm_labels'].to(device)
        sop_labels = train_batch['next_sentence_labels'].to(device)
        
        train_loss = training_step(
            model=albert_pretrain_model, 
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            mlm_labels=mlm_labels, 
            sop_labels=sop_labels,
            scaler=scaler,
            use_multi_gpus=model_config["use_multi_gpus"]
        )
        scheduler.step()
        train_batch_loss += train_loss.item()
        train_perplexity += torch.exp(train_loss)
        
        #if model_config["use_multi_gpus"]:
        #    last_lr = optimizer.module.param_groups[0]['lr']
        #else:
        #last_lr = optimizer.param_groups[0]['lr']
        last_lr = scheduler.optimizer.param_groups[0]["lr"]
          
        if use_step_tracking:
            record_step = (step + 1) * (epoch + 1)
            wandb.log({'learning_rate': last_lr}, step=record_step)
            show_logs(
                train_batch_loss / record_step, 
                record_step, 
                perplexity=train_perplexity.item() / record_step)
    
    if use_epoch_tracking:
        train_epoch_loss = train_batch_loss / step
        wandb.log({'learning_rate': last_lr}, step=epoch)
        show_log(train_epoch_loss, epoch, is_epoch=True)  
        #train_metric_records = metric_collection.compute()
        #wandb.log(train_metric_records, step=epoch)
    
    # Validation Step
    albert_pretrain_model = albert_pretrain_model.eval()
    for step, valid_batch in tqdm(enumerate(val_dataloader), 
                                dynamic_ncols=False, 
                                bar_format="{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}", 
                                total=len(val_dataloader),
                                leave=True, 
                                unit='steps'):            
        input_ids = valid_batch['input_ids'].to(device)
        attention_mask = valid_batch['attention_mask'].to(device)
        token_type_ids = valid_batch['token_type_ids'].to(device)
        mlm_labels = valid_batch['masked_lm_labels'].to(device)
        sop_labels = valid_batch['next_sentence_labels'].to(device)
            
        valid_loss = validataion_step(
            model=albert_pretrain_model, 
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids, 
            mlm_labels=mlm_labels,
            sop_labels=sop_labels
        )
        valid_batch_loss += valid_loss.item()
        valid_perplexity += torch.exp(valid_loss)
        
        if use_step_tracking:
            record_step = (step + 1) * (epoch + 1)
            show_logs(
                valid_batch_loss / record_step, 
                record_step, 
                prefix='valid', 
                perplexity=valid_perplexity.item() / record_step)
            
        #sk_metrics = sklearn_metrics(val_outputs, labels, 'train')
        #ic(sk_metrics)
        #ic(val_metric_collection(outputs, labels).compute())
        #ic(val_metric_collection(outputs, labels))
        
    if use_epoch_tracking:        
        valid_epoch_loss = valid_batch_loss / step 
        show_logs(valid_epoch_loss, epoch, is_epoch=True, prefix='Val')
        #val_metric_records = val_metric_collection.compute()
        #wandb.log(val_metric_records, step=epoch)
    
    loss_template = ("Epoch {}/{} - {:.0f}s {:.0f}ms/step - lr:{:} - loss: {:.6f} - val_loss: {:.6f}")  
    #metrics_template = (
    #    """
    #    categorical_accuracy: {:.4f} - f1_score: {:.4f} - multi_precision: {:.4f} - multi_recall: {:.4f}
    #    val_categorical_accuracy: {:.4f} -  val_f1_score: {:.4f} - val_multi_precision: {:.4f} - val_multi_recall: {:.4f}
    #    """
    #)
    end_time = time.time()
    each_steps_compute_time = (end_time - start_time)
    print(loss_template.format(
        epoch,
        model_config['epochs'], 
        each_steps_compute_time,
        each_steps_compute_time * 1000 / model_config['training_steps'],
        last_lr,
        train_epoch_loss,
        val_epoch_loss)
    )

    #print(metrics_template.format(
    #    train_metric_records['Train_Accuracy'],
    #    train_metric_records['Train_F1'],
    #    train_metric_records['Train_Precision'],
    #    train_metric_records['Train_Recall'],
    #    val_metric_records['Val_Accuracy'],
    #    val_metric_records['Val_F1'],
    #    val_metric_records['Val_Precision'],
    #    val_metric_records['Val_Recall']
    #))
    
    if use_epoch_tracking:
        metric_collection.reset()
        val_metric_collection.reset()
     

In [None]:
%matplotlib inline
import math
from torch.optim.lr_scheduler import _LRScheduler
from torch import nn
from torch import cuda
from torch import optim
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

net = NeuralNetwork()

optimizer = optim.SGD(net.parameters(), lr = 1e-2)
lambda1 = lambda epoch: 0.2 if epoch % 5 == 0 else 1
lambda2 = lambda epoch: 0.2

#scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda = lambda2)
#scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15], gamma=0.1)
#scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)
#scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: PolynomialDecay(step))

class PolynomialDecay(_LRScheduler):
    def __init__(self, optimizer, decay_steps, end_learning_rate=0.0001, power=0.5, cycle=False, last_epoch=-1, verbose=False):
        if decay_steps <= 1.:
            raise ValueError('max_decay_steps should be greater than 1.')            
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle
        super(PolynomialDecay, self).__init__(optimizer, last_epoch, verbose)
    
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")
        #dtype = initial_learning_rate.dtype
        #end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
        #power = math_ops.cast(self.power, dtype)
        #global_step_recomp = math_ops.cast(step, dtype)
        #decay_steps_recomp = math_ops.cast(self.decay_steps, dtype)
        global_step_recomp = self.last_epoch
        decay_steps_recomp = self.decay_steps
        
        if self.cycle:
            if global_step_recomp == 0:
                multiplier = 1.0 
            else:
                multiplier = math.ceil(global_step_recomp / self.decay_steps)
            decay_steps_recomp = decay_steps_recomp * multiplier
        else:
            global_step_recomp = min(global_step_recomp, decay_steps_recomp)
            
        p = global_step_recomp / decay_steps_recomp
        #c(self.last_epoch, optimizer.param_groups[0]['lr'], p)
        return [((group['lr'] - self.end_learning_rate) * math.pow(1 - p, self.power) + self.end_learning_rate) for group in self.optimizer.param_groups]
    
    def _get_closed_form_lr(self):
        return [(base_lr - self.end_learning_rate) * math.pow(1 - p, self.power) + self.end_learning_rate for base_lr in self.base_lrs]


    
def polynomial_decay_scale_fun(global_steps, initial_learning_rate=1e-2, decay_steps=100, power=0.5, end_learning_rate=1e-5, cycle=False):
    if cycle:
        if global_steps == 0:
            multiplier = 1.0 
        else:
            multiplier = math.ceil(global_steps / decay_steps)
            decay_steps = decay_steps * multiplier
    else:
        global_steps = min(global_steps, decay_steps)
    p = global_steps / decay_steps
    #ic(global_steps, p)
    return (initial_learning_rate - end_learning_rate) * math.pow(1 - p, power) + end_learning_rate
    
    
#optimizer = optim.SGD(net.parameters(), lr=1e-2)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
#scheduler = PolynomialDecay(optimizer, decay_steps=1000, end_learning_rate=1e-5)
 
scheduler = LinearWarmupCosineAnnealingLR(
    optimizer, 
    warmup_epochs=model_config['warmup_steps'], 
    max_epochs=model_config['training_steps'],
    eta_min=model_config["end_learning_rate"])

#scheduler = optim.lr_scheduler.CyclicLR(
#    optimizer, 
#    base_lr=1e-5,
#    max_lr=1e-2,
#    step_size_up=20,
#    scale_fn=polynomial_decay_scale_fun,
#    mode='triangular2',
#    scale_mode='cycle',
#    cycle_momentum=False)

iteration = model_config['epochs']
scheduler_lr_list = []
for epoch in range(1, iteration):
    scheduler.step()
    #print(epoch, scheduler.get_last_lr()[0])
    scheduler_lr_list.append(scheduler.get_last_lr()[0])

plt.xlabel('Training Iterations')
plt.ylabel('Learning Rate')
plt.title("CLR - 'triangular' Policy")
plt.plot(range(1, iteration), scheduler_lr_list)
