# Self-Supervised Shared Latent Embedding (S$^3$LE)

In [1]:
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import scipy.io as sio
from swae import SharedWassersteinAutoEncoderClass
from util import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'    
print ("Ready.")

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Ready.


### Select which robot model to use

In [2]:
model_name = 'coman'
print("model_name:[%s]"%(model_name)) 

model_name:[coman]


### Load

In [3]:
PRINT_EACH_MAT = 0

# 1. Paired data from SSL sampling
math_paths = glob.glob('../data/Pairs of %s*.mat'%(model_name))
for m_idx,mat_path in enumerate(math_paths):
    l = sio.loadmat(mat_path) # load
    if m_idx == 0:
        q_pair_tildes,q_pair_bars,x_pair_tildes,x_pair_bars = \
            l['q_tildes'],l['q_bars'],l['x_tildes'],l['x_bars']
    else:
        # Append
        for gamma in [0,0.2,0.5]: 
            q_pair_tildes = np.concatenate((q_pair_tildes,l['q_tildes']),axis=0)
            q_pair_bars = np.concatenate((q_pair_bars,l['q_bars']),axis=0)
            x_pair_tildes = np.concatenate((x_pair_tildes,gamma*l['x_bars']+(1-gamma)*l['x_tildes']),axis=0)
            x_pair_bars = np.concatenate((x_pair_bars,l['x_bars']),axis=0)
print ("[%d] paired data from SSL sampling."%(q_pair_tildes.shape[0]))

# 2. Paired data from motion retargeting 
math_paths = glob.glob('../data/Glue of CMU mocap*%s.mat'%(model_name))
math_paths = math_paths[:100] # use upto 100
for m_idx,mat_path in enumerate(math_paths):
    l = sio.loadmat(mat_path) # load
    if m_idx == 0:
        q_glue_tildes,q_glue_bars,x_glue_tildes,x_glue_bars = \
            l['q_tildes'],l['q_bars'],l['x_tildes'],l['x_bars']
    else:
        for gamma in [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8]: 
            q_glue_tildes = np.concatenate((q_glue_tildes,l['q_tildes']),axis=0)
            q_glue_bars = np.concatenate((q_glue_bars,l['q_bars']),axis=0)
            x_glue_tildes = np.concatenate((x_glue_tildes,gamma*l['x_bars']+(1-gamma)*l['x_tildes']),axis=0)
            x_glue_bars = np.concatenate((x_glue_bars,l['x_bars']),axis=0)
print ("[%d] glue data from [%d] motion retargeting data."%(q_glue_bars.shape[0],len(math_paths)))  

# 3. Robot-specific data
math_paths = glob.glob('../data/Domain of %s*.mat'%(model_name))
for m_idx,mat_path in enumerate(math_paths):
    l = sio.loadmat(mat_path) # load
    if m_idx == 0:
        q_domain_bars = l['q_bars']
    else:
        q_domain_bars = np.concatenate((q_domain_bars,l['q_bars']),axis=0)
print ("[%d] robot-specific data."%(q_domain_bars.shape[0]))  

# 4. Mocap-specific data
math_paths = glob.glob('../data/Domain of %s*.mat'%('mocap'))
for m_idx,mat_path in enumerate(math_paths):
    l = sio.loadmat(mat_path) # load
    if m_idx == 0:
        x_domain_tildes = l['x_tildes']
    else:
        x_domain_tildes = np.concatenate((x_domain_tildes,l['x_tildes']),axis=0)
print ("[%d] mocap-specific data."%(x_domain_tildes.shape[0]))  

print ("Done.") 

[500010] paired data from SSL sampling.
[105574] glue data from [100] motion retargeting data.
[200004] robot-specific data.
[201123] mocap-specific data.
Done.


In [4]:
# q_pair_tildes,q_pair_bars,x_pair_tildes,x_pair_bars
# q_glue_tildes,q_glue_bars,x_glue_tildes,x_glue_bars
# q_domain_bars
# x_domain_tildes

In [5]:
# Relaxed to feasible, perhaps, feasible to feasible is a better idea.

