In [1]:
import torch
import numpy as np
import pickle

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Product, ConstantKernel as C

import gym_sin
from gym import spaces

from utilities.arguments import get_args
from learner.posterior_multi_task import PosteriorMTAgent
from inference.inference_network import InferenceNetwork
from task.ExploitTaskGenerator import ExploitTaskGenerator
from utilities.folder_management import handle_folder_creation


In [2]:
env_name = "exploit-v0"
action_space = spaces.Discrete(3)
latent_dim = 1
theta_min = -0.1
theta_max = 0.1
noise_var = 0.01

device = "cpu"


In [3]:
task_generator = ExploitTaskGenerator(noise_var=noise_var, theta_min=theta_min, theta_max=theta_max)
f = task_generator.create_task_family(n_tasks=5000, n_batches=1, test_perc=0, batch_size=1)

In [4]:
vae_min_seq = 1
vae_max_seq = 180

max_old = None
min_old = None

obs_shape = (2,)

In [5]:
vi = InferenceNetwork(n_in=4, z_dim=latent_dim)
vi_optim = torch.optim.Adam(vi.parameters(), lr=1e-3)

In [6]:
agent = PosteriorMTAgent(action_space=action_space, device=device, gamma=1,
                                 num_steps=180, num_processes=32,
                                 clip_param=0.2, ppo_epoch=4,
                                 num_mini_batch=8,
                                 value_loss_coef=0.5,
                                 entropy_coef=0.,
                                 lr=0.0001,
                                 eps=1e-6, max_grad_norm=0.5,
                                 use_linear_lr_decay=False,
                                 use_gae=False,
                                 gae_lambda=0.95,
                                 use_proper_time_limits=False,
                                 obs_shape=obs_shape,
                                 latent_dim=latent_dim,
                                 recurrent_policy=False,
                                 hidden_size=8,
                                 use_elu=True,
                                 variational_model=vi,
                                 vae_optim=vi_optim,
                                 rescale_obs=False,
                                 max_old=max_old,
                                 min_old=min_old,
                                 vae_min_seq=vae_min_seq,
                                 vae_max_seq=vae_max_seq,
                                 max_action=None, 
                                 min_action=None,
                        use_time=False, rescale_time=None, max_time=None,
                        max_sigma=10)

In [7]:
res_eval, res_vae, test_list = agent.train(training_iter=1000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=10,
                                           log_dir=".",
                                           use_env_obs=False,
                                           num_vae_steps=1,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )

Vae step 0/1, mse 2.223270893096924, kdl 10.13418960571289, num steps 41
Vae step 0/1, mse 1.807769536972046, kdl 7.851034641265869, num steps 38
Vae step 0/1, mse 12.095327377319336, kdl 59.60596466064453, num steps 130
Vae step 0/1, mse 2.88102388381958, kdl 13.364028930664062, num steps 60
Vae step 0/1, mse 15.059921264648438, kdl 76.19908905029297, num steps 164
Vae step 0/1, mse 0.6722270250320435, kdl 1.960891604423523, num steps 17
Vae step 0/1, mse 5.787330627441406, kdl 28.66786766052246, num steps 109
Vae step 0/1, mse 0.9614312648773193, kdl 3.8893930912017822, num steps 39
Vae step 0/1, mse 10.635293006896973, kdl 53.83515930175781, num steps 171
Vae step 0/1, mse 0.6341069340705872, kdl 1.7078030109405518, num steps 13
Vae step 0/1, mse 8.301385879516602, kdl 41.73979568481445, num steps 171
Vae step 0/1, mse 4.403647422790527, kdl 22.242023468017578, num steps 134
Epoch 10 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 167.63432509375002
Vae step 0/1, mse 1.

Vae step 0/1, mse 0.0667431652545929, kdl 0.0674041360616684, num steps 102
Vae step 0/1, mse 0.06767112761735916, kdl 0.06839581578969955, num steps 102
Vae step 0/1, mse 0.054562389850616455, kdl 0.12951888144016266, num steps 113
Vae step 0/1, mse 0.06417982280254364, kdl 0.08382787555456161, num steps 105
Epoch 100 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 159.89266715624998
Vae step 0/1, mse 0.08105107396841049, kdl 0.02839094027876854, num steps 92
Vae step 0/1, mse 0.1648421734571457, kdl 0.06252983212471008, num steps 57
Vae step 0/1, mse 0.045954108238220215, kdl 0.20162232220172882, num steps 125
Vae step 0/1, mse 0.052372053265571594, kdl 0.14093858003616333, num steps 116
Vae step 0/1, mse 0.10393977165222168, kdl 0.002718971110880375, num steps 80
Vae step 0/1, mse 0.039603058248758316, kdl 0.32196611166000366, num steps 143
Vae step 0/1, mse 0.04130571708083153, kdl 0.2663552463054657, num steps 137
Vae step 0/1, mse 0.04591430351138115, kdl 0.210236936

