In [None]:
from transformers import AutoModelForMaskedLM
import torch
import copy 

def load_and_average_checkpoints(checkpoint_paths):
    models = []
    for path in checkpoint_paths:
        checkpoint = torch.load(path)
        base_model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
        base_model.load_state_dict(checkpoint, strict=False)
        models.append(base_model)
        
        print(f"\nWeights from {path}:")
        print(checkpoint['esm.embeddings.word_embeddings.weight'][:5, :5])
    
    soup_model = copy.deepcopy(models[0])
    with torch.no_grad():
        for name, param in soup_model.named_parameters():
            param_sum = torch.zeros_like(param)
            for model in models:
                param_sum += model.state_dict()[name]
            param.copy_(param_sum / len(models))
    
    print("\nAveraged weights:")
    print(soup_model.state_dict()['esm.embeddings.word_embeddings.weight'][:5, :5])
    
    return soup_model

paths = [
    r"C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_0.pt",
    r"C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_1.pt",
    r"C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_2.pt"
]

soup_model = load_and_average_checkpoints(paths)


Weights from C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_0.pt:
tensor([[-3.6652e-03, -1.9487e-01,  8.1242e-03, -4.9705e-03,  1.4702e-02],
        [-1.1057e-01, -5.2987e-01,  1.2668e-01, -1.0137e-02,  3.4593e-01],
        [ 3.2176e-02, -2.0748e-01, -4.8242e-03, -1.4734e-02,  1.5920e-02],
        [-1.1008e-01, -5.3231e-01,  1.2728e-01, -3.9668e-03,  3.4007e-01],
        [ 1.7804e-01,  5.5836e-02, -3.5763e-04,  7.5548e-03, -7.6871e-03]])

Weights from C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_1.pt:
tensor([[-0.0124, -0.1433, -0.0062,  0.0175,  0.0135],
        [-0.0606, -0.4748,  0.0759, -0.0090,  0.2915],
        [-0.0086, -0.1552, -0.0038,  0.0099, -0.0093],
        [-0.0601, -0.4772,  0.0764, -0.0149,  0.2857],
        [ 0.1374,  0.0363,  0.0121, -0.0313, -0.0030]])

Weights from C:\Users\Sri Seshadri\Desktop\Sri RLXF\trained_models\SFT_ESM2_35M\sft_updated_esm2_version_2.pt:
tensor([[-0.