# Option 1: tilde -> bar, bar -> bar
x_recon_in  = np.concatenate((x_pair_tildes,x_glue_tildes,x_pair_bars,x_glue_bars),axis=0)
x_recon_out = np.concatenate((x_pair_bars,x_glue_bars,x_pair_bars,x_glue_bars),axis=0)
y_recon_in  = np.concatenate((q_pair_tildes,q_glue_tildes,q_pair_bars,q_glue_bars),axis=0)
y_recon_out = np.concatenate((q_pair_bars,q_glue_bars,q_pair_bars,q_glue_bars),axis=0)


# Option 2: bar -> bar
"""
x_recon_in  = np.concatenate((x_pair_bars,x_glue_bars),axis=0)
x_recon_out = np.concatenate((x_pair_bars,x_glue_bars),axis=0)
y_recon_in  = np.concatenate((q_pair_bars,q_glue_bars),axis=0)
y_recon_out = np.concatenate((q_pair_bars,q_glue_bars),axis=0)
"""

"""
# Option 3: tilde -> tilde
x_recon_in  = np.concatenate((x_pair_tildes,x_glue_tildes),axis=0)
x_recon_out = np.concatenate((x_pair_tildes,x_glue_tildes),axis=0)
y_recon_in  = np.concatenate((q_pair_tildes,q_glue_tildes),axis=0)
y_recon_out = np.concatenate((q_pair_tildes,q_glue_tildes),axis=0)
"""

# For latent modeling, we use all possible data
x_latent = np.concatenate((x_pair_tildes,x_pair_bars,x_glue_tildes,x_glue_bars,x_domain_tildes),axis=0)
y_latent = np.concatenate((q_pair_tildes,q_pair_bars,q_glue_tildes,q_glue_bars,q_domain_bars),axis=0)

USE_BOTH_SSL_AND_MR = True
if USE_BOTH_SSL_AND_MR:
    # Relaxed mocap -> Feasible robot pose
    x_glue = np.concatenate((x_pair_tildes,x_glue_tildes),axis=0)
    y_glue = np.concatenate((q_pair_bars,q_glue_bars),axis=0)

    # Relaxed mocap -> Feasible robot pose
    x_x2y = np.concatenate((x_pair_tildes,x_glue_tildes),axis=0)
    y_x2y = np.concatenate((q_pair_bars,q_glue_bars),axis=0)
else:
    # Relaxed mocap -> Feasible robot pose
    x_glue = x_glue_tildes
    y_glue = q_glue_bars

    # Relaxed mocap -> Feasible robot pose
    x_x2y = x_glue_tildes
    y_x2y = q_glue_bars

# SSL 
x_nce_anc = np.concatenate((x_pair_tildes,x_glue_tildes),axis=0)
x_nce_pos = np.concatenate((x_pair_bars,x_glue_bars),axis=0)


# Print out stats
n_x_recon,n_x_latent,n_y_recon,n_y_latent = x_recon_in.shape[0],x_latent.shape[0],y_recon_in.shape[0],y_latent.shape[0]
n_glue,n_nce,n_x2y = x_glue.shape[0],x_nce_anc.shape[0],x_x2y.shape[0]
print ("n_x_recon:[%d] n_x_latent:[%d] n_y_recon:[%d] n_y_latent:[%d] n_glue:[%d] n_nce:[%d] n_x2y:[%d]."%
       (n_x_recon,n_x_latent,n_y_recon,n_y_latent,n_glue,n_nce,n_x2y))

n_x_recon:[1211168] n_x_latent:[1412291] n_y_recon:[1211168] n_y_latent:[1411172] n_glue:[605584] n_nce:[605584] n_x2y:[605584].


In [6]:
xdim,ydim = x_pair_tildes.shape[1],q_pair_bars.shape[1]
print ('xdim:[%d] ydim:[%d].'%(xdim,ydim))

xdim:[21] ydim:[11].


### Train ($\texttt{MR}: x \mapsto q$)

