In [1]:
import gym
import torch as th
import numpy as np
import inventory_model
import pandas as pd
from evaluate import *
from ppo_evaluate import ppo_evaluate
import matplotlib.pyplot as plt


from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.env_util import make_vec_env

In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [3]:
#wrapper for cont env with PPO, plot result for several steps
def ppo_eval_interval(p, L, t_t, n_iter, n_step, gae, learning_rate=0.0003):
    ContCONFIG = {'h': 1, 'p': p, 'L': L, 'lambda': 1, 'action':20}
    PolicyCONFIG = dict(activation_fn=th.nn.Tanh,
                     net_arch=[dict(pi=[8,8,8], vf=[64, 64])])
    cont_env = make_vec_env('inventory_cont_config_fix_model-v0', n_envs=4, env_kwargs=ContCONFIG)
    print("Running PPO w/: p=", p, ", L=",L)
    cont_model = PPO(MlpPolicy, cont_env, verbose=0, gamma = 1, gae_lambda=gae, n_epochs=32,
                     learning_rate = learning_rate,use_sde = False, n_steps = n_step, policy_kwargs = PolicyCONFIG)
    env_eval = make_vec_env('inventory_cont_config_fix_model-v0', n_envs=1, env_kwargs=ContCONFIG)
    timesteps = 0
    numiter = n_iter#test
    res_mean_arr = []
    res_std_arr = []

    while(timesteps <= t_t):

        cont_model.learn(total_timesteps=2*4*n_step-1)#each iteration has 8192 timesteps with n_env=4
        timesteps = timesteps + 2*4*n_step

        res_mean, res_std = ppo_evaluate(cont_model, env_eval, numiter)
        res_mean_arr.append(-res_mean)
        res_std_arr.append(res_std)
        
        if -res_mean == min(res_mean_arr):
            cont_model.save("ppo_min_model_"+str(p)+"_"+str(L))
        
#         print(res_mean_arr)
    
    plt.plot(res_mean_arr)
    plt.xlabel("Iteration")
    plt.ylabel("Average cost")
    plt.title("L="+str(L)+", p="+str(p))
    plt.show()
    
    min_model = PPO.load("ppo_min_model_"+str(p)+"_"+str(L))
    mean_min, std_min = ppo_evaluate(min_model, env_eval, 50000)
    print("p="+str(p)+"， L="+str(L)+": mean "+str(-mean_min)+", std_dev: "+str(std_min))
    
    return res_mean_arr, res_std_arr, mean_min, std_min

In [4]:
listp = [99,39,9,4,1,0.25]
listL = [1]
# listp = [1]
# listL = [4]
# t_t = 100000
n_iter = 500
gae_lambda = 0.98
# n_step = 8192
learning_rate = 0.0003
ppo_res = pd.DataFrame(columns = ['p','L','res_mean', 'res_std'])

for p in listp:
    for L in listL:
        n_step = 64*L 
        t_t = 1000*4*n_step
        res_mean, res_std, mean_min, std_min = ppo_eval_interval(p,L,t_t, n_iter, n_step, gae_lambda, learning_rate)
        ppo_res = ppo_res.append({'p': p, 'L':L, 'res_mean':-mean_min, 'res_std': std_min}, ignore_index=True)

Running PPO w/: p= 99 , L= 1
mean:  -61.60724522173628
standard deviation: 4.1429070143495625
mean:  -65.43010354671162
standard deviation: 10.52831064683356
mean:  -59.93526548729782
standard deviation: 4.923238117578097
mean:  -60.64484372057183
standard deviation: 8.455298955380814
mean:  -54.41851272907199
standard deviation: 3.0920547062975663
mean:  -53.04170559876343
standard deviation: 6.249860386760652
mean:  -49.61016013756928
standard deviation: 3.575446207266965
mean:  -48.2898352184737
standard deviation: 8.217014867128286
mean:  -50.35992684871162
standard deviation: 8.161435986396615
mean:  -44.934393528509986
standard deviation: 5.255422444690076
mean:  -42.41826848475593
standard deviation: 5.101441226065303
mean:  -44.14282631023396
standard deviation: 7.0418961916401015
mean:  -27.86918228722531
standard deviation: 4.492130644054722
mean:  -27.042798164814748
standard deviation: 8.065550880778463
mean:  -22.134285351022335
standard deviation: 3.1838664263187844
mean:

mean:  -6.983425579056144
standard deviation: 0.3277664533277393
mean:  -7.103620792041719
standard deviation: 0.5323892672116894
mean:  -6.86984136604094
standard deviation: 0.5649682325229798
mean:  -7.0049884513579315
standard deviation: 0.20341739942716408
mean:  -6.6595255146414045
standard deviation: 0.49499873518295506
mean:  -6.857848854669927
standard deviation: 0.1260374794451398
mean:  -7.437669921177625
standard deviation: 0.34850465508879824
mean:  -7.358137523750216
standard deviation: 0.5737798846848939
mean:  -7.031444584965707
standard deviation: 0.4539390457559106
mean:  -7.229991945950687
standard deviation: 0.4009488232364402
mean:  -6.557199849286675
standard deviation: 0.526250593717769
mean:  -6.814221502785758
standard deviation: 0.6715336410480339
mean:  -6.634175636188873
standard deviation: 0.9288877240815858
mean:  -6.1189550342378904
standard deviation: 0.1181319145488217
mean:  -6.700543656486272
standard deviation: 0.3105716062661559
mean:  -6.74520423337

mean:  -5.992506903605909
standard deviation: 0.7503723427037453
mean:  -6.2460880819275975
standard deviation: 0.6197926032799465
mean:  -6.057183438532054
standard deviation: 0.3007338414510118
mean:  -5.741676189488918
standard deviation: 0.7167542092313609
mean:  -6.543662485968322
standard deviation: 0.6997994808149113
mean:  -6.342811105871946
standard deviation: 0.9870424982355138
mean:  -6.356624690902234
standard deviation: 0.47232352220543405
mean:  -5.88022987742275
standard deviation: 0.3200178712476764
mean:  -5.654181745234132
standard deviation: 0.3844100498113493
mean:  -6.258592036044598
standard deviation: 0.8350171158192277
mean:  -6.006814744737372
standard deviation: 0.6417974800322257
mean:  -5.792916135816998
standard deviation: 0.4693346449788002
mean:  -5.8245060395419594
standard deviation: 0.3471718575376567
mean:  -5.852687025876343
standard deviation: 0.3334203722232003
mean:  -6.319356152275205
standard deviation: 0.6224907794929427
mean:  -5.7991049274712

mean:  -6.465559857933222
standard deviation: 0.16486996123053824
mean:  -6.795507230414449
standard deviation: 0.4887401709119934
mean:  -6.726506464030221
standard deviation: 0.515658404627281
mean:  -6.943096205686778
standard deviation: 0.8037957212502996
mean:  -6.73134875520207
standard deviation: 0.430206052369456
mean:  -6.519880163811147
standard deviation: 0.41893608426914025
mean:  -6.442621116772666
standard deviation: 0.4482513191536485
mean:  -6.630745329272747
standard deviation: 0.5736308787230027
mean:  -6.789702352074533
standard deviation: 0.7133546920700725
mean:  -6.424033088665456
standard deviation: 0.6134120461052863
mean:  -6.688933206101881
standard deviation: 0.6940803707927755
mean:  -6.547278483343869
standard deviation: 0.4141348875563405
mean:  -6.867293529148119
standard deviation: 0.6780911861139319
mean:  -6.434828203308582
standard deviation: 0.4626774187899377
mean:  -6.799395734652876
standard deviation: 0.7304795821393653
mean:  -6.701262954594941


KeyboardInterrupt: 

In [None]:
ppo_res

In [None]:
plt.plot(res_mean)

## Below is obsolete.