This used latent space of 1024 and hidden of only 512. But it seems to really struggle because of the smaller hidden space.
    

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.nn import functional as F
from torch.autograd import Variable
from torch import nn, optim
import torch.utils.data

# load as dask array
import dask.array as da
import dask
import h5py

import os
import glob
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

  from ._conv import register_converters as _register_converters


In [3]:
from world_models_sonic.models.vae import VAE6, loss_function_vae
from world_models_sonic.helpers.summarize import TorchSummarizeDf
from world_models_sonic.helpers.dataset import load_cache_data
from world_models_sonic.models.rnn import MDNRNN2
from world_models_sonic import config

# Init

In [4]:
cuda= torch.cuda.is_available()
env_name='sonic256'
num_epochs=200
batch_size = 2

# VAE loss function
lambda_vae_kld = 0.25
C = 0
z_dim = 256 # latent dimensions

# RNN
action_dim = 12
seq_len = 5
image_size=256
chunksize=seq_len*20



# loss function weights
lambda_vae = 1/100
lambda_finv = 1

data_cache_file = os.path.join(config.base_vae_data_dir, 'sonic_rnn_256_v30.hdf5')
NAME='RNN_v3b_256im_512z_1512_v2_greenfield'

# Load Data

In [5]:
loader_train, loader_test = load_cache_data(
    basedir=config.base_vae_data_dir, 
    env_name=env_name, 
    data_cache_file=data_cache_file, 
    image_size=image_size, 
    chunksize=chunksize, 
    action_dim=action_dim,
    batch_size=batch_size,
    seq_len=seq_len
)
loader_train, loader_test

Loaded from cache /MLDATA/sonic/vae/sonic_rnn_256_v30.hdf5


(<torch.utils.data.dataloader.DataLoader at 0x7f114a3a4be0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f114a1d61d0>)

# Load VAE

In [6]:
# Load VAE
# TODO swap z and k dim, since it's inconsistent with other models
vae = VAE6(image_size=image_size, z_dim=32, conv_dim=48, code_dim=8, k_dim=z_dim)
if cuda:
    vae.cuda()
    
# # Resume

# checkpoint_file = './outputs/models/VAE6_6_256im_512z_inception_CVAE_greenfields_state_dict.pkl'
# if os.path.isfile(checkpoint_file):
#     state_dict = torch.load(checkpoint_file)
#     vae.load_state_dict(state_dict)
#     print('loaded checkpoint_file {checkpoint_file}'.format(checkpoint_file=checkpoint_file))
    
save_file = './outputs/models/{NAME}-vae_state_dict.pkl'.format(NAME=NAME)
if os.path.isfile(save_file):
    state_dict = torch.load(save_file)
    vae.load_state_dict(state_dict)
    print('loaded save_file {save_file}'.format(save_file=save_file))

# Load RNN

In [7]:
# Load MDRNN
action_dim, hidden_size, n_mixture, temp = 12, 128, 3, 0.0


mdnrnn = MDNRNN2(z_dim, action_dim, hidden_size, n_mixture, temp)

if cuda:
    mdnrnn = mdnrnn.cuda()

In [8]:
# # Resume?
save_file = './outputs/models/{NAME}-mdnrnn_state_dict.pkl'.format(NAME=NAME)
if os.path.isfile(save_file):
    state_dict = torch.load(save_file)
    mdnrnn.load_state_dict(state_dict)
    print('loaded {save_file}'.format(save_file=save_file))

# Load inverse model

In [9]:
class FInv(torch.nn.modules.Module):
    def __init__(self, z_dim, action_dim, hidden_size):
        """
        Inverse model from https://arxiv.org/abs/1804.10689.
        
        Takes in z and z' and outputs predicted action
        """
        super().__init__()
        self.ln1 = nn.Linear(z_dim*2, hidden_size)
        self.ln2 = nn.Linear(hidden_size, hidden_size)
        self.ln3 = nn.Linear(hidden_size,  action_dim)
        
    def forward(self, z_now, z_next):
        x = torch.cat((z_now, z_next), dim=-1)
        x = F.relu(self.ln1(x))
        x = F.relu(self.ln2(x))
        x = F.sigmoid(self.ln3(x))
        return x

In [10]:
finv = FInv(z_dim, action_dim, hidden_size=256).cuda()

# Summarize models

In [11]:
img = np.random.randn(image_size, image_size, 3)
action = Variable(torch.from_numpy(np.random.randint(0,12,(12)))).float().cuda()[np.newaxis]
gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
if cuda:
    gpu_img = gpu_img.cuda()