In [7]:
# Hyperparameters
seed = 0
zdim = 14 # max(14,qdim)
hdims,actv_Q,actv_P,actv_D = [128]*3,tf.nn.relu,tf.nn.relu,tf.nn.relu
# ki = tf.contrib.layers.xavier_initializer()
ki = tf.truncated_normal_initializer(stddev=0.1) 
adam_beta1,adam_beta2,adam_epsilon = 0.9,0.9,1e-0 # 0.5,0.9,1e-0
max_iter,batch_size,print_every,save_every = 1e5,64,1000,1000
lr_rate_fr,lr_rate_to,warmup_it = 1.0,0.1,5e3
# Latent prior
latent_beta = 0.1 # 0.1
lr_d = 2e-4
lr_g = 2e-4
# WAE recon
l1_recon_coef = 0.0
l2_recon_coef = 1.0 # 1.0
lr_recon = 1e-3
# Weight decay
wd_coef = 1e-6
# Latent consensus
l1_lc_coef = 1.0
l2_lc_coef = 5.0 # 1.0
lr_lc = 1e-3
# NCE SSL
nce_coef = 0.01
lr_nce = 1e-3
# X->Y mapping
l1_x2y_coef = 0.0
l2_x2y_coef = 0.1
lr_x2y = 1e-3
    
print ("Done.")

Done.


In [8]:
tf.reset_default_graph()
tf.set_random_seed(seed=seed); np.random.seed(seed=seed)
S = SharedWassersteinAutoEncoderClass(
    xname='s3le_wae_x',yname='s3le_wae_y',xdim=xdim,ydim=ydim,zdim=zdim,
    hdims_Q=hdims,hdims_P=hdims,hdims_D=hdims,
    actv_Q=tf.nn.relu,actv_P=tf.nn.relu,actv_D=tf.nn.relu,
    actv_latent=None,actv_out=None,ki=ki,
    adam_beta1=adam_beta1,adam_beta2=adam_beta2,adam_epsilon=adam_epsilon,
)
sess = gpu_sess() 
tf.set_random_seed(seed=seed); np.random.seed(seed=seed)
sess.run(tf.global_variables_initializer())
print ("Done.")



Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where





Done.


In [None]:
# Loop
lr_rate = 1.0
lr_d_weight_x,lr_d_weight_y = 1.0,1.0 # adaptive learning rate 
total_loss_prev = np.inf
for it in range(int(max_iter)): 
    zero_to_one = it/max_iter
    # lr_rate = lr_rate_fr - (lr_rate_fr-lr_rate_to)*zero_to_one
    lr_rate = min((it+1e-4)**(-0.5),(it+1e-4)*warmup_it**(-1.5)) / \
        min((warmup_it+1e-4)**(-0.5),(warmup_it+1e-4)*warmup_it**(-1.5))

    noise_std = 0.0
    # X Recon
    r_idx = np.random.permutation(x_recon_in.shape[0])[:batch_size]
    x_batch_recon_in = x_recon_in[r_idx,:] + noise_std*np.random.randn(batch_size,xdim) # x_pair_tildes ?
    x_batch_recon_out = x_recon_out[r_idx,:]
    # X Latent
    r_idx = np.random.permutation(x_latent.shape[0])[:batch_size]
    x_batch_latent = x_latent[r_idx,:]
    # Y Recon
    r_idx = np.random.permutation(y_recon_in.shape[0])[:batch_size]
    y_batch_recon_in = y_recon_in[r_idx,:] + noise_std*np.random.randn(batch_size,ydim)
    y_batch_recon_out = y_recon_out[r_idx,:]
    # Y Latent
    r_idx = np.random.permutation(y_latent.shape[0])[:batch_size]
    y_batch_latent = y_latent[r_idx,:]
    # X-Y glue
    r_idx = np.random.permutation(x_glue.shape[0])[:batch_size]
    x_batch_glue = x_glue[r_idx,:] + noise_std*np.random.randn(batch_size,xdim)
    y_batch_glue = y_glue[r_idx,:]
    # X SSL
    r_idx = np.random.permutation(x_nce_anc.shape[0])[:batch_size]
    x_batch_nce_anc = x_nce_anc[r_idx,:]
    x_batch_nce_pos = x_nce_pos[r_idx,:]
    # X->Y mapping
    r_idx = np.random.permutation(x_x2y.shape[0])[:batch_size]
    x_batch_x2y = x_x2y[r_idx,:] + noise_std*np.random.randn(batch_size,xdim)
    y_batch_x2y = y_x2y[r_idx,:]
    
    # Update
    recon_loss_x,wd_loss_x,d_loss_x,g_loss_x,\
        recon_loss_y,wd_loss_y,d_loss_y,g_loss_y,\
        lc_loss,nce_loss,x2y_loss = S.update(
            sess,
            x_recon_in=x_batch_recon_in,x_recon_out=x_batch_recon_out,x_latent=x_batch_latent,
            y_recon_in=y_batch_recon_in,y_recon_out=y_batch_recon_out,y_latent=y_batch_latent,
            x_glue=x_batch_glue,y_glue=y_batch_glue,x_nce_anc=x_batch_nce_anc,x_nce_pos=x_batch_nce_pos,
            x_x2y=x_batch_x2y,y_x2y=y_batch_x2y,
            latent_beta=latent_beta,l1_recon_coef=l1_recon_coef,l2_recon_coef=l2_recon_coef,wd_coef=wd_coef,
            lr_recon_x=lr_recon,lr_d_x=lr_d*lr_d_weight_x,lr_g_x=lr_g,
            lr_recon_y=lr_recon,lr_d_y=lr_d*lr_d_weight_y,lr_g_y=lr_g,
            lr_lc=lr_lc,l1_lc_coef=l1_lc_coef,l2_lc_coef=l2_lc_coef,lr_nce=lr_nce,nce_coef=nce_coef,
            lr_x2y=lr_x2y,l1_x2y_coef=l1_x2y_coef,l2_x2y_coef=l2_x2y_coef,lr_rate=lr_rate
            )
    lr_d_weight_x,lr_d_weight_y = min(1.0,d_loss_x/(0.1+g_loss_x)),min(1.0,d_loss_y/(0.1+g_loss_y))
    total_loss = recon_loss_x + wd_loss_x + d_loss_x + g_loss_x + recon_loss_y + wd_loss_y + \
        d_loss_y + g_loss_y + lc_loss + nce_loss + x2y_loss

    # Print results every some iterations 
    if ((it % print_every) == 0) or ((it+1) == max_iter): 
        print (("[%d][%.1f%%]lr:[%.2f] X R:[%.3f] D:[%.3f] G:[%.3f] WD:[%.3f] / Y R:[%.3f] D:[%.3f] G:[%.3f] WD:[%.3f]\n"
                "   LC:[%.3f] / NCE:[%.3f] / X2Y:[%.3f] / total_loss:[%.3f]")%
               (it,zero_to_one*100,lr_rate,recon_loss_x,d_loss_x,g_loss_x,wd_loss_x,
                recon_loss_y,d_loss_y,g_loss_y,wd_loss_y,
                lc_loss,nce_loss,x2y_loss,total_loss))
    # Save?
    if ((it % save_every) == 0) and (total_loss < total_loss_prev):
        total_loss_prev = total_loss
        S.W_x.save_to_mat(sess,it=it,suffix='',VERBOSE=False)
        S.W_y.save_to_mat(sess,it=it,suffix='',VERBOSE=False)
        print ("Checkpoint it:[%d] total_loss:[%.3f]."%(it,total_loss))
    