Vae step 0/1, mse 0.05844917520880699, kdl 0.10673164576292038, num steps 134
Vae step 0/1, mse 0.17925727367401123, kdl 0.09352752566337585, num steps 43
Vae step 0/1, mse 0.10221871733665466, kdl 0.004023544490337372, num steps 86
Vae step 0/1, mse 0.21645085513591766, kdl 0.1778116226196289, num steps 29
Vae step 0/1, mse 0.10110899806022644, kdl 0.004632917232811451, num steps 86
Vae step 0/1, mse 0.07992448657751083, kdl 0.030949287116527557, num steps 104
Vae step 0/1, mse 0.049707673490047455, kdl 0.161124125123024, num steps 144
Vae step 0/1, mse 0.05997074395418167, kdl 0.09492781013250351, num steps 127
Vae step 0/1, mse 0.24317994713783264, kdl 0.24375814199447632, num steps 20
Epoch 200 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 166.36610984375
Vae step 0/1, mse 0.05432553589344025, kdl 0.13215342164039612, num steps 138
Vae step 0/1, mse 0.09108395874500275, kdl 0.011724181473255157, num steps 93
Vae step 0/1, mse 0.08225862681865692, kdl 0.02404474094510

Vae step 0/1, mse 0.13480867445468903, kdl 0.01896834187209606, num steps 36
Vae step 0/1, mse 0.16998302936553955, kdl 0.07325288653373718, num steps 11
Vae step 0/1, mse 0.14271382987499237, kdl 0.0256753358989954, num steps 21
Epoch 290 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 161.67851368750001
Vae step 0/1, mse 0.10755009949207306, kdl 0.0014162061270326376, num steps 74
Vae step 0/1, mse 0.10192103683948517, kdl 0.002832245547324419, num steps 81
Vae step 0/1, mse 0.10428964346647263, kdl 0.0020685866475105286, num steps 75
Vae step 0/1, mse 0.06453654170036316, kdl 0.08095614612102509, num steps 165
Vae step 0/1, mse 0.11309203505516052, kdl 0.003987230826169252, num steps 54
Vae step 0/1, mse 0.08000265061855316, kdl 0.030881155282258987, num steps 121
Vae step 0/1, mse 0.09588773548603058, kdl 0.006685073021799326, num steps 86
Vae step 0/1, mse 0.07739613950252533, kdl 0.03495720773935318, num steps 125
Vae step 0/1, mse 0.08421008288860321, kdl 0.02266190

Vae step 0/1, mse 0.09278276562690735, kdl 0.009024912491440773, num steps 140
Vae step 0/1, mse 0.10259881615638733, kdl 0.005921523552387953, num steps 42
Vae step 0/1, mse 0.0946509838104248, kdl 0.007494510151445866, num steps 127
Vae step 0/1, mse 0.10490148514509201, kdl 0.0013887325767427683, num steps 21
Vae step 0/1, mse 0.10174728184938431, kdl 0.002559439977630973, num steps 64
Vae step 0/1, mse 0.1184883713722229, kdl 0.004510483704507351, num steps 11
Vae step 0/1, mse 0.10482727736234665, kdl 0.0016381675377488136, num steps 20
Vae step 0/1, mse 0.09295041859149933, kdl 0.00895618461072445, num steps 153
Vae step 0/1, mse 0.09709866344928741, kdl 0.005671632010489702, num steps 106
Epoch 390 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 163.82588759375
Vae step 0/1, mse 0.08963353931903839, kdl 0.01228410005569458, num steps 174
Vae step 0/1, mse 0.10278388112783432, kdl 0.007797461003065109, num steps 31
Vae step 0/1, mse 0.09814140200614929, kdl 0.0042221