with TorchSummarizeDf(vae) as tdf:
    x, mu_vae, logvar_vae = vae.forward(gpu_img)
    z = vae.sample(mu_vae, logvar_vae)
    df_vae = tdf.make_df()
#     loss_recon, loss_KLD = loss_function_vae(Y, x, mu_vae, sigma_vae)
#     loss_vae = loss_recon + lambda_vae_kld * torch.abs(loss_KLD-C)
#     loss_vae = loss_vae.mean() # mean along the batches


df_vae[df_vae.level<2]


Total parameters 8909862
Total trainable parameters 8909862


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
4,encoder.0,BasicConv2d,"[(-1, 3, 256, 256)]","[(-1, 48, 256, 256)]",1440,1
42,encoder.1,ConvBlock5,"[(-1, 48, 256, 256)]","[(-1, 96, 128, 128)]",93213,1
80,encoder.2,ConvBlock5,"[(-1, 96, 128, 128)]","[(-1, 144, 64, 64)]",281034,1
118,encoder.3,ConvBlock5,"[(-1, 144, 64, 64)]","[(-1, 192, 32, 32)]",566055,1
156,encoder.4,ConvBlock5,"[(-1, 192, 32, 32)]","[(-1, 240, 16, 16)]",948276,1
194,encoder.5,ConvBlock5,"[(-1, 240, 16, 16)]","[(-1, 288, 8, 8)]",1427697,1
232,encoder.6,ConvBlock5,"[(-1, 288, 8, 8)]","[(-1, 32, 8, 8)]",351550,1
233,mu,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
234,logvar,Linear,"[(-1, 2048)]","[(-1, 256)]",524544,0
235,z,Linear,"[(-1, 256)]","[(-1, 2048)]",526336,0


In [12]:
with TorchSummarizeDf(mdnrnn) as tdf: 
    pi, mu, sigma, hidden_state = mdnrnn.forward(z.unsqueeze(1).repeat((1,2,1)), action.unsqueeze(1).repeat((1,2,1)))
    z_next = mdnrnn.sample(pi, mu, sigma)
#     loss_mdn = mdnrnn.rnn_loss(z, pi, mu, sigma).mean()
    df_mdnrnn = tdf.make_df()
    
df_mdnrnn

Total parameters 1779712
Total trainable parameters 1779712


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,rnn,LSTM,"[[(-1, 2, 268)], [[(-1, 1, 128)], [(-1, 1, 128...","[[(-1, 2, 128)], [[(-1, 1, 128)], [(-1, 1, 128...",203776,0
2,ln1,Linear,"[(-1, 128), (-1, 128)]","[(-1, 128), (-1, 128)]",16512,0
3,ln2,Linear,"[(-1, 128), (-1, 128)]","[(-1, 640), (-1, 640)]",82560,0
4,mdn,Linear,"[(-1, 640), (-1, 640)]","[(-1, 2304), (-1, 2304)]",1476864,0


In [13]:
#     loss = loss_mdn + gamma_vae * loss_vae
with TorchSummarizeDf(finv) as tdf:
    action_pred = finv(z.repeat((1,2,1)), z_next)
        
    df_finv = tdf.make_df()
    
del img, action, gpu_img, x, mu, z
df_finv

Total parameters 200204
Total trainable parameters 200204


Unnamed: 0,name,class_name,input_shape,output_shape,nb_params,level
1,ln1,Linear,"[(-1, 2, 512)]","[(-1, 2, 256)]",131328,0
2,ln2,Linear,"[(-1, 2, 256)]","[(-1, 2, 256)]",65792,0
3,ln3,Linear,"[(-1, 2, 256)]","[(-1, 2, 12)]",3084,0


# Init

In [14]:
class Model(torch.nn.modules.Module):
    def __init__(self, vae, mdnrnn, finv):
        super().__init__()
        self.vae = vae
        self.mdnrnn = mdnrnn
        self.finv = finv
        
model = Model(vae, mdnrnn, finv)

In [15]:

torch.save(finv.state_dict(), './outputs/models/{NAME}-finv_state_dict.pkl'.format(NAME=NAME))

In [16]:

import torch.optim.lr_scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, verbose=True)

# optimizer_vae = optim.Adam(vae.parameters(), lr=3e-5)
# scheduler_vae = optim.lr_scheduler.ReduceLROnPlateau(optimizer_vae, mode='min', patience=3, verbose=True)

# Train helpers

