In [1]:
import torch
import torch.nn as nn

## 1. Fine-Tuning BERT
Tutorial Source: Chris McCormick and Nick Ryan. (2019, July 22). BERT Fine-Tuning Tutorial with PyTorch. Retrieved from http://www.mccormickml.com

### 1.1 Prepare Dataset

### 1.2 Tokenize Dataset

### 1.3 Traning & Validation Split

### 1.4 Train the Classification Model

## 3. Replace all LayerNorm modules with new implementation

In [28]:
import torch.nn as nn
import copy
from layernorm_v2 import LayerNorm_v2

# get a copy of the orignal model
new_model = copy.deepcopy(bert_model)

In [54]:
# embedding layer
param = {}
# store the parameters
for name, data in bert_model._modules['embeddings']._modules['LayerNorm'].named_parameters():
    param[name] = data
# replace LayerNorm with LayerNorm_v2 (same parameters)
new_model._modules['embeddings']._modules['LayerNorm'] = LayerNorm_v2(param['weight'], param['bias'])

# for each sublayer
for i in range(12):
    # multihead attention block and feed forward block
    old_layernorm_1 = bert_model._modules['encoder']._modules['layer']._modules[str(i)]._modules['attention']._modules['output']._modules['LayerNorm']
    old_layernorm_2 = bert_model._modules['encoder']._modules['layer']._modules[str(i)]._modules['output']._modules['LayerNorm']
    param_1 = {}
    param_2 = {}
    for name, data in old_layernorm_1.named_parameters():
        param_1[name] = data
    for name, data in old_layernorm_2.named_parameters():
        param_2[name] = data
    new_model._modules['encoder']._modules['layer']._modules[str(i)]._modules['attention']._modules['output']._modules['LayerNorm'] = LayerNorm_v2(param_1['weight'], param_1['bias'])
    new_model._modules['encoder']._modules['layer']._modules[str(i)]._modules['output']._modules['LayerNorm'] = LayerNorm_v2(param_2['weight'], param_2['bias'])

In [51]:
# # find all LayerNorm modules
# old_layernorm_list = [module for module in bert_model.modules() if isinstance(module, nn.LayerNorm)]


# layernorm_param = []
# for i in range(len(old_layernorm_list)):
#     # store the parameters
#     param_dict = {}
#     for name, param in old_layernorm_list[i].named_parameters():
#         param_dict[name] = param
#     layernorm_param.append(param_dict)
    

# idx = 0
# for layernorm in new_model.modules():
#     if isinstance(layernorm, nn.LayerNorm):
#         layernorm = LayerNorm_v2(layernorm_param[idx]['weight'], layernorm_param[idx]['bias'])
#         idx += 1
#     replace it with the new LayerNorm


In [57]:
# Check LayerNorm
for module in new_model.modules():
    if isinstance(module, nn.LayerNorm):
        print("Not Clean!!!")
    print(module)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm_v2()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm_v2()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linea

In [60]:
# Verify the new model
print(new_model(**encoded_input))

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.1318, -0.0333, -0.0274,  ..., -0.0252,  0.3472,  0.0981],
         [-0.1309,  0.0768, -0.0440,  ..., -0.0012,  0.2886,  0.0154],
         [-0.1996,  0.1322, -0.0499,  ...,  0.0081,  0.2832,  0.0802],
         [-0.2461,  0.0560, -0.0619,  ..., -0.0368,  0.3188,  0.1059],
         [-0.1967,  0.0607, -0.0391,  ..., -0.0327,  0.2993, -0.0095],
         [ 0.0839, -0.1601, -0.0278,  ...,  0.0417,  0.8319,  0.1762]]],
       grad_fn=<AddBackward0>), pooler_output=tensor([[-8.0579e-01, -1.9621e-02,  9.7443e-01,  2.4826e-01, -3.7157e-01,
         -1.1247e-01,  5.5863e-01,  2.7542e-02,  9.6929e-01, -8.1530e-01,
          6.4503e-01, -7.4231e-01,  9.8629e-01, -9.2892e-01,  9.7908e-01,
          6.4986e-02,  1.7156e-01, -1.3067e-01,  1.2808e-01, -2.9844e-01,
          7.5510e-01, -9.5433e-01,  7.9106e-01,  1.6349e-01,  5.2892e-02,
         -9.1742e-01, -7.8209e-02,  9.7946e-01,  9.6361e-01,  8.1462e-01,
         -5.7140e-0