In [None]:
extend_to = None #<= CHANGE THIS. The original is 178 symbols

#⚠️ Ensure the total number of symbols in meldataset.py matches your extend_to!

In [None]:
# load packages
%cd ..
import yaml
import torch
from torch import nn
import os

from models import *
from utils import *

def represent_list_flow(dumper, data):
    return dumper.represent_sequence(
        yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG,
        data,
        flow_style=True
    )
yaml.SafeDumper.add_representer(list, represent_list_flow)

try:
    config = yaml.safe_load(open("./Configs/config.yaml"))
    model_params = recursive_munch(config['model_params'])
    model = build_model(model_params)
except Exception as e:
    print(e)
device = 'cpu'

In [None]:
keys_to_keep = {'predictor', 'decoder', 'text_encoder', 'style_encoder', 'text_aligner', 'pitch_extractor', 'mpd', 'msd'}
params_whole = torch.load("Models/Finetune_Remove/og_finetune.pth", map_location='cpu')
params = params_whole['net']
params = {key: value for key, value in params.items() if key in keys_to_keep}

for key in list(model.keys()):
    if key not in keys_to_keep:
        del model[key]

for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)

old_weight = [
    model['text_encoder'].embedding,
    model['text_aligner'].ctc_linear[2].linear_layer,
    model['text_aligner'].asr_s2s.embedding,
    model['text_aligner'].asr_s2s.project_to_n_symbols
]
print("\nOld shape:") 
for module in old_weight:
    print(module, module.weight.shape)

for i in range(len(old_weight)):
    new_shape = (extend_to, old_weight[i].weight.shape[1])
    new_weight = torch.randn(new_shape) * 0.01 #init mean=0, std=0.01
    with torch.no_grad():
        new_weight[:old_weight[i].weight.size(0), :] = old_weight[i].weight.detach().clone()
    new_param = nn.Parameter(new_weight, requires_grad=True)

    if isinstance(old_weight[i], nn.Embedding):
        old_weight[i].num_embeddings = extend_to
        
    if isinstance(old_weight[i], nn.Linear):
        old_weight[i].out_features = extend_to
        #update bias
        old_bias = old_weight[i].bias.detach()
        old_dim = old_bias.shape[0]
        new_bias = torch.zeros(extend_to)
        new_bias[:old_dim] = old_bias.clone()
        old_weight[i].bias.data = new_bias

    old_weight[i].weight = new_param

print("\nNew shape:")
for module in old_weight:
    print(module, module.weight.shape)

In [None]:
save_path = "./Extend/New_Weights"

if not os.path.exists(save_path):
    os.mkdir(save_path)

#save new weights
state = {
    'net':  {key: model[key].state_dict() for key in model}, 
    'optimizer': params_whole['optimizer'],
    'iters': params_whole['iters'],
    'val_loss': params_whole['val_loss'],
    'epoch': params_whole['epoch'],
}
torch.save(state, os.path.join(save_path, 'extended.pth'))

#save new config
config['model_params']['ASR_params']['n_token'] = extend_to
config['model_params']['n_token'] = extend_to
with open(os.path.join(save_path, 'config.yaml'), 'w') as outfile:
    yaml.safe_dump(config, outfile, 
                    default_flow_style=False,
                    sort_keys=False,
                    indent=4)