In [17]:
# Plot reconstructions
def plot_results(loader, n=2, epoch=0, figsize=(9,6)):
    vae.eval()
    mdnrnn.eval()
    
    observations, actions, rewars, dones = next(iter(loader))
    
    X = Variable(observations.transpose(1,3))
    _, channels, height, width = X.size()
    if cuda:
        X=X.cuda()
    Y, mu_vae, logvar = vae.forward(X)
    loss_recon, loss_KLD = loss_function_vae(Y, X, mu_vae, logvar)
    loss_vae = loss_recon + lambda_vae_kld * torch.abs(loss_KLD-C)
    
    # TODO do we want to sample in test or training mode?
    z_v = vae.sample(mu_vae, logvar)
    
    z_v = z_v.view(batch_size, seq_len, -1)
    Y = Y.view((batch_size, seq_len, channels, height, width))
    X = X.view((batch_size, seq_len, channels, height, width))
    loss_vae = loss_vae.view(batch_size, seq_len, -1)
    actions = actions.view(batch_size, seq_len, -1)
    
    # Forward
    actions_v = Variable(actions).float()
    

    if cuda:
        z_v=z_v.cuda()
        actions_v=actions_v.cuda()
    pi, mu, sigma, hidden_state = mdnrnn.forward(z_v, actions_v)

    loss = mdnrnn.rnn_loss(z_v, pi, mu, sigma)
    
    mu = mu.mean(2).view((batch_size*seq_len, mdnrnn.z_dim))
    X_pred = vae.decode(mu)
    X_pred = X_pred.view((batch_size, seq_len, channels, height, width))
    
    # TODO finv    
    
    for i in np.linspace(0,seq_len-2,n):
        batch = np.random.randint(0,batch_size)
        i=int(i)
        y=Y[batch][i].cpu().data.transpose(0,2).numpy()
        x_orig = X[batch][i].transpose(0,2).data.cpu().numpy()
        x_next = X[batch][i+1].transpose(0,2).data.cpu().numpy()
        x_pred = X_pred[batch][i].transpose(0,2).data.cpu().numpy()
        loss_vae_i = loss_vae[batch][i].cpu().data.item()
        loss_i = loss[batch].cpu().data.item()
        
        plt.figure(figsize=figsize)
        
        plt.subplot(2, 3, 1)
        plt.axis("off")
        plt.title('original')
        plt.imshow(x_orig)

        plt.subplot(2, 3, 4)
        plt.axis("off")
        plt.imshow(y)
        plt.title('reconstructed')
           
        plt.subplot(2, 3, 2)
        plt.axis("off")
        plt.imshow(x_next)
        plt.title('true next')
        
        plt.subplot(2, 3, 5)
        plt.axis("off")
        plt.imshow(x_pred)
        plt.title('pred next')
        
        plt.subplot(2, 3, 3)
        plt.axis("off")
        plt.imshow(np.abs(x_orig-x_next))
        plt.title('actual changes')
        
        plt.subplot(2, 3, 6)
        plt.axis("off")
        plt.imshow(np.abs(y[i]-x_pred))
        plt.title('predicted changes')

        plt.suptitle('epoch {}, seq index {}, batch={}. vae_loss {:2.4f}, loss {:2.4f}'.format(
            epoch, 
            i,
            batch,
            loss_vae_i, 
            loss_i
        ))
#         plt.subplots_adjust(wspace=-.4, hspace=.1)#, bottom=0.1, right=0.8, top=0.9)
        plt.show()
        


TODO

- [ ] make a module containing all 3 including inverse model from https://arxiv.org/pdf/1804.10689.pdf
    - that way they can use the same optimizer
- [ ] do dual training

In [26]:
import collections

def train(loader, vae, mdnrnn, optimizer, max_batches=None, test=False, cuda=True, joint_training=False):
    vae.eval()
    if test:
        mdnrnn.eval()
    else:
        mdnrnn.train()
    info = collections.defaultdict(list)
    hidden_state = None
    if max_batches is None:
        max_batches = len(loader)
    else:
        max_batches = min(max_batches, len(loader))
    iterator = iter(loader)

    with tqdm(total=max_batches*loader.batch_size, mininterval=0.5, desc='test' if test else 'training') as prog:
        for i in range(max_batches):
            # the loader batch_size is seq_len*batch_size
            # we put it through the VAE as (seq_len*batch_size,...)
            # then reshape to (batch_size,seq_len,...) for the mdnrnn
            observations, actions, rewards, dones = next(iterator)
            X = Variable(observations.transpose(1,3))
            if cuda:
                X=X.cuda()
                
            # VAE forward
            Y, mu_vae, logvar = vae.forward(X)
            
            loss_recon, loss_KLD = loss_function_vae(Y, X, mu_vae, logvar)
            loss_vae = loss_recon + lambda_vae_kld * torch.abs(loss_KLD-C)
            loss_vae = loss_vae.mean() # mean along the batches

            # MDNRNN Forward
            z_v = vae.sample(mu_vae, logvar)
            z_v = z_v.view(batch_size, seq_len, -1)
            actions_v = Variable(actions).float()
            actions_v = actions_v.view(batch_size, seq_len, -1)
            if cuda:
                z_v=z_v.cuda()
                actions_v=actions_v.cuda()
            pi, mu, sigma, hidden_state = mdnrnn.forward(z_v, actions_v)

            # We are evaluating how the output distribution for the next step
            # matches the real next step. So we have to discard the last step in the 
            # sequence which has no next step.
            z_true_next = z_v[:,1:]
            loss_mdn = mdnrnn.rnn_loss(z_true_next, pi[:,:-1], mu[:,:-1], sigma[:,:-1]).mean()