print ("Done.")

[0][0.0%]lr:[0.00] X R:[7.033] D:[0.069] G:[0.069] WD:[0.000] / Y R:[5.658] D:[0.069] G:[0.070] WD:[0.000]
   LC:[6.412] / NCE:[0.092] / X2Y:[0.568] / total_loss:[20.042]
[nets/s3le_wae_x/weights.mat] saved. Size is[0.448]MB.
[nets/s3le_wae_y/weights.mat] saved. Size is[0.436]MB.
Checkpoint it:[0] total_loss:[20.042].
[1000][1.0%]lr:[0.20] X R:[2.317] D:[0.069] G:[0.070] WD:[0.000] / Y R:[2.647] D:[0.071] G:[0.067] WD:[0.000]
   LC:[0.578] / NCE:[0.097] / X2Y:[0.316] / total_loss:[6.232]
[nets/s3le_wae_x/weights.mat] saved. Size is[0.448]MB.
[nets/s3le_wae_y/weights.mat] saved. Size is[0.436]MB.
Checkpoint it:[1000] total_loss:[6.232].
[2000][2.0%]lr:[0.40] X R:[1.543] D:[0.068] G:[0.072] WD:[0.000] / Y R:[0.749] D:[0.071] G:[0.066] WD:[0.000]
   LC:[0.493] / NCE:[0.096] / X2Y:[0.217] / total_loss:[3.376]
[nets/s3le_wae_x/weights.mat] saved. Size is[0.448]MB.
[nets/s3le_wae_y/weights.mat] saved. Size is[0.436]MB.
Checkpoint it:[2000] total_loss:[3.376].
[3000][3.0%]lr:[0.60] X R:[0.737