# Deep Latent Policy Gradient for Ant

In [None]:
import gym,warnings,time
warnings.filterwarnings("ignore") # Stop annoying warnings
gym.logger.set_level(40)
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import matplotlib.pyplot as plt
%matplotlib inline
import skvideo.io
from datetime import datetime
from custom_ant import AntEnvCustom # Custom ant 
from lgrp_class import lgrp_class # Gaussian random path
from vae_class import vae_class # VAE
from antTrainEnv_class import antTrainEnv_class
from util import PID_class,display_frames_as_gif,\
    quaternion_to_euler_angle,multi_dim_interp,cpu_sess,gpu_sess,Scaler
print ("TF version is [%s]."%(tf.__version__))

### Instantiate Class

In [None]:
tf.reset_default_graph() # Reset Graph
AntEnv = antTrainEnv_class(_tMax=3,_nAnchor=20,_maxRepeat=3,_hypGain=1/4,_hypLen=1/4,_pGain=0.01,
            _zDim=16,_hDims=[64,64],_vaeActv=tf.nn.relu,
            _PLOT_GRP=True)

### Train Ant

In [None]:
SAVE_VID = False
MAKE_GIF = False # Probably unnecessary 
PLOT_GRP = True 
PLOT_EVERY = 10

In [None]:
sess = gpu_sess()
maxEpoch  = 500
batchSize = 100
print ("Start training...")
AntEnv.train_dlpg(_sess=sess,_seed=0,_maxEpoch=maxEpoch,_batchSize=batchSize,_nIter4update=1e4,
                 _SAVE_VID=SAVE_VID,_MAKE_GIF=MAKE_GIF,_PLOT_GRP=PLOT_GRP,_PLOT_EVERY=PLOT_EVERY,
                 _DO_RENDER=False)

### Make Final vid

In [None]:
SAVE_VID_FINAL = True
MAKE_GIF_FINAL = False
PLOT_GRP_FINAL = True

In [None]:
for _i in range(3):
    np.random.seed(seed=_i+100)
    sampledX = AntEnv.VAE.sample(_sess=sess).reshape((AntEnv.nAnchor,AntEnv.env.actDim))
    sampledX = (sampledX-sampledX.min())/(sampledX.max()-sampledX.min())
    AntEnv.set_anchor_grp_posterior(_anchors=sampledX,_levBtw=0.99)
    avgRwd,ret = AntEnv.unit_rollout_from_grp_mean(_maxRepeat=AntEnv.maxRepeat,_DO_RENDER=True)
    print ("  [^] avgRwd:[%.3f] Xdisp:[%.3f] Hdisp:[%.3f]"%(avgRwd,ret['xDisp'],ret['hDisp']))
    if SAVE_VID_FINAL:
        outputdata = np.asarray(ret['frames']).astype(np.uint8)
        vidName = 'vids/ant_dlpg_final_%d.mp4'%(_i)
        skvideo.io.vwrite(vidName,outputdata)
        print ("[%s] saved."%(vidName))
    if MAKE_GIF_FINAL:
        display_frames_as_gif(ret['frames'],_intv_ms=20,_figsize=(6,6),_fontsize=15,
                              _titleStrs=ret['titleStrs'])
    if PLOT_GRP_FINAL:
        AntEnv.GRPposterior.plot_all(_nPath=10,_figsize=(10,4))