#             loss_mdn += 10 # this is to make sure it stays positive
            
            # Finv forward
            z_next_pred = mdnrnn.sample(pi, mu, sigma)
            action_pred = finv(z_v[:,1:], z_next_pred[:,:-1])
            loss_inv = ((action_pred-actions_v[:,1:])**2).sum(-1)
            loss_inv = loss_inv.mean()
            
            loss = loss_mdn# + lambda_finv * loss_inv #+ lambda_vae * loss_vae

            if not test:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # Record
            info['loss_inv'].append(loss_inv.cpu().data.numpy())
            info['loss_mdn'].append(loss_mdn.cpu().data.numpy())
            info['loss_vae'].append(loss_vae.cpu().data.numpy())
            info['loss_recon'].append(loss_recon.mean().cpu().data.item())
            info['loss_KLD'].append(loss_KLD.mean().cpu().data.item())
            
            prog.update(loader.batch_size)
            prog.desc='loss={loss:2.4f}, loss_rnn={loss_mdn:2.4f}, loss_inv= {loss_inv2:2.4f}={lambda_finv}* {loss_inv:2.4f}, loss_vae={loss_vae:2.4f}={lambda_vae:2.4f} * ({loss_recon:2.2f} + {lambda_vae_kld}*|{loss_KLD:2.2f} - {C}|)'.format(
                loss=loss.cpu().data.item(),
                loss_mdn=np.mean(info['loss_mdn']), 
                loss_recon=np.mean(info['loss_recon']),
                loss_KLD=np.mean(info['loss_KLD']),
                loss_vae=lambda_vae*(np.mean(info['loss_recon'])+lambda_vae_kld*(np.mean(info['loss_KLD'])-C)),
                loss_inv=np.mean(info['loss_inv']),
                loss_inv2=np.mean(info['loss_inv'])*lambda_finv,
                lambda_vae_kld=lambda_vae_kld,
                lambda_finv=lambda_finv,
                lambda_vae=lambda_vae,
                C=C
            )
            if i%400==0:
                print('[{}/{}]'.format(i, max_batches), prog.desc)

        print(prog.desc)
        prog.close()

    return info

# Train

In [19]:
max_batches=30000//loader_train.batch_size
max_batches
torch.cuda.empty_cache()

In [20]:
# Load previous history
import pandas as pd
if os.path.isfile('./outputs/models/{NAME}.csv'.format(NAME=NAME)):
    histories = pd.read_csv('./outputs/models/{NAME}.csv'.format(NAME=NAME)).to_dict(orient='records')
else:
    histories = []