Vae step 0/1, mse 0.09266228973865509, kdl 0.009382638148963451, num steps 173
Vae step 0/1, mse 0.09124286472797394, kdl 0.011373015120625496, num steps 63
Vae step 0/1, mse 0.09378592669963837, kdl 0.007785668596625328, num steps 178
Vae step 0/1, mse 0.08972229063510895, kdl 0.012257721275091171, num steps 53
Epoch 480 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 164.94660025000002
Vae step 0/1, mse 0.09347192943096161, kdl 0.008873154409229755, num steps 118
Vae step 0/1, mse 0.09405370056629181, kdl 0.008710995316505432, num steps 139
Vae step 0/1, mse 0.0903012603521347, kdl 0.012635811232030392, num steps 49
Vae step 0/1, mse 0.09212145954370499, kdl 0.01011863723397255, num steps 96
Vae step 0/1, mse 0.0886080414056778, kdl 0.015222747810184956, num steps 33
Vae step 0/1, mse 0.08830422908067703, kdl 0.01513705588877201, num steps 35
Vae step 0/1, mse 0.09110032767057419, kdl 0.009964264929294586, num steps 67
Vae step 0/1, mse 0.09478440880775452, kdl 0.0071158

Vae step 0/1, mse 0.09102886915206909, kdl 0.010956782847642899, num steps 67
Vae step 0/1, mse 0.09214605391025543, kdl 0.009335213340818882, num steps 86
Vae step 0/1, mse 0.08718457818031311, kdl 0.015100482851266861, num steps 28
Vae step 0/1, mse 0.09351778775453568, kdl 0.007968729361891747, num steps 104
Vae step 0/1, mse 0.09358236193656921, kdl 0.00805056281387806, num steps 130
Vae step 0/1, mse 0.09281928837299347, kdl 0.008049936033785343, num steps 111
Vae step 0/1, mse 0.09374458342790604, kdl 0.007451980374753475, num steps 113
Vae step 0/1, mse 0.08880941569805145, kdl 0.013662870973348618, num steps 41
Vae step 0/1, mse 0.09496952593326569, kdl 0.007877571508288383, num steps 167
Vae step 0/1, mse 0.09369593858718872, kdl 0.007863297127187252, num steps 95
Epoch 580 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 168.15648609375
Vae step 0/1, mse 0.08794242143630981, kdl 0.01499776728451252, num steps 29
Vae step 0/1, mse 0.0951000526547432, kdl 0.00659162

Vae step 0/1, mse 0.09386757761240005, kdl 0.007302552927285433, num steps 110
Vae step 0/1, mse 0.09250850230455399, kdl 0.00843394547700882, num steps 74
Vae step 0/1, mse 0.0927872285246849, kdl 0.008200030773878098, num steps 94
Vae step 0/1, mse 0.09318996965885162, kdl 0.008071299642324448, num steps 122
Vae step 0/1, mse 0.09224028140306473, kdl 0.008680780418217182, num steps 82
Epoch 670 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 165.70060456250002
Vae step 0/1, mse 0.0926552563905716, kdl 0.008524836972355843, num steps 102
Vae step 0/1, mse 0.09270038455724716, kdl 0.008394414559006691, num steps 128
Vae step 0/1, mse 0.09163986891508102, kdl 0.010761003941297531, num steps 54
Vae step 0/1, mse 0.09235908836126328, kdl 0.008915172889828682, num steps 102
Vae step 0/1, mse 0.09272453933954239, kdl 0.008775897324085236, num steps 152
Vae step 0/1, mse 0.08961157500743866, kdl 0.012991419062018394, num steps 36
Vae step 0/1, mse 0.09190944582223892, kdl 0.0092

Evaluation using 32 tasks. Mean reward: 165.16367474999998
Vae step 0/1, mse 0.09046446532011032, kdl 0.011122297495603561, num steps 84
Vae step 0/1, mse 0.08509465306997299, kdl 0.020846806466579437, num steps 14
Vae step 0/1, mse 0.08211701363325119, kdl 0.02345835231244564, num steps 21
Vae step 0/1, mse 0.08309666812419891, kdl 0.022422831505537033, num steps 22
Vae step 0/1, mse 0.08760282397270203, kdl 0.014767132699489594, num steps 10
Vae step 0/1, mse 0.09732400625944138, kdl 0.004279213026165962, num steps 154
Vae step 0/1, mse 0.08935466408729553, kdl 0.011986801400780678, num steps 64
Vae step 0/1, mse 0.09941915422677994, kdl 0.0038283972535282373, num steps 158
Vae step 0/1, mse 0.0941338837146759, kdl 0.006965005770325661, num steps 106
Vae step 0/1, mse 0.22146889567375183, kdl 0.18516263365745544, num steps 2
Epoch 770 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 167.17538146875
Vae step 0/1, mse 0.09062887728214264, kdl 0.010608948767185211, num steps

