In [None]:
from pathlib import Path
import os
from functools import partial
from datetime import datetime, timezone, timedelta
from IPython.core.debugger import set_trace as bk
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as D
import torch.tensor as T
import nlp
from transformers import ElectraModel, ElectraConfig, ElectraTokenizerFast, ElectraForMaskedLM,ElectraForPreTraining
from fastai2.text.all import *
import wandb
from fastai2.callback.wandb import WandbCallback
from _utils.would_like_to_pr import *
from _utils.huggingface import *
from _utils.utils import *

In [None]:
c = MyConfig({
    'device': 'cuda:3',
    'size': 'small',
    'sampling': 'fp32_gumbel',
    'electra_mask_style': True,
    'gen_smooth_label': False,
    'disc_smooth_label': False,
    'sort_sample': True,
    'shuffle': True,
})

i = ['small', 'base', 'large'].index(c.size)
c.mask_prob = [0.15, 0.15, 0.25][i]
c.lr = [5e-4, 2e-4, 2e-4][i]
c.bs = [128, 256, 2048][i]
c.steps = [10**6, 766*1000, 400*1000][i]
c.max_length = [128, 512, 512][i]
generator_size_divisor = [4, 3, 4][i]
disc_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-discriminator')
gen_config = ElectraConfig.from_pretrained(f'google/electra-{c.size}-generator')
# note that public electra-small model is actually small++ and don't scale down generator size 
gen_config.hidden_size = int(disc_config.hidden_size/generator_size_divisor)
gen_config.num_attention_heads = int(disc_config.num_attention_heads/generator_size_divisor)
gen_config.intermediate_size = int(disc_config.intermediate_size/generator_size_divisor)
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{c.size}-generator")
# check
assert c.sampling in ['fp32_gumbel', 'fp16_gumbel', 'multinomial']
# path to data
Path('./checkpoints/pretrain').mkdir(exist_ok=True,parents=True)
if c.size in ['small', 'base']:
  wiki_cache_dir = Path("./datasets/wikipedia/20200501.en/1.0.0")
  book_cache_dir = Path("./datasets/bookcorpus/plain_text/1.0.0")
  wbdl_cache_dir = Path("./datasets/wikibook_dl")
  wbdl_cache_dir.mkdir(exist_ok=True)
# print info
print("process id: {os.getpid()}")
print(c)

# 1. Load Data

In [None]:
if c.size in ['small', 'base']:
  
  # wiki
  if (wiki_cache_dir/f"wiki_electra_{c.max_length}.arrow").exists():
    print('loading the electra data (wiki)')
    wiki = nlp.Dataset.from_file(str(wiki_cache_dir/f"wiki_electra_{c.max_length}.arrow"))
  else:
    print('load/download wiki dataset')
    wiki = nlp.load_dataset('wikipedia', '20200501.en', cache_dir='./datasets')['train']
  
    print('creat data from wiki dataset for ELECTRA')
    wiki = ELECTRADataTransform(wiki, is_docs=True, text_col={'text':'input_ids'}, max_length=c.max_length, hf_toker=hf_tokenizer).map(cache_file_name=str(wiki_cache_dir/f"wiki_electra_{c.max_length}.arrow"))

  # bookcorpus
  if (book_cache_dir/f"book_electra_{c.max_length}.arrow").exists():
    print('loading the electra data (BookCorpus)')
    book = nlp.Dataset.from_file(str(book_cache_dir/f"book_electra_{c.max_length}.arrow"))
  else:
    print('load/download BookCorpus dataset')
    book = nlp.load_dataset('bookcorpus', cache_dir='./datasets')['train']
  
    print('creat data from BookCorpus dataset for ELECTRA')
    book = ELECTRADataTransform(book, is_docs=False, text_col={'text':'input_ids'}, max_length=c.max_length, hf_toker=hf_tokenizer).map(cache_file_name=str(book_cache_dir/f"book_electra_{c.max_length}.arrow"))

  wb_data = HF_MergedDataset(wiki, book)
  wb_dsets = HF_Datasets({'train': wb_data}, cols=['input_ids'], hf_toker=hf_tokenizer)
  dls = wb_dsets.dataloaders(bs=c.bs, 
                             shuffle_train=c.shuffle,
                             srtkey_fc=None if c.sort_sample else False, 
                             cache_dir=Path.home()/'datasets/wikibook_dl', cache_name='dl_{split}.json')

