In [1]:
import torch
import numpy as np
import pickle

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Product, ConstantKernel as C

import gym_sin
from gym import spaces

from utilities.arguments import get_args
from learner.posterior_multi_task import PosteriorMTAgent
from inference.inference_network import InferenceNetwork
from task.ExploreTaskGenerator import ExploreTaskGenerator
from utilities.folder_management import handle_folder_creation


In [2]:
env_name = "gaussexplore-v0"

action_space = spaces.Box(low=np.array([-1]), high=np.array([1]))

latent_dim = 1

device = "cpu"

x_min=-100
x_max=100
noise_std=0.001
std=10
mean_max=60 
mean_min=40

vae_min_seq = 1
vae_max_seq = 20

max_old = [100, 25]
min_old = [-100, 0]

obs_shape = (2,)

In [3]:
task_generator = ExploreTaskGenerator(x_min=x_min, x_max=x_max, noise_std=noise_std, std=std, mean_max=mean_max,
                                     mean_min=mean_min)
f = task_generator.create_task_family(n_tasks=5000, n_batches=1, test_perc=0, batch_size=1)

In [4]:
task_generator.sample_pair_tasks(1)

([{'min_x': -100,
   'max_x': 100,
   'noise_std': 0.001,
   'std': 10,
   'mean': 99.79561614990234,
   'scale_reward': False,
   'amplitude': 1}],
 None,
 [tensor([[50.],
          [20.]])],
 tensor([[99.7956]]))

In [5]:
vi = InferenceNetwork(n_in=4, z_dim=latent_dim)
vi_optim = torch.optim.Adam(vi.parameters(), lr=1e-3)

In [7]:
agent = PosteriorMTAgent(action_space=action_space, device=device, gamma=1,
                                 num_steps=20, num_processes=32,
                                 clip_param=0.1, ppo_epoch=4,
                                 num_mini_batch=8,
                                 value_loss_coef=0.5,
                                 entropy_coef=0.001,
                                 lr=0.00001,
                                 eps=1e-6, max_grad_norm=0.5,
                                 use_linear_lr_decay=False,
                                 use_gae=False,
                                 gae_lambda=0.95,
                                 use_proper_time_limits=False,
                                 obs_shape=obs_shape,
                                 latent_dim=latent_dim,
                                 recurrent_policy=False,
                                 hidden_size=8,
                                 use_elu=True,
                                 variational_model=vi,
                                 vae_optim=vi_optim,
                                 rescale_obs=True,
                                 max_old=max_old,
                                 min_old=min_old,
                                 vae_min_seq=vae_min_seq,
                                 vae_max_seq=vae_max_seq,
                                 max_action=x_max, 
                                 min_action=x_min,
                        use_time=False, rescale_time=None, max_time=None,
                        max_sigma=30,
                        use_decay_kld=True,
                        decay_kld_rate=1)

In [None]:
res_eval, res_vae, test_list = agent.train(training_iter=60000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=100,
                                           log_dir=".",
                                           use_env_obs=False,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )

Epoch 0 MSE DIR 4124.00537109375 MSE VAR 0.2598961591720581 KLD 3.428971529006958
Epoch 0 MSE DIR 2632.0927734375 MSE VAR 0.6786916255950928 KLD 10.192231178283691
Epoch 100 MSE DIR 2387.422607421875 MSE VAR 0.06614695489406586 KLD 2.2854981422424316
Epoch 100 / 60000
Evaluation...
Evaluation using 32 tasks. Mean reward: 1.9142990624999998
Epoch 200 MSE DIR 713.0827026367188 MSE VAR 0.11949768662452698 KLD 0.9908828735351562
Epoch 200 / 60000
Evaluation...
Evaluation using 32 tasks. Mean reward: 2.16761584375
Epoch 300 MSE DIR 836.0440063476562 MSE VAR 0.0706048384308815 KLD 0.2587652802467346
Epoch 300 / 60000
Evaluation...
Evaluation using 32 tasks. Mean reward: 2.5071226875
Epoch 400 MSE DIR 915.992431640625 MSE VAR 0.0824493020772934 KLD 0.15358604490756989
Epoch 400 / 60000
Evaluation...
Evaluation using 32 tasks. Mean reward: 2.127427125
Epoch 500 MSE DIR 831.2276000976562 MSE VAR 0.06856021285057068 KLD 0.4749152362346649
Epoch 500 / 60000
Evaluation...
Evaluation using 32 tasks

In [8]:
res_eval, res_vae, test_list = agent.train(training_iter=10000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=100,
                                           log_dir=".",
                                           use_env_obs=False,
                                           num_vae_steps=1,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )

Epoch 0 MSE DIR 1102.126220703125 MSE VAR 0.5106451511383057 KLD 12.363800048828125
Epoch 0 MSE DIR 61.23255157470703 MSE VAR 0.4919677674770355 KLD 16.346757888793945
Epoch 0 MSE DIR 107.28985595703125 MSE VAR 0.48426467180252075 KLD 16.97114372253418
Epoch 0 MSE DIR 87.6654281616211 MSE VAR 0.47243523597717285 KLD 23.340375900268555
Epoch 0 MSE DIR 73.08710479736328 MSE VAR 0.4773685932159424 KLD 25.774188995361328
Epoch 0 MSE DIR 99.29734802246094 MSE VAR 0.4914470911026001 KLD 22.424747467041016
Epoch 0 MSE DIR 134.47625732421875 MSE VAR 0.5051237344741821 KLD 16.921104431152344
Epoch 0 MSE DIR 45.868141174316406 MSE VAR 0.4941844046115875 KLD 19.985750198364258
Epoch 0 MSE DIR 187.98968505859375 MSE VAR 0.4844704866409302 KLD 17.062307357788086
Epoch 0 MSE DIR 183.72235107421875 MSE VAR 0.4779428243637085 KLD 21.491130828857422
Epoch 0 MSE DIR 187.09422302246094 MSE VAR 0.47421103715896606 KLD 14.130449295043945
Epoch 0 MSE DIR 324.5115966796875 MSE VAR 0.49371346831321716 KLD 17.

Epoch 0 MSE DIR 141.29693603515625 MSE VAR 0.47057539224624634 KLD 23.040828704833984
Epoch 0 MSE DIR 61.44329071044922 MSE VAR 0.4773663878440857 KLD 23.273666381835938
Epoch 0 MSE DIR 230.534912109375 MSE VAR 0.48569005727767944 KLD 11.099896430969238
Epoch 0 MSE DIR 65.67581939697266 MSE VAR 0.5192599296569824 KLD 21.584991455078125
Epoch 100 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 15.163866562499999
Epoch 0 MSE DIR 87.88123321533203 MSE VAR 0.50455242395401 KLD 15.208534240722656
Epoch 0 MSE DIR 40.37436294555664 MSE VAR 0.4782692790031433 KLD 19.356348037719727
Epoch 0 MSE DIR 106.21256256103516 MSE VAR 0.4635176360607147 KLD 21.250804901123047
Epoch 0 MSE DIR 101.51525115966797 MSE VAR 0.4706513583660126 KLD 19.620582580566406
Epoch 0 MSE DIR 68.88744354248047 MSE VAR 0.49311041831970215 KLD 12.761812210083008
Epoch 0 MSE DIR 65.1770248413086 MSE VAR 0.5100916624069214 KLD 16.940345764160156
Epoch 0 MSE DIR 190.8408203125 MSE VAR 0.5099077820777893 KLD 14.67

Epoch 0 MSE DIR 9.867161750793457 MSE VAR 0.48726195096969604 KLD 19.526012420654297
Epoch 0 MSE DIR 31.774131774902344 MSE VAR 0.48194006085395813 KLD 22.180946350097656
Epoch 0 MSE DIR 657.4150390625 MSE VAR 0.48657456040382385 KLD 14.701711654663086
Epoch 0 MSE DIR 45.95131301879883 MSE VAR 0.4983396530151367 KLD 18.866907119750977
Epoch 0 MSE DIR 27.07884979248047 MSE VAR 0.4901244640350342 KLD 21.86675262451172
Epoch 0 MSE DIR 24.114336013793945 MSE VAR 0.4846692383289337 KLD 17.397783279418945
Epoch 0 MSE DIR 48.213321685791016 MSE VAR 0.48692312836647034 KLD 16.805139541625977
Epoch 200 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 14.186022031250001
Epoch 0 MSE DIR 21.0980281829834 MSE VAR 0.4818415641784668 KLD 16.96601104736328
Epoch 0 MSE DIR 71.22007751464844 MSE VAR 0.4911867678165436 KLD 20.355897903442383
Epoch 0 MSE DIR 22.282121658325195 MSE VAR 0.4940316081047058 KLD 19.638656616210938
Epoch 0 MSE DIR 20.045873641967773 MSE VAR 0.4837679862976074 KLD 2