Vae step 0/1, mse 0.08750647306442261, kdl 0.017530448734760284, num steps 31
Vae step 0/1, mse 0.21122656762599945, kdl 0.15996231138706207, num steps 2
Vae step 0/1, mse 0.09107276797294617, kdl 0.010427914559841156, num steps 140
Vae step 0/1, mse 0.09137734770774841, kdl 0.010001951828598976, num steps 121
Vae step 0/1, mse 0.0914975181221962, kdl 0.009956416673958302, num steps 104
Epoch 860 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 167.88087678125
Vae step 0/1, mse 0.08962374180555344, kdl 0.014095399528741837, num steps 52
Vae step 0/1, mse 0.08668015152215958, kdl 0.019041186198592186, num steps 26
Vae step 0/1, mse 0.09164881706237793, kdl 0.009467706084251404, num steps 95
Vae step 0/1, mse 0.08975450694561005, kdl 0.012935075908899307, num steps 51
Vae step 0/1, mse 0.09292306005954742, kdl 0.008801013231277466, num steps 127
Vae step 0/1, mse 0.3625454604625702, kdl 0.6371777653694153, num steps 1
Vae step 0/1, mse 0.09434762597084045, kdl 0.0074084317311

Evaluation using 32 tasks. Mean reward: 167.41260053125
Vae step 0/1, mse 0.08814237266778946, kdl 0.013569505885243416, num steps 43
Vae step 0/1, mse 0.09397140145301819, kdl 0.007558486424386501, num steps 167
Vae step 0/1, mse 0.09272489696741104, kdl 0.00867924652993679, num steps 115
Vae step 0/1, mse 0.08904079347848892, kdl 0.012556849978864193, num steps 48
Vae step 0/1, mse 0.08575239777565002, kdl 0.01720087230205536, num steps 27
Vae step 0/1, mse 0.09611290693283081, kdl 0.005731344223022461, num steps 177
Vae step 0/1, mse 0.09523900598287582, kdl 0.006138717755675316, num steps 169
Vae step 0/1, mse 0.08782587945461273, kdl 0.014342566020786762, num steps 36
Vae step 0/1, mse 0.09473539888858795, kdl 0.00743718771263957, num steps 109
Vae step 0/1, mse 0.09322252124547958, kdl 0.00847990345209837, num steps 89
Epoch 960 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 167.62532015624998
Vae step 0/1, mse 0.09489281475543976, kdl 0.006603818386793137, num step

In [8]:
res_eval_2, res_vae_2, test_list_2 = agent.train(training_iter=1000,
                                           env_name=env_name,
                                           seed=0,
                                           task_generator=task_generator,
                                           eval_interval=10,
                                           log_dir=".",
                                           use_env_obs=False,
                                           num_vae_steps=1,
                                           init_vae_steps=1,
                                           sw_size=10,
                                           num_random_task_to_eval=32,
                                           num_test_processes=2,
                                           use_data_loader=False,
                                           gp_list_sequences=[],
                                           prior_sequences=[],
                                           init_prior_test_sequences=[],
                                           verbose=True
                                          )

Vae step 0/1, mse 0.0888662114739418, kdl 0.013013292104005814, num steps 39
Vae step 0/1, mse 0.08953103423118591, kdl 0.012997377663850784, num steps 45
Vae step 0/1, mse 0.09180683642625809, kdl 0.009154877625405788, num steps 75
Vae step 0/1, mse 0.09360834956169128, kdl 0.007442496716976166, num steps 100
Vae step 0/1, mse 0.09226628392934799, kdl 0.00899707991629839, num steps 71
Vae step 0/1, mse 0.09667956829071045, kdl 0.00679086335003376, num steps 136
Vae step 0/1, mse 0.09261520206928253, kdl 0.008095569908618927, num steps 71
Vae step 0/1, mse 0.08719795942306519, kdl 0.015345987863838673, num steps 29
Vae step 0/1, mse 0.09316270053386688, kdl 0.008506467565894127, num steps 72
Vae step 0/1, mse 0.0888509526848793, kdl 0.013012906536459923, num steps 36
Vae step 0/1, mse 0.09504552185535431, kdl 0.006209836341440678, num steps 116
Vae step 0/1, mse 0.09384175390005112, kdl 0.0073607442900538445, num steps 90
Epoch 10 / 1000
Evaluation...
Evaluation using 32 tasks. Mean re