In [27]:
for epoch in range(num_epochs):
    # Run
    info = train(loader_train, vae, mdnrnn, optimizer, max_batches=max_batches, test=False, cuda=True, joint_training=True)
    torch.cuda.empty_cache()
    info_val = train(loader_test, vae, mdnrnn, optimizer, max_batches=max_batches//6, test=True, cuda=True, joint_training=True)
    torch.cuda.empty_cache()
    
    # Adjust
    scheduler.step(np.mean(info_val['loss_mdn']))
    
    # View
    print('Epoch {}, loss={:2.4f}, loss_val={:2.4f}, loss_vae={:2.4f}, loss_vae_val={:2.4f},  loss_finv={:2.4f}, loss_finv_vae={:2.4f}, ,'.format(
        epoch, 
        np.mean(info['loss_mdn']), 
        np.mean(info_val['loss_mdn']),
        np.mean(info['loss_vae']), 
        np.mean(info_val['loss_vae']),
        np.mean(info['loss_finv']),
        np.mean(info_val['loss_finv'])
    ))
    plot_results(loader_test, n=2, epoch=epoch)
    
    # Record
    history = {k+'_val':np.mean(v) for k,v in info_val.items()}
    history.update({k:np.mean(v) for k,v in info.items()})
    histories.append(history)
    
    torch.save(mdnrnn.state_dict(), './outputs/models/{NAME}-mdnrnn_{epoch}_state_dict.pkl'.format(NAME=NAME))
    torch.save(vae.state_dict(), './outputs/models/{NAME}-vae_{epoch}_state_dict.pkl'.format(NAME=NAME))
    torch.save(finv.state_dict(), './outputs/models/{NAME}-finv_{epoch}_state_dict.pkl'.format(NAME=NAME))
    
    # Tidy
    torch.cuda.empty_cache()

HBox(children=(IntProgress(value=0, description='training', max=60000), HTML(value='')))

[0/6000] loss=1.2277, loss_rnn=1.2277, loss_inv= 3.0198=1* 3.0198, loss_vae=356.3856=0.0100 * (35638.50 + 0.25*|0.26 - 0|)



RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

In [None]:
import pandas as pd
df_history = pd.DataFrame(histories)
df_history.plot()

In [None]:
df_history[['loss_mdn','loss_mdn_val']].plot()

In [None]:
df_history[['loss_vae','loss_vae_val']].plot()

In [None]:
df_history[['loss_inv','loss_inv_val']].plot()

## Save

In [29]:
torch.save(mdnrnn.state_dict(), f'./outputs/models/{NAME}-mdnrnn_state_dict.pkl')
torch.save(vae.state_dict(), f'./outputs/models/{NAME}-vae_state_dict.pkl')
torch.save(finv.state_dict(), f'./outputs/models/{NAME}-finv_state_dict.pkl')
df_history.to_csv(f'./outputs/models/{NAME}.csv', index=False)

# torch.save(mdnrnn, f'./outputs/models/{NAME}-mdnrnn.pkl')
# torch.save(vae, f'./outputs/models/{NAME}-vae')
# torch.save(finv.state_dict(), f'./outputs/models/{NAME}-finv.pkl')

SyntaxError: invalid syntax (<ipython-input-29-efe22a6dfec1>, line 1)

## View

In [28]:
plot_results(loader_test, n=4, epoch=0)

RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

In [None]:
plot_results(loader_train, n=4, epoch=0)

In [None]:
torch.cuda.empty_cache()

# DEBUG

In [None]:
# DEBUG the distributions
vae.train()
mdnrnn.train()

observations, actions, rewards, dones = next(iter(loader_train))

X = Variable(observations.transpose(1,3))
_, channels, height, width = X.size()
if cuda:
    X=X.cuda()
Y, mu_vae, logvar = vae.forward(X)
loss_recon, loss_KLD = loss_function_vae(Y, X, mu_vae, logvar)
loss_vae = loss_recon + lambda_vae_kld * torch.abs(loss_KLD-C)

# TODO do we want to sample in test or training mode?
z_v = vae.sample(mu_vae, logvar)

z_v = z_v.view(batch_size, seq_len, -1)
Y = Y.view((batch_size, seq_len, channels, height, width))
X = X.view((batch_size, seq_len, channels, height, width))
loss_vae = loss_vae.view(batch_size, seq_len, -1)
actions = actions.view(batch_size, seq_len, -1)

# Forward
actions_v = Variable(actions).float()


if cuda:
    z_v=z_v.cuda()
    actions_v=actions_v.cuda()
pi, mu, sigma, hidden_state = mdnrnn.forward(z_v, actions_v)

loss = mdnrnn.rnn_loss(z_v, pi, mu, sigma)

# mu = mu.mean(2).view((batch_size*seq_len, mdnrnn.z_dim))
# X_pred = vae.decode(mu)
# X_pred = X_pred.view((batch_size, seq_len, channels, height, width))


mdnrnn.train()

zs=mdnrnn.sample(pi, mu, sigma)







z_v = vae.sample(mu_vae, logvar)
plt.hist(z_v.cpu().data.numpy().flatten(), bins=50)
plt.title('z_v')
plt.show()

plt.hist(zs.cpu().data.numpy().flatten(), bins=50)
plt.title('z_pred')
plt.show()

plt.hist(mu_vae.cpu().data.numpy().flatten(), bins=50)
plt.title('mu_vae')
plt.show()

plt.hist(mu.cpu().data.numpy().flatten(), bins=50)
plt.title('mu')
plt.show()

plt.hist(logvar.exp().cpu().data.numpy().flatten(), bins=50)
plt.title('sigma_vae')
plt.show()

plt.hist(sigma.cpu().data.numpy().flatten(), bins=50)
plt.title('sigma')
plt.show()