else: # for large size
  raise NotImplementedError

# 2. Masked language model objective

## 2.1 MLM objective callback

In [None]:
"""
Modified from HuggingFace/transformers (https://github.com/huggingface/transformers/blob/0a3d0e02c5af20bfe9091038c4fd11fb79175546/src/transformers/data/data_collator.py#L102). It is
- few ms faster: intead of a[b] a on gpu b on cpu, tensors here are all in the same device
- few tens of us faster: in how we create special token mask
- doesn't require huggingface tokenizer
- cost you only 20 ms on a (128,128) tensor, so dynamic masking is cheap   
"""
def mask_tokens(inputs, mask_token_index, vocab_size, special_token_indices, mlm_probability=0.15, replace_prob=0.1, orginal_prob=0.1, ignore_index=-100):
  """ 
  Prepare masked tokens inputs/labels for masked language modeling: (1-replace_prob-orginal_prob)% MASK, replace_prob% random, orginal_prob% original within mlm_probability% of tokens in the sentence. 
  - ignore_index in nn.CrossEntropy is default to -100, so you don't need to specify ignore_index in loss
  """
  
  device = inputs.device
  labels = inputs.clone()
  # We sample a few tokens in each sequence for masked-LM training (with probability mlm_probability defaults to 0.15 in Bert/RoBERTA
  probability_matrix = torch.full(labels.shape, mlm_probability, device=device)
  special_tokens_mask = torch.full(inputs.shape, False, dtype=torch.bool, device=device)
  for sp_id in special_token_indices:
    special_tokens_mask = special_tokens_mask | (inputs==sp_id)
  probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
  mlm_mask = torch.bernoulli(probability_matrix).bool()
  labels[~mlm_mask] = ignore_index  # We only compute loss on masked tokens

  # <1 - replace_prob - orginal_prob>% of the time, we replace masked input tokens with mask_token
  mask_prob = 1 - replace_prob - orginal_prob
  mask_token_mask = torch.bernoulli(torch.full(labels.shape, 0.8, device=device)).bool() & mlm_mask
  inputs[mask_token_mask] = mask_token_index

  # <replace_prob>% of the time, we replace masked input tokens with random word
  if int(replace_prob)!=0:
    rep_prob = replace_prob/(replace_prob + orginal_prob)
    replace_token_mask = torch.bernoulli(torch.full(labels.shape, 0.5, device=device)).bool() & mlm_mask & ~mask_token_mask
    random_words = torch.randint(vocab_size, labels.shape, dtype=torch.long, device=device)
    inputs[replace_token_mask] = random_words[replace_token_mask]

  # <orginal_prob>% of the time, we keep the masked input tokens unchanged
  return inputs, labels, mlm_mask

class MaskedLMCallback(Callback):
  @delegates(mask_tokens)
  def __init__(self, mask_tok_id, special_tok_ids, vocab_size, ignore_index=-100, for_electra=False, **kwargs):
    self.ignore_index = ignore_index
    self.for_electra = for_electra
    self.mask_tokens = partial(mask_tokens,
                               mask_token_index=mask_tok_id,
                               special_token_indices=special_tok_ids,
                               vocab_size=vocab_size,
                               ignore_index=-100,
                               **kwargs)

  def begin_batch(self):
    text_indices = self.xb[0]
    masked_inputs, labels, is_mlm_applied = self.mask_tokens(text_indices)
    if self.for_electra:
      self.learn.xb, self.learn.yb = (masked_inputs, is_mlm_applied, labels), (labels,)
    else:
      self.learn.xb, self.learn.yb = (masked_inputs,), (labels,)

  @delegates(TfmdDL.show_batch)
  def show_batch(self, dl, idx_show_ignored, verbose=True, **kwargs):
    b = dl.one_batch()
    inputs = b[0]
    masked_inputs, labels, is_mlm_applied = self.mask_tokens(inputs.clone())
    # check
    assert torch.equal(is_mlm_applied, labels!=self.ignore_index)
    assert torch.equal((~is_mlm_applied *masked_inputs + is_mlm_applied * labels), inputs)
    # change symbol to show the ignored position
    labels[labels==self.ignore_index] = idx_show_ignored
    # some notice to help understand the masking mechanism
    if verbose: 
      print("We won't count loss from position where y is ignore index")
      print("Notice 1. Positions have label token in y will be either [Mask]/other token/orginal token in x")
      print("Notice 2. Special tokens (CLS, SEP) won't be masked.")
      print("Notice 3. Dynamic masking: every time you run gives you different results.")
    # show
    tfm_b =(masked_inputs, is_mlm_applied, labels, labels) if self.for_electra else (masked_inputs, labels)   
    dl.show_batch(b=tfm_b, **kwargs)