Vae step 0/1, mse 0.09284500777721405, kdl 0.009152527898550034, num steps 135
Vae step 0/1, mse 0.09423569589853287, kdl 0.007245917804539204, num steps 80
Vae step 0/1, mse 0.09221474826335907, kdl 0.009673785418272018, num steps 140
Vae step 0/1, mse 0.09098441898822784, kdl 0.010331188328564167, num steps 34
Vae step 0/1, mse 0.09396465122699738, kdl 0.007525123190134764, num steps 59
Vae step 0/1, mse 0.09231701493263245, kdl 0.009133968502283096, num steps 161
Epoch 100 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 168.4969448125
Vae step 0/1, mse 0.09301028400659561, kdl 0.008170904591679573, num steps 131
Vae step 0/1, mse 0.0909842848777771, kdl 0.011550594121217728, num steps 29
Vae step 0/1, mse 0.09300248324871063, kdl 0.007929431274533272, num steps 58
Vae step 0/1, mse 0.09305823594331741, kdl 0.00796554982662201, num steps 50
Vae step 0/1, mse 0.09330467879772186, kdl 0.008006769232451916, num steps 157
Vae step 0/1, mse 0.09437324106693268, kdl 0.00797599

Vae step 0/1, mse 0.09377288073301315, kdl 0.008746584877371788, num steps 107
Epoch 190 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 169.79057628125003
Vae step 0/1, mse 0.0928582102060318, kdl 0.009071249514818192, num steps 118
Vae step 0/1, mse 0.09250848740339279, kdl 0.009462352842092514, num steps 131
Vae step 0/1, mse 0.09403109550476074, kdl 0.007395443506538868, num steps 73
Vae step 0/1, mse 0.09223363548517227, kdl 0.008825048804283142, num steps 118
Vae step 0/1, mse 0.0928167998790741, kdl 0.00826241821050644, num steps 120
Vae step 0/1, mse 0.09248094260692596, kdl 0.010029982775449753, num steps 145
Vae step 0/1, mse 0.19652782380580902, kdl 0.1253485530614853, num steps 2
Vae step 0/1, mse 0.09146025776863098, kdl 0.010708791203796864, num steps 164
Vae step 0/1, mse 0.09126434475183487, kdl 0.010323408991098404, num steps 159
Vae step 0/1, mse 0.09175246208906174, kdl 0.009515980258584023, num steps 154
Epoch 200 / 1000
Evaluation...
Evaluation using 3

Vae step 0/1, mse 0.09184892475605011, kdl 0.00897308625280857, num steps 64
Vae step 0/1, mse 0.09191319346427917, kdl 0.009214702993631363, num steps 62
Vae step 0/1, mse 0.0959053784608841, kdl 0.005522166378796101, num steps 109
Vae step 0/1, mse 0.09661233425140381, kdl 0.0049187010154128075, num steps 130
Vae step 0/1, mse 0.09840685874223709, kdl 0.0036993399262428284, num steps 149
Vae step 0/1, mse 0.08912921696901321, kdl 0.012718446552753448, num steps 39
Vae step 0/1, mse 0.0974438488483429, kdl 0.0042680855840444565, num steps 156
Epoch 290 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 171.81734515624998
Vae step 0/1, mse 0.09564206004142761, kdl 0.005882083438336849, num steps 122
Vae step 0/1, mse 0.09219945222139359, kdl 0.010191859677433968, num steps 5
Vae step 0/1, mse 0.09469529241323471, kdl 0.006468572653830051, num steps 161
Vae step 0/1, mse 0.08730186522006989, kdl 0.014939554035663605, num steps 38
Vae step 0/1, mse 0.09190303832292557, kdl 0.00

Vae step 0/1, mse 0.09359589964151382, kdl 0.007623963989317417, num steps 93
Epoch 380 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 173.63486340625002
Vae step 0/1, mse 0.09384637326002121, kdl 0.0073765055276453495, num steps 126
Vae step 0/1, mse 0.09356045722961426, kdl 0.007748846895992756, num steps 162
Vae step 0/1, mse 0.09079372137784958, kdl 0.011186166666448116, num steps 38
Vae step 0/1, mse 0.08671306073665619, kdl 0.01614338532090187, num steps 20
Vae step 0/1, mse 0.08926810324192047, kdl 0.013326922431588173, num steps 28
Vae step 0/1, mse 0.08724027127027512, kdl 0.015910020098090172, num steps 18
Vae step 0/1, mse 0.09607838094234467, kdl 0.005817989818751812, num steps 164
Vae step 0/1, mse 0.09291048347949982, kdl 0.008290143683552742, num steps 48
Vae step 0/1, mse 0.09633678197860718, kdl 0.00548488087952137, num steps 164
Vae step 0/1, mse 0.09594438225030899, kdl 0.005529666785150766, num steps 106
Epoch 390 / 1000
Evaluation...
Evaluation using 