Epoch 0 MSE DIR 62.24097442626953 MSE VAR 0.49273622035980225 KLD 16.94883918762207
Epoch 0 MSE DIR 55.49863815307617 MSE VAR 0.4910373091697693 KLD 17.860841751098633
Epoch 0 MSE DIR 46.79127883911133 MSE VAR 0.48874518275260925 KLD 18.00894546508789
Epoch 0 MSE DIR 71.09456634521484 MSE VAR 0.4870077073574066 KLD 23.721355438232422
Epoch 0 MSE DIR 270.919921875 MSE VAR 0.46819907426834106 KLD 8.20993423461914
Epoch 0 MSE DIR 74.04637908935547 MSE VAR 0.49559450149536133 KLD 22.579376220703125
Epoch 0 MSE DIR 33.47499465942383 MSE VAR 0.497993528842926 KLD 20.056655883789062
Epoch 0 MSE DIR 58.39998245239258 MSE VAR 0.49654385447502136 KLD 20.153209686279297
Epoch 0 MSE DIR 14.626851081848145 MSE VAR 0.48243752121925354 KLD 18.788986206054688
Epoch 0 MSE DIR 89.05166625976562 MSE VAR 0.47537103295326233 KLD 18.51671600341797
Epoch 300 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 15.602823781249999
Epoch 0 MSE DIR 22.99711036682129 MSE VAR 0.48044097423553467 KLD 20.98

Epoch 0 MSE DIR 127.16033172607422 MSE VAR 0.5023034811019897 KLD 18.921062469482422
Epoch 0 MSE DIR 57.850711822509766 MSE VAR 0.4871797263622284 KLD 15.997712135314941
Epoch 0 MSE DIR 65.6239242553711 MSE VAR 0.4752005934715271 KLD 20.420581817626953
Epoch 0 MSE DIR 153.79180908203125 MSE VAR 0.48138466477394104 KLD 17.66363525390625
Epoch 0 MSE DIR 83.01224517822266 MSE VAR 0.49148058891296387 KLD 19.094173431396484
Epoch 0 MSE DIR 83.08248138427734 MSE VAR 0.4932485520839691 KLD 13.506078720092773
Epoch 0 MSE DIR 50.66714859008789 MSE VAR 0.4967525601387024 KLD 18.89590835571289
Epoch 0 MSE DIR 289.33349609375 MSE VAR 0.48430532217025757 KLD 15.248037338256836
Epoch 0 MSE DIR 17.967926025390625 MSE VAR 0.47943076491355896 KLD 14.554520606994629
Epoch 0 MSE DIR 180.1589813232422 MSE VAR 0.48723888397216797 KLD 18.810781478881836
Epoch 0 MSE DIR 74.10955810546875 MSE VAR 0.49077096581459045 KLD 20.516202926635742
Epoch 0 MSE DIR 82.7069320678711 MSE VAR 0.49163979291915894 KLD 21.748

Epoch 0 MSE DIR 45.32157897949219 MSE VAR 0.47182202339172363 KLD 22.10578155517578
Epoch 0 MSE DIR 81.09587860107422 MSE VAR 0.48539915680885315 KLD 19.17051124572754
Epoch 0 MSE DIR 147.21792602539062 MSE VAR 0.5084992051124573 KLD 15.346427917480469
Epoch 0 MSE DIR 315.54095458984375 MSE VAR 0.47696423530578613 KLD 11.880226135253906
Epoch 0 MSE DIR 36.197750091552734 MSE VAR 0.49747851490974426 KLD 22.3729248046875
Epoch 0 MSE DIR 133.33236694335938 MSE VAR 0.49284374713897705 KLD 21.165090560913086
Epoch 0 MSE DIR 79.88632202148438 MSE VAR 0.48031342029571533 KLD 21.899349212646484
Epoch 0 MSE DIR 132.4051055908203 MSE VAR 0.4826715886592865 KLD 22.454486846923828
Epoch 0 MSE DIR 146.26864624023438 MSE VAR 0.4840106964111328 KLD 14.625062942504883
Epoch 0 MSE DIR 59.1157112121582 MSE VAR 0.4884253144264221 KLD 16.723451614379883
Epoch 0 MSE DIR 142.42088317871094 MSE VAR 0.4942640960216522 KLD 26.267059326171875
Epoch 0 MSE DIR 81.40058135986328 MSE VAR 0.489757776260376 KLD 18.92

Epoch 0 MSE DIR 19.227025985717773 MSE VAR 0.5034778118133545 KLD 20.466096878051758
Epoch 0 MSE DIR 54.33757781982422 MSE VAR 0.5057421922683716 KLD 24.53722381591797
Epoch 0 MSE DIR 294.6529541015625 MSE VAR 0.46553075313568115 KLD 10.94767951965332
Epoch 0 MSE DIR 56.55366134643555 MSE VAR 0.4801122546195984 KLD 19.424325942993164
Epoch 0 MSE DIR 61.89160919189453 MSE VAR 0.49078258872032166 KLD 19.944652557373047
Epoch 0 MSE DIR 213.339599609375 MSE VAR 0.5011710524559021 KLD 24.481853485107422
Epoch 0 MSE DIR 36.606876373291016 MSE VAR 0.4848836362361908 KLD 18.33942413330078
Epoch 0 MSE DIR 77.98754119873047 MSE VAR 0.48162081837654114 KLD 21.027698516845703
Epoch 0 MSE DIR 102.03801727294922 MSE VAR 0.49416130781173706 KLD 14.522067070007324
Epoch 0 MSE DIR 132.9329071044922 MSE VAR 0.4831654727458954 KLD 16.73040199279785
Epoch 0 MSE DIR 218.65756225585938 MSE VAR 0.49494338035583496 KLD 14.107095718383789
Epoch 0 MSE DIR 251.8457489013672 MSE VAR 0.47656235098838806 KLD 17.466

Epoch 0 MSE DIR 198.69021606445312 MSE VAR 0.4896264672279358 KLD 17.359067916870117
Epoch 0 MSE DIR 273.14453125 MSE VAR 0.4578806161880493 KLD 10.670649528503418
Epoch 0 MSE DIR 87.82743072509766 MSE VAR 0.5116400718688965 KLD 22.120161056518555
Epoch 0 MSE DIR 63.36162567138672 MSE VAR 0.5123215913772583 KLD 15.249475479125977
Epoch 0 MSE DIR 25.898006439208984 MSE VAR 0.4899921417236328 KLD 17.71997833251953
Epoch 0 MSE DIR 20.70561981201172 MSE VAR 0.47069051861763 KLD 20.382709503173828
Epoch 0 MSE DIR 32.495933532714844 MSE VAR 0.4648597240447998 KLD 18.94539451599121
Epoch 0 MSE DIR 22.1234073638916 MSE VAR 0.48553821444511414 KLD 15.749350547790527
Epoch 0 MSE DIR 105.10142517089844 MSE VAR 0.5076412558555603 KLD 18.240299224853516
Epoch 0 MSE DIR 65.17884063720703 MSE VAR 0.5077087879180908 KLD 13.323673248291016
Epoch 0 MSE DIR 110.32081604003906 MSE VAR 0.4858185648918152 KLD 21.326574325561523
Epoch 0 MSE DIR 52.286617279052734 MSE VAR 0.4733860492706299 KLD 22.13526153564

