In [1]:
import os
import cv2
import math
import time
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import cm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split, DataLoader, Dataset
from torchsummary import summary

In [2]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [3]:
random_seed = 42
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1f01d58eaf0>

In [4]:
torch.set_printoptions(edgeitems=5)

In [5]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True) #, dtype=torch.float

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [7]:
device = get_default_device()
print(device)

cuda


In [28]:
ckpt1 = torch.load('netG_50.pth')
ckpt2 = torch.load('netG_1200.pth')

In [16]:
print(len(ckpt1.keys()))
print(len(list(ckpt1.keys())))

for key in list(ckpt1.keys()):
    print(key)

384
384
all_modules.0.weight
all_modules.0.bias
all_modules.1.weight
all_modules.1.bias
all_modules.2.weight
all_modules.2.bias
all_modules.3.GroupNorm_0.style.weight
all_modules.3.GroupNorm_0.style.bias
all_modules.3.Conv_0.weight
all_modules.3.Conv_0.bias
all_modules.3.Dense_0.weight
all_modules.3.Dense_0.bias
all_modules.3.GroupNorm_1.style.weight
all_modules.3.GroupNorm_1.style.bias
all_modules.3.Conv_1.weight
all_modules.3.Conv_1.bias
all_modules.4.GroupNorm_0.style.weight
all_modules.4.GroupNorm_0.style.bias
all_modules.4.Conv_0.weight
all_modules.4.Conv_0.bias
all_modules.4.Dense_0.weight
all_modules.4.Dense_0.bias
all_modules.4.GroupNorm_1.style.weight
all_modules.4.GroupNorm_1.style.bias
all_modules.4.Conv_1.weight
all_modules.4.Conv_1.bias
all_modules.5.GroupNorm_0.style.weight
all_modules.5.GroupNorm_0.style.bias
all_modules.5.Conv_0.weight
all_modules.5.Conv_0.bias
all_modules.5.Dense_0.weight
all_modules.5.Dense_0.bias
all_modules.5.GroupNorm_1.style.weight
all_modules.5.G

In [17]:
print(len(ckpt2.keys()))
print(len(list(ckpt2.keys())))

for key in list(ckpt2.keys()):
    print(key)

384
384
module.all_modules.0.weight
module.all_modules.0.bias
module.all_modules.1.weight
module.all_modules.1.bias
module.all_modules.2.weight
module.all_modules.2.bias
module.all_modules.3.GroupNorm_0.style.weight
module.all_modules.3.GroupNorm_0.style.bias
module.all_modules.3.Conv_0.weight
module.all_modules.3.Conv_0.bias
module.all_modules.3.Dense_0.weight
module.all_modules.3.Dense_0.bias
module.all_modules.3.GroupNorm_1.style.weight
module.all_modules.3.GroupNorm_1.style.bias
module.all_modules.3.Conv_1.weight
module.all_modules.3.Conv_1.bias
module.all_modules.4.GroupNorm_0.style.weight
module.all_modules.4.GroupNorm_0.style.bias
module.all_modules.4.Conv_0.weight
module.all_modules.4.Conv_0.bias
module.all_modules.4.Dense_0.weight
module.all_modules.4.Dense_0.bias
module.all_modules.4.GroupNorm_1.style.weight
module.all_modules.4.GroupNorm_1.style.bias
module.all_modules.4.Conv_1.weight
module.all_modules.4.Conv_1.bias
module.all_modules.5.GroupNorm_0.style.weight
module.all_m

In [20]:
for key in list(ckpt2.keys()):
    ckpt2[key[7:]] = ckpt2.pop(key)

In [25]:
print(len(ckpt2.keys()))
print(len(list(ckpt2.keys())))

for key in list(ckpt2.keys()):
    print(key)

384
384
all_modules.0.weight
all_modules.0.bias
all_modules.1.weight
all_modules.1.bias
all_modules.2.weight
all_modules.2.bias
all_modules.3.GroupNorm_0.style.weight
all_modules.3.GroupNorm_0.style.bias
all_modules.3.Conv_0.weight
all_modules.3.Conv_0.bias
all_modules.3.Dense_0.weight
all_modules.3.Dense_0.bias
all_modules.3.GroupNorm_1.style.weight
all_modules.3.GroupNorm_1.style.bias
all_modules.3.Conv_1.weight
all_modules.3.Conv_1.bias
all_modules.4.GroupNorm_0.style.weight
all_modules.4.GroupNorm_0.style.bias
all_modules.4.Conv_0.weight
all_modules.4.Conv_0.bias
all_modules.4.Dense_0.weight
all_modules.4.Dense_0.bias
all_modules.4.GroupNorm_1.style.weight
all_modules.4.GroupNorm_1.style.bias
all_modules.4.Conv_1.weight
all_modules.4.Conv_1.bias
all_modules.5.GroupNorm_0.style.weight
all_modules.5.GroupNorm_0.style.bias
all_modules.5.Conv_0.weight
all_modules.5.Conv_0.bias
all_modules.5.Dense_0.weight
all_modules.5.Dense_0.bias
all_modules.5.GroupNorm_1.style.weight
all_modules.5.G