Vae step 0/1, mse 0.09311633557081223, kdl 0.008038266561925411, num steps 168
Vae step 0/1, mse 0.09497052431106567, kdl 0.006545864045619965, num steps 94
Vae step 0/1, mse 0.09236092120409012, kdl 0.00886128842830658, num steps 6
Vae step 0/1, mse 0.09019254148006439, kdl 0.011389507912099361, num steps 22
Vae step 0/1, mse 0.09467057883739471, kdl 0.006876589730381966, num steps 130
Vae step 0/1, mse 0.09519096463918686, kdl 0.006291444879025221, num steps 70
Vae step 0/1, mse 0.0937284529209137, kdl 0.007921519689261913, num steps 170
Epoch 480 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 173.6291550625
Vae step 0/1, mse 0.09473156929016113, kdl 0.0064658522605896, num steps 71
Vae step 0/1, mse 0.09422733634710312, kdl 0.006794296205043793, num steps 71
Vae step 0/1, mse 0.0947866439819336, kdl 0.0065418812446296215, num steps 78
Vae step 0/1, mse 0.1205788105726242, kdl 0.004576263017952442, num steps 3
Vae step 0/1, mse 0.09244170039892197, kdl 0.009064078330993

Vae step 0/1, mse 0.0933905765414238, kdl 0.007920341566205025, num steps 80
Vae step 0/1, mse 0.0932876393198967, kdl 0.007975522428750992, num steps 77
Epoch 570 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 175.63050731250001
Vae step 0/1, mse 0.09248510748147964, kdl 0.008722788654267788, num steps 153
Vae step 0/1, mse 0.092506043612957, kdl 0.008640380576252937, num steps 151
Vae step 0/1, mse 0.09409791976213455, kdl 0.007248770445585251, num steps 96
Vae step 0/1, mse 0.2797214090824127, kdl 0.3550350069999695, num steps 1
Vae step 0/1, mse 0.08685412257909775, kdl 0.016622893512248993, num steps 8
Vae step 0/1, mse 0.09259527176618576, kdl 0.008806915022432804, num steps 166
Vae step 0/1, mse 0.09211002290248871, kdl 0.009206394664943218, num steps 166
Vae step 0/1, mse 0.09260832518339157, kdl 0.008409986272454262, num steps 85
Vae step 0/1, mse 0.09241445362567902, kdl 0.009024490602314472, num steps 163
Vae step 0/1, mse 0.09310460835695267, kdl 0.00846234895

Vae step 0/1, mse 0.0944250077009201, kdl 0.008316680788993835, num steps 132
Vae step 0/1, mse 0.08801800012588501, kdl 0.014086619950830936, num steps 17
Vae step 0/1, mse 0.0941130593419075, kdl 0.00690098013728857, num steps 96
Vae step 0/1, mse 0.09218583256006241, kdl 0.010194572620093822, num steps 168
Vae step 0/1, mse 0.09453977644443512, kdl 0.010929832234978676, num steps 178
Vae step 0/1, mse 0.09407258778810501, kdl 0.007618321571499109, num steps 136
Vae step 0/1, mse 0.09452784806489944, kdl 0.006745371036231518, num steps 108
Vae step 0/1, mse 0.09317038208246231, kdl 0.007983257994055748, num steps 178
Epoch 670 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 171.92571740625
Vae step 0/1, mse 0.09340817481279373, kdl 0.008711636066436768, num steps 161
Vae step 0/1, mse 0.09352192282676697, kdl 0.008578058332204819, num steps 137
Vae step 0/1, mse 0.09384459257125854, kdl 0.008245361037552357, num steps 143
Vae step 0/1, mse 0.09479810297489166, kdl 0.0064

Vae step 0/1, mse 0.0954613983631134, kdl 0.006177307106554508, num steps 55
Vae step 0/1, mse 0.093757763504982, kdl 0.007286525797098875, num steps 110
Vae step 0/1, mse 0.09480983018875122, kdl 0.006326412782073021, num steps 76
Epoch 760 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 174.9754988125
Vae step 0/1, mse 0.09001177549362183, kdl 0.012202599085867405, num steps 8
Vae step 0/1, mse 0.09306937456130981, kdl 0.008607447147369385, num steps 123
Vae step 0/1, mse 0.0936906710267067, kdl 0.007329246494919062, num steps 84
Vae step 0/1, mse 0.09341323375701904, kdl 0.00804136972874403, num steps 112
Vae step 0/1, mse 0.09035711735486984, kdl 0.01144351251423359, num steps 6
Vae step 0/1, mse 0.09152022004127502, kdl 0.010003548115491867, num steps 158
Vae step 0/1, mse 0.09455692768096924, kdl 0.00675712525844574, num steps 49
Vae step 0/1, mse 0.09494432806968689, kdl 0.006407475098967552, num steps 60
Vae step 0/1, mse 0.09456177800893784, kdl 0.0068317623808979