Epoch 0 MSE DIR 40.28684997558594 MSE VAR 0.4875836968421936 KLD 18.372299194335938
Epoch 0 MSE DIR 127.25979614257812 MSE VAR 0.48851796984672546 KLD 21.86199378967285
Epoch 0 MSE DIR 9.607891082763672 MSE VAR 0.48584815859794617 KLD 24.64865493774414
Epoch 0 MSE DIR 39.144039154052734 MSE VAR 0.48714107275009155 KLD 19.958175659179688
Epoch 0 MSE DIR 18.854358673095703 MSE VAR 0.48700204491615295 KLD 26.138164520263672
Epoch 0 MSE DIR 15.077594757080078 MSE VAR 0.48884621262550354 KLD 21.89006996154785
Epoch 0 MSE DIR 49.78361892700195 MSE VAR 0.4910001754760742 KLD 16.558589935302734
Epoch 0 MSE DIR 80.0846939086914 MSE VAR 0.49492889642715454 KLD 19.684215545654297
Epoch 0 MSE DIR 19.74982452392578 MSE VAR 0.47548481822013855 KLD 21.933839797973633
Epoch 0 MSE DIR 17.208763122558594 MSE VAR 0.48023685812950134 KLD 22.790449142456055
Epoch 0 MSE DIR 18.51198959350586 MSE VAR 0.4897380471229553 KLD 19.999631881713867
Epoch 0 MSE DIR 25.085081100463867 MSE VAR 0.5002104043960571 KLD 1

Epoch 0 MSE DIR 140.37286376953125 MSE VAR 0.4883897602558136 KLD 17.097015380859375
Epoch 0 MSE DIR 117.24163055419922 MSE VAR 0.4905126690864563 KLD 15.05345630645752
Epoch 0 MSE DIR 30.1375732421875 MSE VAR 0.4685407280921936 KLD 19.26534652709961
Epoch 0 MSE DIR 32.39005661010742 MSE VAR 0.47313740849494934 KLD 22.3079776763916
Epoch 0 MSE DIR 50.05064392089844 MSE VAR 0.49694931507110596 KLD 24.142841339111328
Epoch 0 MSE DIR 200.78379821777344 MSE VAR 0.5087414383888245 KLD 26.359983444213867
Epoch 0 MSE DIR 28.805879592895508 MSE VAR 0.499141126871109 KLD 21.56205940246582
Epoch 0 MSE DIR 239.11849975585938 MSE VAR 0.4767211675643921 KLD 15.952174186706543
Epoch 0 MSE DIR 273.7214050292969 MSE VAR 0.4808153212070465 KLD 14.500499725341797
Epoch 0 MSE DIR 64.58554077148438 MSE VAR 0.4696347117424011 KLD 20.043521881103516
Epoch 0 MSE DIR 11.553781509399414 MSE VAR 0.48914051055908203 KLD 24.359403610229492
Epoch 0 MSE DIR 28.611637115478516 MSE VAR 0.5031355023384094 KLD 22.23572

Epoch 0 MSE DIR 181.41244506835938 MSE VAR 0.5019809007644653 KLD 15.073084831237793
Epoch 0 MSE DIR 112.19654846191406 MSE VAR 0.48432618379592896 KLD 18.225841522216797
Epoch 0 MSE DIR 115.8187026977539 MSE VAR 0.47220224142074585 KLD 15.433236122131348
Epoch 0 MSE DIR 75.85935974121094 MSE VAR 0.4790509343147278 KLD 18.88509750366211
Epoch 0 MSE DIR 17.778907775878906 MSE VAR 0.48655006289482117 KLD 19.595914840698242
Epoch 0 MSE DIR 34.95167541503906 MSE VAR 0.4944153428077698 KLD 16.030118942260742
Epoch 0 MSE DIR 58.10845947265625 MSE VAR 0.5024714469909668 KLD 15.426299095153809
Epoch 0 MSE DIR 56.58894729614258 MSE VAR 0.4887406826019287 KLD 19.8917179107666
Epoch 0 MSE DIR 77.37503814697266 MSE VAR 0.4769022762775421 KLD 17.638378143310547
Epoch 0 MSE DIR 49.33723068237305 MSE VAR 0.48079541325569153 KLD 18.510475158691406
Epoch 0 MSE DIR 91.4287109375 MSE VAR 0.4935775399208069 KLD 22.25686264038086
Epoch 0 MSE DIR 9.209882736206055 MSE VAR 0.4896429181098938 KLD 16.329494476

Epoch 0 MSE DIR 71.29983520507812 MSE VAR 0.47430115938186646 KLD 20.073890686035156
Epoch 0 MSE DIR 88.60379028320312 MSE VAR 0.4760999083518982 KLD 15.71755599975586
Epoch 0 MSE DIR 46.52521514892578 MSE VAR 0.49229541420936584 KLD 17.184223175048828
Epoch 0 MSE DIR 49.233036041259766 MSE VAR 0.500225841999054 KLD 17.889135360717773
Epoch 0 MSE DIR 234.2256317138672 MSE VAR 0.4964783489704132 KLD 12.20599365234375
Epoch 0 MSE DIR 228.86679077148438 MSE VAR 0.485530823469162 KLD 14.861664772033691
Epoch 0 MSE DIR 23.47625732421875 MSE VAR 0.47665008902549744 KLD 17.965497970581055
Epoch 0 MSE DIR 218.1004180908203 MSE VAR 0.48165595531463623 KLD 16.496570587158203
Epoch 0 MSE DIR 156.11924743652344 MSE VAR 0.4896770715713501 KLD 15.420536041259766
Epoch 0 MSE DIR 257.3022766113281 MSE VAR 0.4955922067165375 KLD 16.6164608001709
Epoch 0 MSE DIR 33.535213470458984 MSE VAR 0.493579626083374 KLD 19.86968231201172
Epoch 0 MSE DIR 101.61994934082031 MSE VAR 0.48444679379463196 KLD 17.033546

Epoch 0 MSE DIR 233.1103515625 MSE VAR 0.48826274275779724 KLD 11.620061874389648
Epoch 0 MSE DIR 197.61587524414062 MSE VAR 0.5180681347846985 KLD 12.681014060974121
Epoch 0 MSE DIR 256.3540344238281 MSE VAR 0.5120110511779785 KLD 9.795289993286133
Epoch 0 MSE DIR 30.864810943603516 MSE VAR 0.4854755103588104 KLD 15.444585800170898
Epoch 0 MSE DIR 105.62046813964844 MSE VAR 0.4657641649246216 KLD 14.824427604675293
Epoch 0 MSE DIR 213.7608642578125 MSE VAR 0.4736882746219635 KLD 16.014368057250977
Epoch 0 MSE DIR 199.7503662109375 MSE VAR 0.4766363501548767 KLD 22.150371551513672
Epoch 0 MSE DIR 109.83899688720703 MSE VAR 0.5044326782226562 KLD 18.716127395629883
Epoch 0 MSE DIR 120.83804321289062 MSE VAR 0.5038810968399048 KLD 14.806671142578125
Epoch 0 MSE DIR 121.06861114501953 MSE VAR 0.4922262132167816 KLD 20.217201232910156
Epoch 0 MSE DIR 119.52975463867188 MSE VAR 0.48260778188705444 KLD 20.4218692779541
Epoch 0 MSE DIR 64.78526306152344 MSE VAR 0.4728448987007141 KLD 21.10810

Epoch 0 MSE DIR 119.78511047363281 MSE VAR 0.4943385720252991 KLD 19.418170928955078
Epoch 0 MSE DIR 25.269519805908203 MSE VAR 0.4799535870552063 KLD 19.794227600097656
Epoch 0 MSE DIR 29.731666564941406 MSE VAR 0.46961063146591187 KLD 24.69777488708496
Epoch 0 MSE DIR 46.076141357421875 MSE VAR 0.48233968019485474 KLD 20.96694564819336
Epoch 0 MSE DIR 76.9797592163086 MSE VAR 0.49794304370880127 KLD 17.693098068237305
Epoch 0 MSE DIR 41.62681579589844 MSE VAR 0.5022619962692261 KLD 19.57851219177246
Epoch 0 MSE DIR 24.2376651763916 MSE VAR 0.4951440989971161 KLD 18.74045753479004
Epoch 0 MSE DIR 32.35086441040039 MSE VAR 0.48415839672088623 KLD 19.5249080657959
Epoch 0 MSE DIR 29.14082908630371 MSE VAR 0.4764924645423889 KLD 23.689655303955078
Epoch 0 MSE DIR 48.42539978027344 MSE VAR 0.4759950339794159 KLD 16.425296783447266
Epoch 0 MSE DIR 21.863521575927734 MSE VAR 0.49211549758911133 KLD 20.861949920654297
Epoch 0 MSE DIR 28.487781524658203 MSE VAR 0.49779975414276123 KLD 16.5869

