In [1]:
import argparse
import os
import tempfile

import torch
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader, Dataset

from omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.utils import AppState, model_utils

In [2]:
connector = NLPSaveRestoreConnector()

In [3]:
# nemo_file = "path to 1TP nemo file"
# nemo_file = "/data/megatron_3b_1TP/megatron_t5_expanded_vocab_posemb.nemo"
# nemo_file = "/datap/misc/Checkpoints/megatron_t5_220m/tp1_pp1/megatron_t5.nemo"
nemo_file = "/Data/Checkpoints/multilingual_t5/megatron_mt5.nemo"
# nemo_file = "path to 1TP output nemo file to be saved after expansion" If this is none then out_file = nemo_file + "_expanded_vocab.nemo"
out_file = None #  "/data/megatron_3b_1TP/"

In [5]:

def expand_vocab(key, state_dict, increase_by=9984+128):
    original_shape = state_dict[key].shape
    shape_len = len(original_shape)
    print("Original shape :", original_shape, len(original_shape))

    # Use vocab size as final expansion dim, we want the vocab to be divisible by 128
    new_vocab_size = original_shape[0] + increase_by
    print("New vocab size :", new_vocab_size)

    # Add buffer of dummy tokens for divisibility of vocab tokens
    divisible_by_val = 1 # model_cfg.get('make_vocab_size_divisible_by', 1)
    if new_vocab_size % divisible_by_val != 0:
        dummy_tokens = divisible_by_val - (new_vocab_size % divisible_by_val)
        final_vocab_size = new_vocab_size + dummy_tokens
        print("Adding Dummy Tokens :", dummy_tokens)
    else:
        final_vocab_size = new_vocab_size

    # Final expanded shape
    if shape_len > 1:
        final_vocab_shape = [final_vocab_size, original_shape[1]]
    else:
        final_vocab_shape = [final_vocab_size]
    new_shape = torch.Size(final_vocab_shape)

    # Expand vocab
    new_output_layer = torch.zeros(new_shape, dtype=state_dict[key].dtype)
    if shape_len > 1:
        new_output_layer[: original_shape[0], :] = state_dict[key]
    else:
        new_output_layer[: original_shape[0]] = state_dict[key]
    print(new_output_layer.size())

    # Update new tokens
    if shape_len > 1:
        new_output_layer[original_shape[0]:, :] = 1e-6  # small constant init is sufficient for new tokens
    else:
        new_output_layer[original_shape[0]:] = 1e-6

    # Update dummy tokens
    if shape_len > 1:
        new_output_layer[new_vocab_size:, :] = 0.0
    else:
        new_output_layer[new_vocab_size:] = 0.0
#     print(new_output_layer)

    state_dict[key] = new_output_layer
    print(f"FINAL state_dict[{key}].shape {state_dict[key].shape}")

    return state_dict

In [6]:
keys = [
    'enc_dec_model.encoder_embedding.word_embeddings.weight',
    'enc_dec_model.decoder_embedding.word_embeddings.weight',
    'enc_dec_model.tokens_head.bias',
    'enc_dec_model.encoder_embedding.position_embeddings.weight',
    'enc_dec_model.decoder_embedding.position_embeddings.weight',
]