In [None]:
mlm_cb = MaskedLMCallback(mask_tok_id=hf_tokenizer.mask_token_id, 
                          special_tok_ids=hf_tokenizer.all_special_ids, 
                          vocab_size=hf_tokenizer.vocab_size,
                          mlm_probability=c.mask_prob,
                          replace_prob=0.0 if c.electra_mask_style else 0.1, 
                          orginal_prob=0.15 if c.electra_mask_style else 0.1,
                          for_electra=True)
# mlm_cb.show_batch(dls[0], )

# 3. ELECTRA (replaced token detection objective)
see details in paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555)

In [None]:
class ELECTRAModel(nn.Module):
  
  def __init__(self, generator, discriminator):
    super().__init__()
    self.generator, self.discriminator = generator,discriminator
    self.gumbel_dist = torch.distributions.gumbel.Gumbel(0.,1.)
    self.toker = hf_tokenizer
    # tight embeddings (word, pos, token_type)
    # Note input and output embedding of generator has been tighted by huggingface/transformers 
    self.discriminator.model.electra.embeddings = self.generator.model.electra.embeddings

  def to(self, *args, **kwargs):
    super().to(*args, **kwargs)
    a_tensor = next(self.parameters())
    device, dtype = a_tensor.device, a_tensor.dtype
    if c.sampling=='fp32_gumbel': dtype = torch.float32
    self.gumbel_dist = torch.distributions.gumbel.Gumbel(torch.tensor(0., device=device, dtype=dtype), torch.tensor(1., device=device, dtype=dtype))

  def forward(self, masked_inputs, is_mlm_applied, labels):
    # masked_inp_ids: (B,L)
    # ignored: (B,L)
    #assert is_mlm_applied.dtype == torch.bool
    
    gen_logits = self.generator(masked_inputs) # (B, L, vocab size)
    
    # reduce size to speed up and use
    mlm_gen_logits = gen_logits[is_mlm_applied, :].detach() # ( #mlm_positions, vocab_size)
    # sampling
    pred_toks = self.sample(mlm_gen_logits)
    # pred_toks: ( #mlm_positions, )
    # use predicted token to fill 15%(mlm prob) mlm applied positions
    generated = masked_inputs.clone() # (B,L)
    generated[is_mlm_applied] = pred_toks # (B,L)
    # not equal to generator predicted and is at mlm applied position
    is_replaced = is_mlm_applied.clone() # (B,L)
    is_replaced[is_mlm_applied] = (pred_toks != labels[is_mlm_applied]) # (B,L)

    disc_logits = self.discriminator(generated) # (B, L)

    return gen_logits, generated, disc_logits, is_replaced

  def sample(self, logits):
    "reimplement it cuz there is a bug in torch.nn.functional.gumbel_softmax when fp16 (https://github.com/pytorch/pytorch/issues/41663)"
    "This is equal to the code of official ELECTRA repo. standard gumbel dist. = -ln(-ln(standard uniform dist.))"
    if c.sampling == 'fp32_gumbel':
      return (logits.float() + self.gumbel_dist.sample(logits.shape)).argmax(dim=-1)
    elif c.sampling == 'fp16_gumbel': # 5.06 ms
      return (logits + self.gumbel_dist.sample(logits.shape)).argmax(dim=-1)
    elif c.sampling == 'multinomial':
      return torch.multinomial(F.softmax(logits, dim=-1), 1).squeeze()