Epoch 0 MSE DIR 487.8580322265625 MSE VAR 0.4793797731399536 KLD 11.905815124511719
Epoch 0 MSE DIR 215.5006866455078 MSE VAR 0.5046302676200867 KLD 19.165481567382812
Epoch 0 MSE DIR 51.360145568847656 MSE VAR 0.5139515399932861 KLD 24.26093864440918
Epoch 0 MSE DIR 22.322574615478516 MSE VAR 0.4957998991012573 KLD 17.47703742980957
Epoch 0 MSE DIR 130.57901000976562 MSE VAR 0.4646577835083008 KLD 13.684941291809082
Epoch 0 MSE DIR 123.02931213378906 MSE VAR 0.47353753447532654 KLD 19.637649536132812
Epoch 0 MSE DIR 163.96560668945312 MSE VAR 0.48204755783081055 KLD 15.67310905456543
Epoch 0 MSE DIR 109.09918975830078 MSE VAR 0.5028653144836426 KLD 20.295692443847656
Epoch 0 MSE DIR 292.42242431640625 MSE VAR 0.49828046560287476 KLD 15.132563591003418
Epoch 0 MSE DIR 176.06600952148438 MSE VAR 0.48765408992767334 KLD 15.013571739196777
Epoch 0 MSE DIR 223.5155029296875 MSE VAR 0.48510557413101196 KLD 13.853239059448242
Epoch 0 MSE DIR 34.853939056396484 MSE VAR 0.4860149621963501 KLD 

Epoch 0 MSE DIR 25.085830688476562 MSE VAR 0.5032123327255249 KLD 18.677635192871094
Epoch 0 MSE DIR 33.47722244262695 MSE VAR 0.49343085289001465 KLD 23.33433723449707
Epoch 0 MSE DIR 26.284189224243164 MSE VAR 0.4781636893749237 KLD 23.017539978027344
Epoch 0 MSE DIR 29.853042602539062 MSE VAR 0.478108286857605 KLD 25.837678909301758
Epoch 0 MSE DIR 12.793170928955078 MSE VAR 0.4856448769569397 KLD 24.33844757080078
Epoch 0 MSE DIR 39.85375213623047 MSE VAR 0.493753045797348 KLD 17.01637077331543
Epoch 0 MSE DIR 28.859201431274414 MSE VAR 0.49527600407600403 KLD 17.8420352935791
Epoch 0 MSE DIR 62.64944076538086 MSE VAR 0.48910799622535706 KLD 21.723342895507812
Epoch 0 MSE DIR 22.866451263427734 MSE VAR 0.485961377620697 KLD 17.639232635498047
Epoch 0 MSE DIR 40.13640213012695 MSE VAR 0.47946780920028687 KLD 16.692440032958984
Epoch 0 MSE DIR 35.46184539794922 MSE VAR 0.4867767095565796 KLD 17.017303466796875
Epoch 0 MSE DIR 49.88806915283203 MSE VAR 0.4882991313934326 KLD 20.690942

Epoch 0 MSE DIR 190.72935485839844 MSE VAR 0.500895619392395 KLD 16.842832565307617
Epoch 0 MSE DIR 35.22027587890625 MSE VAR 0.48344117403030396 KLD 16.3721981048584
Epoch 0 MSE DIR 47.457969665527344 MSE VAR 0.48182517290115356 KLD 19.731292724609375
Epoch 0 MSE DIR 10.183211326599121 MSE VAR 0.48372331261634827 KLD 23.712060928344727
Epoch 0 MSE DIR 71.34004211425781 MSE VAR 0.49345672130584717 KLD 18.432647705078125
Epoch 0 MSE DIR 20.924602508544922 MSE VAR 0.4903709292411804 KLD 20.45587730407715
Epoch 0 MSE DIR 15.909265518188477 MSE VAR 0.49203771352767944 KLD 22.745620727539062
Epoch 0 MSE DIR 19.329532623291016 MSE VAR 0.48324504494667053 KLD 19.80515480041504
Epoch 0 MSE DIR 8.949529647827148 MSE VAR 0.4821719527244568 KLD 21.90863037109375
Epoch 0 MSE DIR 45.9636344909668 MSE VAR 0.48428553342819214 KLD 24.410953521728516
Epoch 0 MSE DIR 58.66006088256836 MSE VAR 0.4947209060192108 KLD 20.064048767089844
Epoch 0 MSE DIR 30.388002395629883 MSE VAR 0.49789518117904663 KLD 21.

Epoch 0 MSE DIR 165.7857666015625 MSE VAR 0.4807072877883911 KLD 15.714155197143555
Epoch 0 MSE DIR 70.57417297363281 MSE VAR 0.49161332845687866 KLD 13.763221740722656
Epoch 0 MSE DIR 86.45368957519531 MSE VAR 0.49292826652526855 KLD 14.168668746948242
Epoch 0 MSE DIR 171.67706298828125 MSE VAR 0.4970052242279053 KLD 17.72986602783203
Epoch 0 MSE DIR 222.24136352539062 MSE VAR 0.4791804254055023 KLD 13.531681060791016
Epoch 0 MSE DIR 174.36166381835938 MSE VAR 0.5097532272338867 KLD 16.488433837890625
Epoch 0 MSE DIR 185.2281494140625 MSE VAR 0.4910373091697693 KLD 17.73805046081543
Epoch 0 MSE DIR 251.20249938964844 MSE VAR 0.45053625106811523 KLD 16.38677978515625
Epoch 0 MSE DIR 60.752662658691406 MSE VAR 0.455909788608551 KLD 16.836271286010742
Epoch 0 MSE DIR 63.18536376953125 MSE VAR 0.4997546076774597 KLD 19.933176040649414
Epoch 0 MSE DIR 83.73587799072266 MSE VAR 0.5291010141372681 KLD 21.284465789794922
Epoch 0 MSE DIR 43.02959060668945 MSE VAR 0.5155417919158936 KLD 23.6365

Epoch 0 MSE DIR 10.636083602905273 MSE VAR 0.4836614727973938 KLD 14.98672103881836
Epoch 0 MSE DIR 70.54531860351562 MSE VAR 0.4934287369251251 KLD 16.456336975097656
Epoch 0 MSE DIR 57.22184753417969 MSE VAR 0.4932252764701843 KLD 18.52709197998047
Epoch 0 MSE DIR 30.462488174438477 MSE VAR 0.488617479801178 KLD 19.641036987304688
Epoch 0 MSE DIR 84.27660369873047 MSE VAR 0.48676663637161255 KLD 16.438505172729492
Epoch 0 MSE DIR 232.5491485595703 MSE VAR 0.4666479825973511 KLD 15.622676849365234
Epoch 0 MSE DIR 52.84527587890625 MSE VAR 0.5013476014137268 KLD 19.384552001953125
Epoch 0 MSE DIR 63.75547409057617 MSE VAR 0.4989011287689209 KLD 21.7077579498291
Epoch 0 MSE DIR 55.02542495727539 MSE VAR 0.49091458320617676 KLD 15.63502311706543
Epoch 0 MSE DIR 265.619140625 MSE VAR 0.48454052209854126 KLD 17.511260986328125
Epoch 0 MSE DIR 43.983646392822266 MSE VAR 0.47705817222595215 KLD 20.63077163696289
Epoch 0 MSE DIR 29.06559944152832 MSE VAR 0.48057276010513306 KLD 16.88313293457

Epoch 0 MSE DIR 65.30889129638672 MSE VAR 0.4747569262981415 KLD 16.600173950195312
Epoch 0 MSE DIR 70.9288330078125 MSE VAR 0.47085967659950256 KLD 19.804996490478516
Epoch 0 MSE DIR 273.83929443359375 MSE VAR 0.4846150577068329 KLD 12.48015308380127
Epoch 0 MSE DIR 131.65065002441406 MSE VAR 0.48632121086120605 KLD 17.08696746826172
Epoch 0 MSE DIR 202.39019775390625 MSE VAR 0.4845258593559265 KLD 10.938974380493164
Epoch 0 MSE DIR 263.296630859375 MSE VAR 0.515533447265625 KLD 16.58094024658203
Epoch 0 MSE DIR 139.33888244628906 MSE VAR 0.4948576092720032 KLD 19.761442184448242
Epoch 0 MSE DIR 134.6815185546875 MSE VAR 0.4773796796798706 KLD 16.408437728881836
Epoch 0 MSE DIR 56.673683166503906 MSE VAR 0.4686531722545624 KLD 21.452760696411133
Epoch 0 MSE DIR 49.19620132446289 MSE VAR 0.48327597975730896 KLD 20.252277374267578
Epoch 0 MSE DIR 100.35322570800781 MSE VAR 0.49803784489631653 KLD 19.586130142211914
Epoch 0 MSE DIR 277.2655944824219 MSE VAR 0.48942264914512634 KLD 11.910