Vae step 0/1, mse 0.09170122444629669, kdl 0.009585618041455746, num steps 145
Vae step 0/1, mse 0.09162842482328415, kdl 0.009490638971328735, num steps 155
Vae step 0/1, mse 0.09211090952157974, kdl 0.009617852047085762, num steps 173
Vae step 0/1, mse 0.09536252915859222, kdl 0.006650217808783054, num steps 83
Vae step 0/1, mse 0.09557032585144043, kdl 0.00587182492017746, num steps 57
Vae step 0/1, mse 0.0949699729681015, kdl 0.006393800489604473, num steps 96
Vae step 0/1, mse 0.09425973147153854, kdl 0.007291561923921108, num steps 131
Vae step 0/1, mse 0.09452741593122482, kdl 0.006721158511936665, num steps 111
Epoch 860 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 175.04699290625
Vae step 0/1, mse 0.09387145936489105, kdl 0.009447576478123665, num steps 179
Vae step 0/1, mse 0.09451335668563843, kdl 0.00666788499802351, num steps 95
Vae step 0/1, mse 0.09463710337877274, kdl 0.006879640743136406, num steps 90
Vae step 0/1, mse 0.2474283128976822, kdl 0.26271846

Vae step 0/1, mse 0.0923207625746727, kdl 0.008732245303690434, num steps 157
Vae step 0/1, mse 0.09292091429233551, kdl 0.00869828276336193, num steps 34
Vae step 0/1, mse 0.09317275136709213, kdl 0.008101035840809345, num steps 130
Epoch 950 / 1000
Evaluation...
Evaluation using 32 tasks. Mean reward: 176.637393
Vae step 0/1, mse 0.09161287546157837, kdl 0.009434184059500694, num steps 170
Vae step 0/1, mse 0.09395350515842438, kdl 0.00733186723664403, num steps 71
Vae step 0/1, mse 0.09400107711553574, kdl 0.007471160497516394, num steps 44
Vae step 0/1, mse 0.08900108188390732, kdl 0.014487739652395248, num steps 9
Vae step 0/1, mse 0.24253877997398376, kdl 0.24831129610538483, num steps 1
Vae step 0/1, mse 0.09214220196008682, kdl 0.009245135821402073, num steps 34
Vae step 0/1, mse 0.09339138865470886, kdl 0.007652252912521362, num steps 72
Vae step 0/1, mse 0.09153452515602112, kdl 0.00995609164237976, num steps 118
Vae step 0/1, mse 0.09103009849786758, kdl 0.01062521431595087,

In [23]:
def evaluate(agent, num_task_to_evaluate, task_generator, log_dir, seed, use_env_obs, env_name):
    assert num_task_to_evaluate % agent.num_processes == 0

    print("Evaluation...")

    n_iter = num_task_to_evaluate // agent.num_processes
    r_epi_list = []
    action_list = []
    r_list = []

    for _ in range(n_iter):
        envs_kwargs, prev_task, prior, new_tasks = task_generator.sample_pair_tasks(agent.num_processes)
        agent.envs = get_vec_envs_multi_task(env_name, seed, agent.num_processes, agent.gamma, log_dir, agent.device,
                                            True, envs_kwargs, agent.envs, num_frame_stack=None)

        eval_episode_rewards = []

        obs = agent.envs.reset()
        obs = augment_obs_posterior(obs, agent.latent_dim, prior,
                                    use_env_obs, rescale_obs=agent.rescale_obs,
                                    max_old=agent.max_old, min_old=agent.min_old, is_prior=True)
        if agent.use_time:
            obs = augment_obs_time(obs=obs, time=0, rescale_time=agent.rescale_time, max_time=agent.max_time)

        eval_recurrent_hidden_states = torch.zeros(
            agent.num_processes, agent.actor_critic.recurrent_hidden_state_size, device=agent.device)
        eval_masks = torch.zeros(agent.num_processes, 1, device=agent.device)

        use_prev_state = False
        step = 0
        while len(eval_episode_rewards) < agent.num_processes:
            with torch.no_grad():
                _, action, _, eval_recurrent_hidden_states = agent.actor_critic.act(
                    obs,
                    eval_recurrent_hidden_states,
                    eval_masks,
                    deterministic=False)
            action_list.append(action)

            # Observe reward and next obs
            obs, reward, done, infos = agent.envs.step(action)
            r_list.append(reward)

            posterior = get_posterior_no_prev(agent.vae, action, reward, prior,
                                              min_action=agent.min_action, max_action=agent.max_action,
                                              use_prev_state=use_prev_state)
            obs = augment_obs_posterior(obs, agent.latent_dim, posterior,
                                        use_env_obs, rescale_obs=agent.rescale_obs,
                                        max_old=agent.max_old, min_old=agent.min_old, is_prior=False)
            if agent.use_time:
                obs = augment_obs_time(obs=obs, time=step + 1, rescale_time=agent.rescale_time,
                                       max_time=agent.max_time)
            step += 1

            use_prev_state = True
            eval_masks = torch.tensor(
                [[0.0] if done_ else [1.0] for done_ in done],
                dtype=torch.float32,
                device=agent.device)

            for info in infos:
                if 'episode' in info.keys():
                    total_epi_reward = info['episode']['r']
                    eval_episode_rewards.append(total_epi_reward)

        r_epi_list.append(eval_episode_rewards)

    r_epi_list = reduce(list.__add__, r_epi_list)
    print("Evaluation using {} tasks. Mean reward: {}".format(num_task_to_evaluate, np.mean(r_epi_list)))
    return np.mean(r_epi_list), action_list, r_list

