In [None]:
extend_to = None #<= CHANGE THIS. The original is 178 symbols

save_path = "./Extend/New_Weights"
config_path = "./Configs/config.yaml"
model_path = "./Models/Finetune/base_model.pth"

#⚠️ Must run this notebook first before adding any symbol to the config file

In [None]:
# load packages
%cd ..
import yaml
import torch
from torch import nn
import os
from models import *
from utils import *
import warnings
warnings.filterwarnings("ignore")

device = 'cpu'

config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
try:
    symbols = (
                    list(config['symbol']['pad']) +
                    list(config['symbol']['punctuation']) +
                    list(config['symbol']['letters']) +
                    list(config['symbol']['letters_ipa']) +
                    list(config['symbol']['extend'])
                )
    symbol_dict = {}
    for i in range(len((symbols))):
        symbol_dict[symbols[i]] = i

    n_token = len(symbol_dict) + 1
    print("\nFound", n_token, "symbols in the original config file")
except Exception as e:
    print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
    raise SystemExit(1)

if (extend_to-n_token) <= 0:
    print(f"\nERROR: Cannot extend from {n_token} to {extend_to}.")
    raise SystemExit(1)

model_params = recursive_munch(config['model_params'])
model_params['n_token'] = n_token
model = build_model(model_params)

keys_to_keep = {'predictor', 'decoder', 'text_encoder', 'style_encoder', 'text_aligner', 'pitch_extractor', 'mpd', 'msd'}
params_whole = torch.load(model_path, map_location='cpu')
params = params_whole['net']
params = {key: value for key, value in params.items() if key in keys_to_keep}

for key in list(model.keys()):
    if key not in keys_to_keep:
        del model[key]

for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)

old_weight = [
    model['text_encoder'].embedding,
    model['text_aligner'].ctc_linear[2].linear_layer,
    model['text_aligner'].asr_s2s.embedding,
    model['text_aligner'].asr_s2s.project_to_n_symbols
]
print("\nOld shape:") 
for module in old_weight:
    print(module, module.weight.shape)

for i in range(len(old_weight)):
    new_shape = (extend_to, old_weight[i].weight.shape[1])
    new_weight = torch.randn(new_shape) * 0.01 #init mean=0, std=0.01
    with torch.no_grad():
        new_weight[:old_weight[i].weight.size(0), :] = old_weight[i].weight.detach().clone()
    new_param = nn.Parameter(new_weight, requires_grad=True)

    if isinstance(old_weight[i], nn.Embedding):
        old_weight[i].num_embeddings = extend_to
        
    if isinstance(old_weight[i], nn.Linear):
        old_weight[i].out_features = extend_to
        #update bias
        old_bias = old_weight[i].bias.detach()
        old_dim = old_bias.shape[0]
        new_bias = torch.zeros(extend_to)
        new_bias[:old_dim] = old_bias.clone()
        old_weight[i].bias.data = new_bias

    old_weight[i].weight = new_param

print("\nNew shape:")
for module in old_weight:
    print(module, module.weight.shape)

if not os.path.exists(save_path):
    os.mkdir(save_path)

print(f"\n\n✅ Successfully extended the token set to a maximum of {extend_to} symbols.")
print(f"You can now add {extend_to - n_token} additional symbols in the config file.")

#save new weights
state = {
    'net':  {key: model[key].state_dict() for key in model}, 
    'optimizer': None,
    'iters': 0,
    'val_loss': 0,
    'epoch': 0,
}
torch.save(state, os.path.join(save_path, 'extended.pth'))