Epoch 0 MSE DIR 136.23037719726562 MSE VAR 0.49043816328048706 KLD 14.775948524475098
Epoch 0 MSE DIR 75.4044418334961 MSE VAR 0.502511739730835 KLD 18.011838912963867
Epoch 0 MSE DIR 11.128196716308594 MSE VAR 0.4986345171928406 KLD 15.39555835723877
Epoch 0 MSE DIR 28.119112014770508 MSE VAR 0.48652568459510803 KLD 22.689781188964844
Epoch 0 MSE DIR 33.58951950073242 MSE VAR 0.47499752044677734 KLD 20.82097625732422
Epoch 0 MSE DIR 21.655057907104492 MSE VAR 0.47699764370918274 KLD 19.755094528198242
Epoch 0 MSE DIR 53.262393951416016 MSE VAR 0.4871726930141449 KLD 15.556474685668945
Epoch 0 MSE DIR 31.76093292236328 MSE VAR 0.4960762858390808 KLD 18.369197845458984
Epoch 0 MSE DIR 32.018890380859375 MSE VAR 0.4967794418334961 KLD 13.984676361083984
Epoch 0 MSE DIR 132.863037109375 MSE VAR 0.49092212319374084 KLD 24.2818603515625
Epoch 0 MSE DIR 56.60456085205078 MSE VAR 0.4824283719062805 KLD 18.570505142211914
Epoch 0 MSE DIR 42.75009536743164 MSE VAR 0.48068925738334656 KLD 15.926

Epoch 0 MSE DIR 62.67934799194336 MSE VAR 0.5047699809074402 KLD 18.269615173339844
Epoch 0 MSE DIR 353.98602294921875 MSE VAR 0.4842996895313263 KLD 9.930245399475098
Epoch 0 MSE DIR 51.449195861816406 MSE VAR 0.5027763247489929 KLD 16.366893768310547
Epoch 0 MSE DIR 37.7125129699707 MSE VAR 0.48800942301750183 KLD 23.0653076171875
Epoch 0 MSE DIR 254.11593627929688 MSE VAR 0.47932082414627075 KLD 15.618597030639648
Epoch 0 MSE DIR 106.22830200195312 MSE VAR 0.47788211703300476 KLD 16.590070724487305
Epoch 0 MSE DIR 62.89764404296875 MSE VAR 0.48153626918792725 KLD 20.6069278717041
Epoch 0 MSE DIR 47.78356170654297 MSE VAR 0.49418479204177856 KLD 17.973556518554688
Epoch 0 MSE DIR 22.817848205566406 MSE VAR 0.4915810823440552 KLD 24.62181854248047
Epoch 0 MSE DIR 151.81715393066406 MSE VAR 0.4980581998825073 KLD 18.612211227416992
Epoch 0 MSE DIR 108.93367004394531 MSE VAR 0.4855189919471741 KLD 17.2612361907959
Epoch 0 MSE DIR 103.99579620361328 MSE VAR 0.4820929765701294 KLD 16.6206

Epoch 0 MSE DIR 26.171031951904297 MSE VAR 0.49624040722846985 KLD 19.27621078491211
Epoch 0 MSE DIR 48.47919845581055 MSE VAR 0.4819428622722626 KLD 20.17801284790039
Epoch 0 MSE DIR 170.66094970703125 MSE VAR 0.4769115149974823 KLD 25.980844497680664
Epoch 0 MSE DIR 425.219482421875 MSE VAR 0.47018948197364807 KLD 7.80620813369751
Epoch 0 MSE DIR 29.03307342529297 MSE VAR 0.5046126842498779 KLD 25.457656860351562
Epoch 0 MSE DIR 33.79882049560547 MSE VAR 0.5090874433517456 KLD 21.209423065185547
Epoch 0 MSE DIR 42.29501724243164 MSE VAR 0.48881176114082336 KLD 14.831239700317383
Epoch 0 MSE DIR 45.713069915771484 MSE VAR 0.4787399172782898 KLD 21.77016830444336
Epoch 0 MSE DIR 43.9899787902832 MSE VAR 0.47445598244667053 KLD 20.008167266845703
Epoch 0 MSE DIR 67.99532318115234 MSE VAR 0.47167956829071045 KLD 16.194169998168945
Epoch 0 MSE DIR 46.57965850830078 MSE VAR 0.4857686161994934 KLD 17.472900390625
Epoch 0 MSE DIR 36.269718170166016 MSE VAR 0.5019770264625549 KLD 18.489614486

Epoch 0 MSE DIR 29.73991584777832 MSE VAR 0.5043911337852478 KLD 19.77332878112793
Epoch 0 MSE DIR 80.36019134521484 MSE VAR 0.49245384335517883 KLD 16.277048110961914
Epoch 0 MSE DIR 401.22955322265625 MSE VAR 0.4929845929145813 KLD 12.63232421875
Epoch 0 MSE DIR 38.9374885559082 MSE VAR 0.4714098870754242 KLD 16.639259338378906
Epoch 0 MSE DIR 68.35247802734375 MSE VAR 0.4746211767196655 KLD 12.63078784942627
Epoch 0 MSE DIR 71.68399810791016 MSE VAR 0.4818432331085205 KLD 14.455817222595215
Epoch 0 MSE DIR 213.87184143066406 MSE VAR 0.49928486347198486 KLD 16.68682098388672
Epoch 0 MSE DIR 40.346405029296875 MSE VAR 0.5102865695953369 KLD 15.876436233520508
Epoch 0 MSE DIR 14.633035659790039 MSE VAR 0.48973119258880615 KLD 18.107608795166016
Epoch 0 MSE DIR 315.5648193359375 MSE VAR 0.4648666977882385 KLD 11.304473876953125
Epoch 0 MSE DIR 30.683147430419922 MSE VAR 0.4764270484447479 KLD 17.932706832885742
Epoch 0 MSE DIR 72.74365997314453 MSE VAR 0.49255624413490295 KLD 21.0949344

Epoch 0 MSE DIR 57.11321258544922 MSE VAR 0.4922392964363098 KLD 16.279422760009766
Epoch 0 MSE DIR 18.407289505004883 MSE VAR 0.4889700710773468 KLD 22.555076599121094
Epoch 0 MSE DIR 75.45045471191406 MSE VAR 0.4887810945510864 KLD 19.649641036987305
Epoch 0 MSE DIR 47.28452682495117 MSE VAR 0.4818139374256134 KLD 18.98459243774414
Epoch 0 MSE DIR 35.90883255004883 MSE VAR 0.48363250494003296 KLD 21.50166893005371
Epoch 0 MSE DIR 52.35362243652344 MSE VAR 0.48823949694633484 KLD 19.36288833618164
Epoch 0 MSE DIR 23.731826782226562 MSE VAR 0.4882471263408661 KLD 22.654029846191406
Epoch 0 MSE DIR 229.0205841064453 MSE VAR 0.49658721685409546 KLD 13.503154754638672
Epoch 0 MSE DIR 22.03561019897461 MSE VAR 0.48712897300720215 KLD 26.578821182250977
Epoch 0 MSE DIR 31.828161239624023 MSE VAR 0.4854203462600708 KLD 20.727569580078125
Epoch 0 MSE DIR 23.413061141967773 MSE VAR 0.4835151433944702 KLD 21.44118309020996
Epoch 0 MSE DIR 18.14818000793457 MSE VAR 0.4849001169204712 KLD 25.3597

