In [18]:
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn

from collections import OrderedDict

In [2]:
checkpoint_path = 'C:\\Users\\Hasegawa\\Desktop\\AIT Note\\Thesis\\other implementation\\Carla-ppo\\vae\\models\\seg_bce_cnn_zdim64_beta1_kl_tolerance0.0_data\\checkpoints'

In [3]:
class ConvVAE(nn.Module):
    def __init__(self):
        super().__init__()
        
        latent_size = 64
        image_size = (80, 160)
        settings = [
            # out channel, k, s, p
            (16, 4, 2, 0),
            (32, 4, 2, 0),
            (32, 4, 2, 0),
            (64, 4, 2, 0),
            (64, 4, 2, 0),
        ]

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2),
            nn.ReLU(),
        )
        
        (self.encoded_H, self.encoded_W), size_hist = self._calculate_spatial_size(image_size, self.encoder)
        
        self.mean = nn.Linear(self.encoded_H * self.encoded_W * 256, latent_size)
        self.logstd = nn.Linear(self.encoded_H * self.encoded_W * 256, latent_size)
        
        # latent
        self.latent = nn.Linear(latent_size, self.encoded_H * self.encoded_W * 256)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.flatten(start_dim=1)
        return self.mean(x), self.logstd(x)
    
    def decode(self, z):
        z = self.latent(z)
        z = z.view(-1, 256, self.encoded_H, self.encoded_W)
        z = self.decoder(z)
        return z

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, x, encode=False, mean=False):
        mu, logstd = self.encode(x)
        z = mu
        x = self.decode(z)
        return x, mu, logstd
    
    def _calculate_spatial_size(self, image_size, conv_layers):
        ''' Calculate spatial size after convolution layers '''
        H, W = image_size
        size_hist = []
        size_hist.append((H, W))

        for layer in conv_layers:
            if layer.__class__.__name__ != 'Conv2d':
                continue
            conv = layer
            H = int((H + 2 * conv.padding[0] - conv.dilation[0] * (conv.kernel_size[0] - 1) - 1) / conv.stride[0] + 1)
            W = int((W + 2 * conv.padding[1] - conv.dilation[1] * (conv.kernel_size[1] - 1) - 1) / conv.stride[1] + 1)

            size_hist.append((H, W))

        return (H, W), size_hist

In [4]:
model = ConvVAE()

In [5]:
dummy = torch.rand(1, 3, 80, 160)
dummy_hat, mu, logstd = model(dummy)

assert dummy_hat.size() == (1, 1, 80, 160)

In [6]:
params_name = [n for n, param in model.named_parameters()]
params = [param for n, param in model.named_parameters()]

In [7]:
params_name

['encoder.0.weight',
 'encoder.0.bias',
 'encoder.2.weight',
 'encoder.2.bias',
 'encoder.4.weight',
 'encoder.4.bias',
 'encoder.6.weight',
 'encoder.6.bias',
 'mean.weight',
 'mean.bias',
 'logstd.weight',
 'logstd.bias',
 'latent.weight',
 'latent.bias',
 'decoder.0.weight',
 'decoder.0.bias',
 'decoder.2.weight',
 'decoder.2.bias',
 'decoder.4.weight',
 'decoder.4.bias',
 'decoder.6.weight',
 'decoder.6.bias']

In [8]:
init_vars = tf.train.list_variables(checkpoint_path)
names = []
arrays = []
for name, shape in init_vars:
    array = tf.train.load_variable(checkpoint_path, name)
    names.append(name)
    arrays.append(array)

In [20]:
model_state_dict = model.state_dict()
new_state_dict = OrderedDict()

In [22]:
def to_torch_conv_index(idx):
    return str((idx - 1) * 2)

def to_pretained_conv_index(idx):
    return str(idx / 2 + 1)

for name, array in zip(names, arrays):
    splited_names = name.split('/')
    # skip unnecessary parameter
    if any(n in ['Adam', 'Adam_1', 'step_idx', 'beta1_power', 'beta2_power'] for n in splited_names):
        continue

    splited_names.pop(0)
    # print(name)
    
    if splited_names[0] == 'encoder':
        # splited_names[1] is in "convX" format where X is number
        splited_names[1] = to_torch_conv_index(int(splited_names[1][-1]))
    elif splited_names[0] == 'decoder':
        if splited_names[1] == 'dense1':
            splited_names = ['latent', splited_names[-1]]
        else:
            # splited_names[1] is in "convX" format where X is number
            splited_names[1] = to_torch_conv_index(int(splited_names[1][-1]))
    elif splited_names[0] == 'mean':
        pass
    elif splited_names[0] == 'logstd_sqare':
        splited_names[0] = 'logstd'
    else:
        raise Exception(f'not support key: {name}')
        
    if splited_names[-1] == 'kernel':
        splited_names[-1] = 'weight'
        array = array.transpose()
        
    new_key = '.'.join(splited_names)
        
    current_param = model_state_dict[new_key]
    if current_param.size() != array.shape:
        raise Exception(f'key {new_key} has mismatch weight shape: {current_param.size()} and {array.shape}')
        
    new_state_dict[new_key] = torch.from_numpy(array)

In [23]:
torch.save(new_state_dict, 'carla-ppo-seg-vae.pkl')

In [11]:
import numpy as np

In [12]:
np.array([0, 2, 4, 6]) / 2 + 1

array([1., 2., 3., 4.])

In [13]:
(np.array([1, 2, 3, 4]) - 1) * 2

array([0, 2, 4, 6])

In [24]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>