## Change state dicts' keys for `DistributedDataParallel`

A model that is wrapped with `DistributedDataParallel` expects its state dict's keys to start with `'module.'`, followed by the usual keys when `DistributedDataParellel` is not used.  In order to load the pretrained models that come with the original ESRGAN-pytorch, which are not wrapped with `DistributedDataParallel`, their state dicts' keys are prepended with `'module.'` beforehand so that they can be loaded into models wrapped with `DistributedDataParallel`.  This is done once.  The resulting state dicts are saved to disk for future use.

In [1]:
import torch

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [9]:
models = ('psnr', 'gan')

for m in models:
    g_sdict = torch.load(f'../parameters/{m}.pth', map_location=device)
    g_sdict = {f'module.{k}':v for k, v in g_sdict.items()}
    save_dict = {
        'epoch':None, 
        'unit_scheduler_step':None, 
        'history':None,
        'g_state_dict':g_sdict, 
        'd_state_dict':None, 
        'opt_g_state_dict':None, 
        'opt_d_state_dict':None, 
        'amp':None, 
        'args':None}
    torch.save(save_dict, f'../parameters/{m}_based.pth')

In [10]:
! du -hs ../parameters/*

 64M	../parameters/gan.pth
 64M	../parameters/gan_based.pth
 64M	../parameters/psnr.pth
 64M	../parameters/psnr_based.pth