Epoch 0 MSE DIR 37.26631164550781 MSE VAR 0.49576815962791443 KLD 18.923450469970703
Epoch 0 MSE DIR 54.36326217651367 MSE VAR 0.5047555565834045 KLD 13.04146957397461
Epoch 0 MSE DIR 131.37796020507812 MSE VAR 0.49434128403663635 KLD 15.285767555236816
Epoch 0 MSE DIR 137.99903869628906 MSE VAR 0.4799124002456665 KLD 13.91370677947998
Epoch 0 MSE DIR 62.26124572753906 MSE VAR 0.4769609570503235 KLD 14.654330253601074
Epoch 0 MSE DIR 190.21652221679688 MSE VAR 0.4874851107597351 KLD 14.977587699890137
Epoch 0 MSE DIR 148.94210815429688 MSE VAR 0.48860204219818115 KLD 15.01517391204834
Epoch 0 MSE DIR 281.9072570800781 MSE VAR 0.48285773396492004 KLD 12.575149536132812
Epoch 0 MSE DIR 26.977519989013672 MSE VAR 0.5052722692489624 KLD 20.77043914794922
Epoch 0 MSE DIR 19.34347915649414 MSE VAR 0.49428778886795044 KLD 20.2283935546875
Epoch 0 MSE DIR 48.007484436035156 MSE VAR 0.4802229702472687 KLD 17.946754455566406
Epoch 0 MSE DIR 87.71917724609375 MSE VAR 0.47682178020477295 KLD 15.30

Epoch 0 MSE DIR 50.38332748413086 MSE VAR 0.4941220283508301 KLD 21.221221923828125
Epoch 0 MSE DIR 62.81395721435547 MSE VAR 0.48949944972991943 KLD 22.812965393066406
Epoch 0 MSE DIR 280.6525573730469 MSE VAR 0.48741960525512695 KLD 14.43233871459961
Epoch 0 MSE DIR 130.20596313476562 MSE VAR 0.48532718420028687 KLD 13.472476959228516
Epoch 0 MSE DIR 60.01957321166992 MSE VAR 0.4815253019332886 KLD 22.459007263183594
Epoch 0 MSE DIR 122.4842758178711 MSE VAR 0.4840971529483795 KLD 16.079177856445312
Epoch 0 MSE DIR 91.38145446777344 MSE VAR 0.4978363811969757 KLD 18.46520233154297
Epoch 0 MSE DIR 91.07223510742188 MSE VAR 0.4933130145072937 KLD 20.907644271850586
Epoch 0 MSE DIR 94.76850128173828 MSE VAR 0.48886677622795105 KLD 20.06264305114746
Epoch 0 MSE DIR 225.21226501464844 MSE VAR 0.47984763979911804 KLD 19.990650177001953
Epoch 0 MSE DIR 133.1103057861328 MSE VAR 0.48564428091049194 KLD 24.081560134887695
Epoch 0 MSE DIR 22.26288414001465 MSE VAR 0.4849807322025299 KLD 24.017

Epoch 0 MSE DIR 244.2334442138672 MSE VAR 0.47549206018447876 KLD 17.98997688293457
Epoch 0 MSE DIR 348.5998840332031 MSE VAR 0.5044434070587158 KLD 12.852302551269531
Epoch 0 MSE DIR 287.45458984375 MSE VAR 0.5094374418258667 KLD 12.889674186706543
Epoch 0 MSE DIR 63.97327423095703 MSE VAR 0.4955926835536957 KLD 15.182429313659668
Epoch 0 MSE DIR 104.03915405273438 MSE VAR 0.47695186734199524 KLD 18.966588973999023
Epoch 0 MSE DIR 125.16159057617188 MSE VAR 0.4701671600341797 KLD 14.442363739013672
Epoch 0 MSE DIR 63.95769500732422 MSE VAR 0.4794623851776123 KLD 21.382238388061523
Epoch 0 MSE DIR 126.49195098876953 MSE VAR 0.4960917532444 KLD 15.252179145812988
Epoch 0 MSE DIR 66.60888671875 MSE VAR 0.5053583979606628 KLD 16.119356155395508
Epoch 0 MSE DIR 117.6066665649414 MSE VAR 0.4975047707557678 KLD 21.262916564941406
Epoch 0 MSE DIR 136.17214965820312 MSE VAR 0.47780871391296387 KLD 24.732614517211914
Epoch 0 MSE DIR 143.00477600097656 MSE VAR 0.47820568084716797 KLD 19.11358642

Epoch 0 MSE DIR 66.20268249511719 MSE VAR 0.4792574644088745 KLD 17.502925872802734
Epoch 0 MSE DIR 101.19564819335938 MSE VAR 0.47911274433135986 KLD 16.151809692382812
Epoch 0 MSE DIR 176.7066650390625 MSE VAR 0.4872366786003113 KLD 15.046976089477539
Epoch 0 MSE DIR 597.8797607421875 MSE VAR 0.4917947053909302 KLD 10.434537887573242
Epoch 0 MSE DIR 39.800880432128906 MSE VAR 0.5062151551246643 KLD 24.717926025390625
Epoch 0 MSE DIR 82.01763916015625 MSE VAR 0.4956875145435333 KLD 23.195926666259766
Epoch 0 MSE DIR 61.27955627441406 MSE VAR 0.4684419333934784 KLD 20.977746963500977
Epoch 0 MSE DIR 72.5771713256836 MSE VAR 0.4785144329071045 KLD 27.152565002441406
Epoch 0 MSE DIR 82.76850128173828 MSE VAR 0.48581334948539734 KLD 17.836687088012695
Epoch 0 MSE DIR 224.81097412109375 MSE VAR 0.49282538890838623 KLD 18.214900970458984
Epoch 0 MSE DIR 386.9143981933594 MSE VAR 0.5139948725700378 KLD 7.062069416046143
Epoch 0 MSE DIR 108.37980651855469 MSE VAR 0.477922648191452 KLD 16.8779

Epoch 0 MSE DIR 212.55477905273438 MSE VAR 0.4872341454029083 KLD 15.596334457397461
Epoch 0 MSE DIR 211.73280334472656 MSE VAR 0.46561741828918457 KLD 17.70681381225586
Epoch 0 MSE DIR 152.63941955566406 MSE VAR 0.47327008843421936 KLD 14.41152286529541
Epoch 0 MSE DIR 206.24720764160156 MSE VAR 0.4946286976337433 KLD 20.302148818969727
Epoch 0 MSE DIR 129.87159729003906 MSE VAR 0.5077182650566101 KLD 12.99826717376709
Epoch 0 MSE DIR 117.0491714477539 MSE VAR 0.5044350624084473 KLD 16.770465850830078
Epoch 0 MSE DIR 149.3950958251953 MSE VAR 0.4844638407230377 KLD 9.098769187927246
Epoch 0 MSE DIR 99.1652603149414 MSE VAR 0.4761034846305847 KLD 12.105501174926758
Epoch 0 MSE DIR 38.360328674316406 MSE VAR 0.4747827649116516 KLD 13.636770248413086
Epoch 0 MSE DIR 95.33283233642578 MSE VAR 0.4878915548324585 KLD 16.878007888793945
Epoch 0 MSE DIR 166.22967529296875 MSE VAR 0.4973873794078827 KLD 14.950218200683594
Epoch 0 MSE DIR 270.5062255859375 MSE VAR 0.5061148405075073 KLD 11.7628

Epoch 0 MSE DIR 189.78857421875 MSE VAR 0.48987632989883423 KLD 14.556461334228516
Epoch 0 MSE DIR 115.69550323486328 MSE VAR 0.49022576212882996 KLD 19.227638244628906
Epoch 0 MSE DIR 177.08352661132812 MSE VAR 0.48894354701042175 KLD 14.984594345092773
Epoch 0 MSE DIR 214.35284423828125 MSE VAR 0.4865833520889282 KLD 16.583160400390625
Epoch 0 MSE DIR 181.57257080078125 MSE VAR 0.4845717251300812 KLD 14.712023735046387
Epoch 0 MSE DIR 80.53824615478516 MSE VAR 0.4849781095981598 KLD 16.810943603515625
Epoch 0 MSE DIR 112.33812713623047 MSE VAR 0.48797985911369324 KLD 15.962019920349121
Epoch 0 MSE DIR 146.31886291503906 MSE VAR 0.4918188154697418 KLD 18.04660987854004
Epoch 0 MSE DIR 171.5209503173828 MSE VAR 0.4881912171840668 KLD 17.6558837890625
Epoch 0 MSE DIR 214.14984130859375 MSE VAR 0.4905090928077698 KLD 16.407501220703125
Epoch 0 MSE DIR 189.68771362304688 MSE VAR 0.48454976081848145 KLD 16.1245059967041
Epoch 0 MSE DIR 122.80288696289062 MSE VAR 0.48803046345710754 KLD 19.