In [24]:
from functools import reduce

import numpy as np
import torch

from utilities.observation_utils import augment_obs_posterior, get_posterior_no_prev, augment_obs_time
from inference.inference_utils import loss_inference_closed_form
from ppo_a2c.algo.ppo import PPO
from ppo_a2c.envs import get_vec_envs_multi_task
from ppo_a2c.model import MLPBase, Policy
from ppo_a2c.storage import RolloutStorage

m, a_list, r_list = evaluate(agent, 32, task_generator, ".", 0, False, env_name)

Evaluation...
Evaluation using 32 tasks. Mean reward: 179.38415956250003


In [25]:
a_list = [e.squeeze(1).tolist() for e in a_list]
r_list = [e.squeeze(1).tolist() for e in r_list]

In [26]:
a_list = np.array(a_list)
r_list = np.array(r_list)

In [27]:
np.sum(r_list.mean(0) > 1)

15

In [28]:
count_0 = 0
count_1 = 0
count_2 = 0
info = []

for time in range(a_list.shape[0]):
    for proc in range(a_list.shape[1]):
        if a_list[time, proc] == 0:
            count_0 += 1
        elif a_list[time, proc] == 1:
            count_1 += 1
        elif a_list[time, proc] == 2:
            count_2 += 1
            info.append((time, proc))
            
print(count_2)

60


In [29]:
info

[(0, 7),
 (2, 19),
 (5, 30),
 (8, 14),
 (11, 12),
 (13, 1),
 (26, 24),
 (31, 2),
 (31, 8),
 (32, 23),
 (34, 7),
 (38, 3),
 (40, 16),
 (49, 1),
 (52, 15),
 (52, 29),
 (54, 8),
 (55, 4),
 (55, 28),
 (56, 3),
 (58, 28),
 (62, 4),
 (68, 31),
 (75, 25),
 (80, 2),
 (80, 22),
 (83, 0),
 (84, 15),
 (84, 18),
 (87, 26),
 (89, 19),
 (98, 4),
 (100, 12),
 (101, 22),
 (105, 22),
 (108, 6),
 (109, 17),
 (112, 11),
 (113, 10),
 (114, 18),
 (125, 13),
 (129, 21),
 (130, 15),
 (133, 10),
 (135, 20),
 (135, 28),
 (138, 27),
 (139, 19),
 (139, 26),
 (140, 8),
 (141, 6),
 (141, 10),
 (142, 27),
 (146, 18),
 (155, 28),
 (156, 14),
 (156, 30),
 (161, 6),
 (164, 18),
 (174, 20)]

In [20]:
def fake_step(action, noise_var=0.01, theta=0.1):
    noise_std = np.sqrt(noise_var)
    noise = np.random.normal(loc=0, scale=noise_std, size=1)
    
    if action == 0:
        reward = (1 + theta / 10) + noise
    elif action == 1:
        reward = (1 - theta / 10) + noise
    else:
        reward = (0.7 + noise) if theta > 0 else (0.3 + noise)
    #print("Reward {}".format(reward))
    #print("Noise {}".format(noise))
    return reward[0]

In [21]:
r = 0
for _ in range(150):
    a = np.random.binomial(n=1, p=0.5)
    r += fake_step(a)
r

149.12589351704588

10

In [7]:
np.random.normal(loc=0, scale=np.sqrt(0.1), size=1)
            

array([0.30240772])

In [12]:
np.random.normal(loc=0, scale=np.sqrt(0.1), size=1)
            

array([-0.23285137])