class ELECTRALoss():
  def __init__(self, pad_idx, loss_weights=(1.0, 50.0), gen_label_smooth=False, disc_label_smooth=False):
    self.pad_idx = pad_idx
    self.loss_weights = loss_weights
    if gen_label_smooth:
      eps = gen_label_smooth if isinstance(gen_label_smooth, float) else 0.1
      self.gen_loss_fc = LabelSmoothingCrossEntropyFlat(eps=eps)
    else:
      self.gen_loss_fc = CrossEntropyLossFlat()
    self.disc_loss_fc = nn.BCEWithLogitsLoss()
    self.disc_label_smooth = disc_label_smooth
    
  def __call__(self, pred, targ_ids):
    gen_logits, generated, disc_logits, is_replaced = pred
    gen_loss = self.gen_loss_fc(gen_logits.float(), targ_ids) # ignore position where targ_id==-100
    non_pad = generated != self.pad_idx
    disc_logits = disc_logits.masked_select(non_pad) # -> 1d tensor
    is_replaced = is_replaced.masked_select(non_pad) # -> 1d tensor
    if self.disc_label_smooth:
      eps = self.disc_label_smooth if isinstance(self.disc_label_smooth, float) else 0.1
      zeros = ~is_replaced
      is_replaced = is_replaced.float().masked_fill(zeros, eps)
    disc_loss = self.disc_loss_fc(disc_logits.float(), is_replaced.float())
    return gen_loss * self.loss_weights[0] + disc_loss * self.loss_weights[1]

  def decodes(self, pred):
    gen_logits, generated, disc_logits, is_replaced = pred
    gen_pred = gen_logits.argmax(dim=-1)
    disc_pred = disc_logits > 0
    return gen_pred, generated, disc_pred, is_replaced

# 4. Learning rate schedule

In [None]:
def linear_warmup_and_decay(pct, lr_max, warmup_steps, total_steps, fake_total_steps=None, end_lr=0.0, decay_power=1):
  """ pct (float): fastai count it as ith_step/num_epoch*len(dl), so we can't just use pct when our num_epoch is fake.he ith_step is count from 0, """
  step_i = round(pct * (fake_total_steps if fake_total_steps else total_steps)) + 1 # fastai count step from 0, so we have to add 1 back
  # According to the original source code, two schedules take effect at the same time, but decaying schedule will be neglible in the early time.
  decayed_lr = (lr_max-end_lr) * (1 - step_i/total_steps) ** decay_power + end_lr # https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/polynomial_decay
  warmed_lr = decayed_lr * min(1.0, step_i/warmup_steps) # https://github.com/google-research/electra/blob/81f7e5fc98b0ad8bfd20b641aa8bc9e6ac00c8eb/model/optimization.py#L44
  return warmed_lr

In [None]:
lr_shedule = ParamScheduler({'lr': partial(linear_warmup_and_decay,
                                            lr_max=c.lr,
                                            warmup_steps=10000,
                                            total_steps=c.steps,
                                            fake_total_steps=len(dls[0])*9999)})

# 5. Train

In [None]:
class GradientClipping(Callback):
    "Gradient clipping during training."
    def __init__(self, clip:float = 0.):
        self.clip = clip
    def after_backward(self):
        "Clip the gradient before the optimizer step."
        if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)

In [None]:
def now_time():
  now_time = datetime.now(timezone(timedelta(hours=+8)))
  name = str(now_time)[6:-13].replace(' ', '_').replace(':', '-')
  return name

In [None]:
generator = HF_Model(ElectraForMaskedLM, gen_config, hf_tokenizer, variable_sep=True)
discriminator = HF_Model(ElectraForPreTraining, disc_config, hf_tokenizer, variable_sep=True)
electra_model = ELECTRAModel(generator, discriminator)
electra_loss_func = ELECTRALoss(pad_idx=hf_tokenizer.pad_token_id, gen_label_smooth=c.gen_smooth_label, disc_label_smooth=c.disc_smooth_label)

dls.to(torch.device(c.device))
run_name = now_time()
print(run_name)
learn = Learner(dls, electra_model,
                loss_func=electra_loss_func,
                opt_func=partial(Adam, eps=1e-6,),
                path='./checkpoints',
                model_dir='pretrain',
                cbs=[mlm_cb,
                    RunSteps(c.steps, [0.0625, 0.125, 0.5, 1.0], run_name+"_{percent}"),
                    ],
                )

if c.device.startswith('cuda'): 
  # too large loss_scale will cause overflow and lose the batch, too small loss_scale will cause underflow or can't fully use the value of loss
  # adjust initial loss scale and interval to re scale up loss scale cuz we have scale disc loss by 50 and we have so many training steps 
  learn = learn.to_fp16(max_loss_scale=2.**11, scale_wait=int(c.steps*0.01), clip=1.)
else: 
  learn.add_cb(GradientClipping(1.))
# to also exclude layernorm
learn.create_opt()
for p in my_bn_bias_state(learn, True): p['do_wd'] = False 
# ----------------------------
learn.fit(9999, cbs=[lr_shedule])