Epoch 0 MSE DIR 394.7746276855469 MSE VAR 0.4999646842479706 KLD 6.127535820007324
Epoch 0 MSE DIR 143.9781036376953 MSE VAR 0.49166417121887207 KLD 14.193462371826172
Epoch 0 MSE DIR 49.5312614440918 MSE VAR 0.48043906688690186 KLD 19.563852310180664
Epoch 0 MSE DIR 92.01725769042969 MSE VAR 0.4796505570411682 KLD 17.303043365478516
Epoch 0 MSE DIR 20.810691833496094 MSE VAR 0.4860578775405884 KLD 24.60117530822754
Epoch 0 MSE DIR 24.491313934326172 MSE VAR 0.4953603446483612 KLD 15.359150886535645
Epoch 0 MSE DIR 59.39488983154297 MSE VAR 0.49618807435035706 KLD 19.871910095214844
Epoch 0 MSE DIR 52.17931365966797 MSE VAR 0.48500436544418335 KLD 17.3857479095459
Epoch 0 MSE DIR 166.11819458007812 MSE VAR 0.48154720664024353 KLD 15.339666366577148
Epoch 0 MSE DIR 85.11192321777344 MSE VAR 0.48483848571777344 KLD 16.84358024597168
Epoch 0 MSE DIR 48.477474212646484 MSE VAR 0.48829859495162964 KLD 21.5751895904541
Epoch 0 MSE DIR 32.77687072753906 MSE VAR 0.49112701416015625 KLD 14.3330

Epoch 0 MSE DIR 89.98662567138672 MSE VAR 0.48460882902145386 KLD 18.433021545410156
Epoch 0 MSE DIR 81.67595672607422 MSE VAR 0.48387178778648376 KLD 18.757232666015625
Epoch 0 MSE DIR 38.73862075805664 MSE VAR 0.487802654504776 KLD 21.383766174316406
Epoch 0 MSE DIR 39.17967224121094 MSE VAR 0.490060031414032 KLD 23.069751739501953
Epoch 0 MSE DIR 217.25100708007812 MSE VAR 0.49617525935173035 KLD 15.021693229675293
Epoch 0 MSE DIR 92.74906921386719 MSE VAR 0.4754589796066284 KLD 19.51173973083496
Epoch 0 MSE DIR 32.28672409057617 MSE VAR 0.48687249422073364 KLD 26.300716400146484
Epoch 0 MSE DIR 35.390560150146484 MSE VAR 0.4935140609741211 KLD 24.595853805541992
Epoch 0 MSE DIR 31.997177124023438 MSE VAR 0.49391353130340576 KLD 19.85216522216797
Epoch 0 MSE DIR 81.19160461425781 MSE VAR 0.48547494411468506 KLD 19.134490966796875
Epoch 0 MSE DIR 68.98343658447266 MSE VAR 0.4799458980560303 KLD 17.610132217407227
Epoch 0 MSE DIR 13.970151901245117 MSE VAR 0.4826270341873169 KLD 20.84

Epoch 0 MSE DIR 77.69110870361328 MSE VAR 0.47823432087898254 KLD 18.652647018432617
Epoch 3200 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 11.679546187500002
Epoch 0 MSE DIR 241.85769653320312 MSE VAR 0.4695931673049927 KLD 22.34691047668457
Epoch 0 MSE DIR 78.90731811523438 MSE VAR 0.47743624448776245 KLD 23.66362953186035
Epoch 0 MSE DIR 255.8307647705078 MSE VAR 0.4946187138557434 KLD 18.746946334838867
Epoch 0 MSE DIR 45.872989654541016 MSE VAR 0.5036299824714661 KLD 17.480236053466797
Epoch 0 MSE DIR 634.7990112304688 MSE VAR 0.5030899047851562 KLD 15.01127815246582
Epoch 0 MSE DIR 31.373655319213867 MSE VAR 0.47907552123069763 KLD 11.510225296020508
Epoch 0 MSE DIR 107.20407104492188 MSE VAR 0.47176432609558105 KLD 13.194710731506348
Epoch 0 MSE DIR 494.569580078125 MSE VAR 0.4871092140674591 KLD 17.458660125732422
Epoch 0 MSE DIR 247.13514709472656 MSE VAR 0.4925449788570404 KLD 20.917875289916992
Epoch 0 MSE DIR 179.1746368408203 MSE VAR 0.49287810921669006 K

Epoch 0 MSE DIR 43.56894302368164 MSE VAR 0.49064016342163086 KLD 18.07627296447754
Epoch 0 MSE DIR 29.643329620361328 MSE VAR 0.48738762736320496 KLD 21.256013870239258
Epoch 0 MSE DIR 35.061458587646484 MSE VAR 0.4871926009654999 KLD 20.01259994506836
Epoch 0 MSE DIR 31.293121337890625 MSE VAR 0.4864305853843689 KLD 17.225391387939453
Epoch 3300 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 15.925458062499999
Epoch 0 MSE DIR 56.06727600097656 MSE VAR 0.48849424719810486 KLD 20.919363021850586
Epoch 0 MSE DIR 10.62805461883545 MSE VAR 0.48728975653648376 KLD 18.59916114807129
Epoch 0 MSE DIR 23.689477920532227 MSE VAR 0.48634833097457886 KLD 20.848730087280273
Epoch 0 MSE DIR 7.77536153793335 MSE VAR 0.49039342999458313 KLD 18.095138549804688
Epoch 0 MSE DIR 336.82763671875 MSE VAR 0.4856868386268616 KLD 18.475067138671875
Epoch 0 MSE DIR 16.13042640686035 MSE VAR 0.4913173019886017 KLD 18.457090377807617
Epoch 0 MSE DIR 244.3730010986328 MSE VAR 0.4903998076915741 KLD

Epoch 0 MSE DIR 71.10592651367188 MSE VAR 0.48704326152801514 KLD 17.160913467407227
Epoch 0 MSE DIR 41.737449645996094 MSE VAR 0.48312312364578247 KLD 14.225652694702148
Epoch 0 MSE DIR 57.73102569580078 MSE VAR 0.4863468408584595 KLD 14.513413429260254
Epoch 0 MSE DIR 92.88526153564453 MSE VAR 0.4912530481815338 KLD 14.798735618591309
Epoch 0 MSE DIR 26.465513229370117 MSE VAR 0.49079248309135437 KLD 18.59834098815918
Epoch 0 MSE DIR 24.459266662597656 MSE VAR 0.48820269107818604 KLD 12.691387176513672
Epoch 0 MSE DIR 72.12145233154297 MSE VAR 0.48364633321762085 KLD 16.372411727905273
Epoch 3400 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 12.88411275
Epoch 0 MSE DIR 135.45506286621094 MSE VAR 0.4861694276332855 KLD 16.28935432434082
Epoch 0 MSE DIR 160.98402404785156 MSE VAR 0.48790061473846436 KLD 16.583589553833008
Epoch 0 MSE DIR 24.94438934326172 MSE VAR 0.49088072776794434 KLD 20.062606811523438
Epoch 0 MSE DIR 16.03910255432129 MSE VAR 0.4916076064109802 KLD 

Epoch 0 MSE DIR 22.799114227294922 MSE VAR 0.48700249195098877 KLD 16.334354400634766
Epoch 0 MSE DIR 58.36989974975586 MSE VAR 0.48427337408065796 KLD 19.233081817626953
Epoch 0 MSE DIR 23.881114959716797 MSE VAR 0.4863952100276947 KLD 16.10511016845703
Epoch 0 MSE DIR 381.136474609375 MSE VAR 0.4878505766391754 KLD 12.637229919433594
Epoch 0 MSE DIR 23.392953872680664 MSE VAR 0.49190080165863037 KLD 17.163082122802734
Epoch 0 MSE DIR 172.74993896484375 MSE VAR 0.49261799454689026 KLD 16.939695358276367
Epoch 0 MSE DIR 75.21613311767578 MSE VAR 0.4930357038974762 KLD 14.917959213256836
Epoch 0 MSE DIR 46.1317024230957 MSE VAR 0.4739533066749573 KLD 22.178260803222656
Epoch 0 MSE DIR 247.377197265625 MSE VAR 0.4964938759803772 KLD 18.10205078125
Epoch 0 MSE DIR 51.228755950927734 MSE VAR 0.47605639696121216 KLD 19.035274505615234
Epoch 0 MSE DIR 75.75636291503906 MSE VAR 0.48592835664749146 KLD 14.83336067199707
Epoch 3500 / 10000
Evaluation...
Evaluation using 32 tasks. Mean reward: 1

