## See whether I can convert a MDM snapshot to safetensors and load in my model code

In [1]:
import torch

In [2]:
snapshot = "/home/stefanwebb/code/python/motion-diffusion-model/save/my_humanml_trans_enc_512/model000600000.pt"
model_snapshot = torch.load(snapshot, weights_only=False)

In [3]:
list(model_snapshot.keys())

['input_process.poseEmbedding.weight',
 'input_process.poseEmbedding.bias',
 'sequence_pos_encoder.pe',
 'seqTransEncoder.layers.0.self_attn.in_proj_weight',
 'seqTransEncoder.layers.0.self_attn.in_proj_bias',
 'seqTransEncoder.layers.0.self_attn.out_proj.weight',
 'seqTransEncoder.layers.0.self_attn.out_proj.bias',
 'seqTransEncoder.layers.0.linear1.weight',
 'seqTransEncoder.layers.0.linear1.bias',
 'seqTransEncoder.layers.0.linear2.weight',
 'seqTransEncoder.layers.0.linear2.bias',
 'seqTransEncoder.layers.0.norm1.weight',
 'seqTransEncoder.layers.0.norm1.bias',
 'seqTransEncoder.layers.0.norm2.weight',
 'seqTransEncoder.layers.0.norm2.bias',
 'seqTransEncoder.layers.1.self_attn.in_proj_weight',
 'seqTransEncoder.layers.1.self_attn.in_proj_bias',
 'seqTransEncoder.layers.1.self_attn.out_proj.weight',
 'seqTransEncoder.layers.1.self_attn.out_proj.bias',
 'seqTransEncoder.layers.1.linear1.weight',
 'seqTransEncoder.layers.1.linear1.bias',
 'seqTransEncoder.layers.1.linear2.weight',
 '

In [4]:
from model import MotionDiffusionModel
model = MotionDiffusionModel()



In [5]:
param_names = [x for x, _ in list(model.named_parameters())]

In [14]:
replacements = [
    ('seqTransEncoder', 'encoder'),
    ('embed_timestep', 'timestep_encoder'),
    ('embed_text', 'text_proj'),
    ('output_process.poseFinal', 'output_proj'),
    ('sequence_pos_encoder', 'pos_encoder'),
    ('input_process.poseEmbedding', 'input_proj'),
]

def map_param_name(param_name):
    for old, new in replacements:
        param_name = param_name.replace(old, new)
    return param_name

param_names

['encoder.layers.0.self_attn.in_proj_weight',
 'encoder.layers.0.self_attn.in_proj_bias',
 'encoder.layers.0.self_attn.out_proj.weight',
 'encoder.layers.0.self_attn.out_proj.bias',
 'encoder.layers.0.linear1.weight',
 'encoder.layers.0.linear1.bias',
 'encoder.layers.0.linear2.weight',
 'encoder.layers.0.linear2.bias',
 'encoder.layers.0.norm1.weight',
 'encoder.layers.0.norm1.bias',
 'encoder.layers.0.norm2.weight',
 'encoder.layers.0.norm2.bias',
 'encoder.layers.1.self_attn.in_proj_weight',
 'encoder.layers.1.self_attn.in_proj_bias',
 'encoder.layers.1.self_attn.out_proj.weight',
 'encoder.layers.1.self_attn.out_proj.bias',
 'encoder.layers.1.linear1.weight',
 'encoder.layers.1.linear1.bias',
 'encoder.layers.1.linear2.weight',
 'encoder.layers.1.linear2.bias',
 'encoder.layers.1.norm1.weight',
 'encoder.layers.1.norm1.bias',
 'encoder.layers.1.norm2.weight',
 'encoder.layers.1.norm2.bias',
 'encoder.layers.2.self_attn.in_proj_weight',
 'encoder.layers.2.self_attn.in_proj_bias',
 '

In [7]:
mapped_names = list(sorted([ map_param_name(x) for x in list(model_snapshot.keys()) if not x.endswith('pe')]))

In [22]:
mapped_names

['encoder.layers.0.linear1.bias',
 'encoder.layers.0.linear1.weight',
 'encoder.layers.0.linear2.bias',
 'encoder.layers.0.linear2.weight',
 'encoder.layers.0.norm1.bias',
 'encoder.layers.0.norm1.weight',
 'encoder.layers.0.norm2.bias',
 'encoder.layers.0.norm2.weight',
 'encoder.layers.0.self_attn.in_proj_bias',
 'encoder.layers.0.self_attn.in_proj_weight',
 'encoder.layers.0.self_attn.out_proj.bias',
 'encoder.layers.0.self_attn.out_proj.weight',
 'encoder.layers.1.linear1.bias',
 'encoder.layers.1.linear1.weight',
 'encoder.layers.1.linear2.bias',
 'encoder.layers.1.linear2.weight',
 'encoder.layers.1.norm1.bias',
 'encoder.layers.1.norm1.weight',
 'encoder.layers.1.norm2.bias',
 'encoder.layers.1.norm2.weight',
 'encoder.layers.1.self_attn.in_proj_bias',
 'encoder.layers.1.self_attn.in_proj_weight',
 'encoder.layers.1.self_attn.out_proj.bias',
 'encoder.layers.1.self_attn.out_proj.weight',
 'encoder.layers.2.linear1.bias',
 'encoder.layers.2.linear1.weight',
 'encoder.layers.2.lin

In [8]:
param_names = list(sorted([x for x, _ in list(model.named_parameters())]))

In [11]:
len(param_names), len(mapped_names)

(106, 106)

In [None]:
for x, y in zip(param_names, mapped_names):
    print(x == y)

In [23]:
mapped_state_dict = {map_param_name(k): v for k, v in model_snapshot.items() if map_param_name(k) != 'pos_encoder.pe'}

In [24]:
model.load_state_dict(mapped_state_dict, strict=False)

<All keys matched successfully>

In [25]:
from safetensors.torch import save_file
save_file(mapped_state_dict, "model.safetensors")