with tempfile.TemporaryDirectory() as tmpdir:
        # Extract the model from the checkpoint
        connector._unpack_nemo_file(nemo_file, tmpdir)

        # Load the model config
        config_path = os.path.join(tmpdir, connector.model_config_yaml)
        config = OmegaConf.load(config_path)

        tp_size = config.get('tensor_model_parallel_size', 1)
        pp_size = config.get('pipeline_model_parallel_size', 1)

        if tp_size > 1:
            raise RuntimeError(
                "GPT model's vocab and output embeddings cannot be expanded for tensor parallelism > 1.\n"
                "Please use the `megatron_change_num_partitions.py` script to change the number of partitions.\n"
                "Then run this script to expand the vocabulary and output embeddings.\n"
                "Finally run the `megatron_change_num_partitions.py` script again to restore the number of partitions."
            )

        appstate = AppState()
        appstate.tensor_model_parallel_rank = 1
        appstate.pipeline_model_parallel_rank = 1
        appstate.tensor_model_parallel_size = tp_size
        appstate.pipeline_model_parallel_size = pp_size
        appstate.model_parallel_size = tp_size * pp_size

        input_embedding_size = None
        input_embed_processed = False

        output_layer_size = None
        output_layer_processed = False

        # Load the TP 1 PP Y model checkpoint

        # Only need to update 2 PP - PP 0 for the input embedding and PP -1 for the output layer.
        pp_checks = [0]
        if pp_size > 1:
            pp_checks = [pp_size - 1]
        print(pp_checks)
        
        for pp in pp_checks:
            for tp in range(1):  # tp size
                appstate.tensor_model_parallel_rank = tp
                appstate.pipeline_model_parallel_rank = pp

                checkpoint_path = os.path.join(tmpdir, connector.model_weights_ckpt)
                checkpoint_path = model_utils.inject_model_parallel_rank(checkpoint_path)

                print("Parsing checkpoint at location: ", checkpoint_path)
                state_dict = torch.load(checkpoint_path, map_location='cpu')
                
                
                for k in keys:
                    print(f"Before {k} --> {state_dict[k].shape}")
                    if "position" in k:
                        state_dict = expand_vocab(k, state_dict, 4096-512)
                    else:
                        state_dict = expand_vocab(k, state_dict)
                    print(f"After {k} --> {state_dict[k].shape}")
                    
                print("Saving state dict ...")
                torch.save(state_dict, checkpoint_path)

        # Save the full nemo file
        save_filepath = out_file
        if save_filepath is None:
            save_filepath = os.path.splitext(nemo_file)[0] + '_expanded_vocab_posemb.nemo'
            print(save_filepath)

        connector._make_nemo_file_from_folder(save_filepath, tmpdir)
        print("Done")

[0]
Parsing checkpoint at location:  /tmp/tmpv3fmhfls/model_weights.ckpt
Before enc_dec_model.encoder_embedding.word_embeddings.weight --> torch.Size([250112, 768])
Original shape : torch.Size([250112, 768]) 2
New vocab size : 260224
torch.Size([260224, 768])
FINAL state_dict[enc_dec_model.encoder_embedding.word_embeddings.weight].shape torch.Size([260224, 768])
After enc_dec_model.encoder_embedding.word_embeddings.weight --> torch.Size([260224, 768])
Before enc_dec_model.decoder_embedding.word_embeddings.weight --> torch.Size([250112, 768])
Original shape : torch.Size([250112, 768]) 2
New vocab size : 260224
torch.Size([260224, 768])
FINAL state_dict[enc_dec_model.decoder_embedding.word_embeddings.weight].shape torch.Size([260224, 768])
After enc_dec_model.decoder_embedding.word_embeddings.weight --> torch.Size([260224, 768])
Before enc_dec_model.tokens_head.bias --> torch.Size([250112])
Original shape : torch.Size([250112]) 1
New vocab size : 260224
torch.Size([260224])
FINAL state_d

In [7]:
print(state_dict['enc_dec_model.encoder_embedding.word_embeddings.weight'].size(),
state_dict['enc_dec_model.encoder_embedding.position_embeddings.weight'].size(),
state_dict['enc_dec_model.decoder_embedding.word_embeddings.weight'].size(),
state_dict['enc_dec_model.decoder_embedding.position_embeddings.weight'].size(),
state_dict['enc_dec_model.tokens_head.bias'].size())

torch.Size([260224, 768]) torch.Size([4096, 768]) torch.Size([260224, 768]) torch.Size([4096, 768]) torch.Size([260224])


In [8]:
for k in list(state_dict.keys()):
    if "position" in k:
        print(f"{k} --> {state_dict[k].shape}")

enc_dec_model.encoder_embedding.position_embeddings.weight --> torch.Size([4096, 768])
enc_dec_model.decoder_embedding.position_embeddings.weight --> torch.Size([4096, 768])


In [9]:
for k in list(state_dict.keys()):
    if "encoder_embedding" in k:
        print(f"{k} --> {state_dict[k].shape}")

enc_dec_model.encoder_embedding.word_embeddings.weight --> torch.Size([260224, 768])
enc_dec_model.encoder_embedding.position_embeddings.weight --> torch.Size([4096, 768])