Epoch 0 MSE DIR 227.3913116455078 MSE VAR 0.49744540452957153 KLD 15.215030670166016
Epoch 0 MSE DIR 52.38847732543945 MSE VAR 0.49686238169670105 KLD 22.532123565673828
Epoch 0 MSE DIR 31.53994369506836 MSE VAR 0.48692217469215393 KLD 19.397043228149414
Epoch 0 MSE DIR 361.0370178222656 MSE VAR 0.4743519425392151 KLD 20.8346004486084
Epoch 0 MSE DIR 296.81085205078125 MSE VAR 0.4850369989871979 KLD 12.209278106689453
Epoch 0 MSE DIR 21.485130310058594 MSE VAR 0.4950293004512787 KLD 17.081342697143555
Epoch 0 MSE DIR 19.008243560791016 MSE VAR 0.49411502480506897 KLD 23.224525451660156
Epoch 0 MSE DIR 61.19855499267578 MSE VAR 0.4858565330505371 KLD 21.71772003173828
Epoch 0 MSE DIR 62.53136444091797 MSE VAR 0.4833485186100006 KLD 17.621334075927734
Epoch 0 MSE DIR 49.16545486450195 MSE VAR 0.48443081974983215 KLD 19.106746673583984
Epoch 0 MSE DIR 19.143146514892578 MSE VAR 0.4886527955532074 KLD 17.897201538085938
Epoch 0 MSE DIR 32.7155876159668 MSE VAR 0.4929197132587433 KLD 18.229

Epoch 0 MSE DIR 29.75576400756836 MSE VAR 0.49339383840560913 KLD 19.597007751464844
Epoch 0 MSE DIR 20.450279235839844 MSE VAR 0.4863625466823578 KLD 24.412830352783203
Epoch 0 MSE DIR 232.92926025390625 MSE VAR 0.4851354658603668 KLD 14.903364181518555
Epoch 0 MSE DIR 48.88629913330078 MSE VAR 0.4806923270225525 KLD 18.06434440612793
Epoch 0 MSE DIR 19.654464721679688 MSE VAR 0.486275315284729 KLD 20.141889572143555
Epoch 0 MSE DIR 54.6407585144043 MSE VAR 0.4909486770629883 KLD 16.580142974853516
Epoch 0 MSE DIR 74.94363403320312 MSE VAR 0.4926506280899048 KLD 20.770584106445312
Epoch 0 MSE DIR 32.430274963378906 MSE VAR 0.49357840418815613 KLD 21.188575744628906
Epoch 0 MSE DIR 92.7341079711914 MSE VAR 0.4887000322341919 KLD 21.025354385375977
Epoch 0 MSE DIR 117.94430541992188 MSE VAR 0.47942572832107544 KLD 20.7138729095459
Epoch 0 MSE DIR 11.015804290771484 MSE VAR 0.4799269437789917 KLD 26.985862731933594
Epoch 0 MSE DIR 49.39066696166992 MSE VAR 0.4909909963607788 KLD 23.96019

Epoch 0 MSE DIR 65.18696594238281 MSE VAR 0.4936406910419464 KLD 16.93975257873535
Epoch 0 MSE DIR 27.360654830932617 MSE VAR 0.49033406376838684 KLD 20.579952239990234
Epoch 0 MSE DIR 57.257835388183594 MSE VAR 0.4809683561325073 KLD 14.748924255371094
Epoch 0 MSE DIR 43.73590850830078 MSE VAR 0.48042747378349304 KLD 21.85757827758789
Epoch 0 MSE DIR 36.90272903442383 MSE VAR 0.49207642674446106 KLD 23.14450454711914
Epoch 0 MSE DIR 135.73257446289062 MSE VAR 0.4971372187137604 KLD 16.044801712036133
Epoch 0 MSE DIR 51.17812728881836 MSE VAR 0.48745524883270264 KLD 19.719833374023438
Epoch 0 MSE DIR 27.910743713378906 MSE VAR 0.48248666524887085 KLD 23.181262969970703
Epoch 0 MSE DIR 46.4608268737793 MSE VAR 0.481234610080719 KLD 14.227428436279297
Epoch 0 MSE DIR 192.34677124023438 MSE VAR 0.4898410439491272 KLD 14.70449161529541
Epoch 0 MSE DIR 66.88955688476562 MSE VAR 0.494772732257843 KLD 16.112510681152344
Epoch 0 MSE DIR 72.49478912353516 MSE VAR 0.4937962293624878 KLD 17.20683

Epoch 0 MSE DIR 38.53894805908203 MSE VAR 0.48461079597473145 KLD 21.424068450927734
Epoch 0 MSE DIR 120.11686706542969 MSE VAR 0.48972365260124207 KLD 21.393810272216797
Epoch 0 MSE DIR 20.009326934814453 MSE VAR 0.49449658393859863 KLD 21.000919342041016
Epoch 0 MSE DIR 49.76842498779297 MSE VAR 0.49235665798187256 KLD 19.147823333740234
Epoch 0 MSE DIR 168.87515258789062 MSE VAR 0.4808998703956604 KLD 18.952083587646484
Epoch 0 MSE DIR 107.23477172851562 MSE VAR 0.48107486963272095 KLD 19.7667179107666
Epoch 0 MSE DIR 270.22900390625 MSE VAR 0.47655731439590454 KLD 16.409423828125
Epoch 0 MSE DIR 98.76374816894531 MSE VAR 0.5013583898544312 KLD 17.581165313720703
Epoch 0 MSE DIR 341.1606750488281 MSE VAR 0.5030604004859924 KLD 21.215587615966797
Epoch 0 MSE DIR 154.86746215820312 MSE VAR 0.48710501194000244 KLD 13.861568450927734
Epoch 0 MSE DIR 114.57585144042969 MSE VAR 0.4775637984275818 KLD 13.942689895629883
Epoch 0 MSE DIR 35.11960220336914 MSE VAR 0.47776055335998535 KLD 19.2

KeyboardInterrupt: 

In [None]:
res_eval_2, res_vae_2, test_list_2 = agent.train(training_iter=50000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=100,
                                           log_dir=".",
                                           use_env_obs=False,
                                           num_vae_steps=1,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )

In [None]:
vi_2 = InferenceNetwork(n_in=4, z_dim=latent_dim)
vi_optim_2 = torch.optim.Adam(vi.parameters(), lr=1e-3)

agent_2 = PosteriorMTAgent(action_space=action_space, device=device, gamma=1,
                                 num_steps=20, num_processes=32,
                                 clip_param=0.1, ppo_epoch=4,
                                 num_mini_batch=8,
                                 value_loss_coef=0.5,
                                 entropy_coef=0.001,
                                 lr=0.000005,
                                 eps=1e-6, max_grad_norm=0.5,
                                 use_linear_lr_decay=False,
                                 use_gae=False,
                                 gae_lambda=0.95,
                                 use_proper_time_limits=False,
                                 obs_shape=obs_shape,
                                 latent_dim=latent_dim,
                                 recurrent_policy=False,
                                 hidden_size=8,
                                 use_elu=True,
                                 variational_model=vi_2,
                                 vae_optim=vi_optim_2,
                                 rescale_obs=True,
                                 max_old=max_old,
                                 min_old=min_old,
                                 vae_min_seq=vae_min_seq,
                                 vae_max_seq=vae_max_seq,
                                 max_action=x_max, 
                                 min_action=x_min,
                        use_time=False, rescale_time=None, max_time=None,
                        max_sigma=30,
                        use_decay_kld=True,
                        decay_kld_rate=1)

res_eval_2, res_vae_2, test_list_2 = agent_2.train(training_iter=60000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=100,
                                           log_dir=".",
                                           use_env_obs=False,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )