In [1]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
def add_networks(dst_net, src_net):
    params1 = src_net.named_parameters()
    params2 = dst_net.named_parameters()
    dict_params2 = dst_net.state_dict()
    for name1, param1 in params1:
        if name1 in dict_params2:
            dict_params2[name1].data.copy_(param1.data + dict_params2[name1].data)
    dst_net.load_state_dict(dict_params2)
    return dst_net

def apply_denominator(dst_net, denominator):
    denominator_inv = 1.0 / denominator
    params = dst_net.named_parameters()
    dict_params = dst_net.state_dict()
    for name, param in params:
        dict_params[name].data.copy_(dict_params[name].data / denominator)
    dst_net.load_state_dict(dict_params)
    return dst_net

In [3]:
aug_model_path = 'checkpoints/'
aug_models = ["_media","_sampo","_nothing2","s1","s2","s3"]
num_models = len(aug_models)

In [4]:
text_to_images = []
image_to_texts = []
for am in aug_models:
    text_to_image = torch.load(aug_model_path + "t2i" + am + ".pt")
    text_to_image.requires_grad_(False).eval().to(device)
    image_to_text = torch.load(aug_model_path + "i2t" + am + ".pt")
    image_to_text.requires_grad_(False).eval().to(device)
    text_to_images.append(text_to_image)
    image_to_texts.append(image_to_text)

In [5]:
avg_t2i = None
avg_i2t = None
for i in range(0, num_models):
        if avg_t2i == None:
            avg_t2i = text_to_images[i]
        else:
            avg_t2i = add_networks(avg_t2i, text_to_images[i])
        
        if avg_i2t == None:
            avg_i2t = image_to_texts[i]
        else:
            avg_i2t = add_networks(avg_i2t, image_to_texts[i])
apply_denominator(avg_t2i, num_models)
apply_denominator(avg_i2t, num_models)
torch.save(avg_t2i, 't2i_avg1.pt')
torch.save(avg_i2t, 'i2t_avg1.pt')