### DeepHAM to solve KS model (code on Nuvolos)

In [None]:
# Define the configurations directly instead of using absl flags
config_path = "./configs/KS/game_nn_n50.json"
exp_name = "1fm1"
seed_index = 0

In [None]:
import os
os.chdir('/files/day2/Yang/code/DeepHAM_nuvolos/src')
os.getcwd()

In [None]:
# Imports from the original script
import json
import time
import datetime
from param import KSParam
from dataset import KSInitDataSet
from value import ValueTrainer
from policy import KSPolicyTrainer
from util import print_elapsedtime
from util import set_random_seed

In [None]:
# Load the configuration from the JSON file
with open(config_path, 'r') as f:
    config = json.load(f)

if "random_seed" in config:
    seed = config["random_seed"][seed_index]
    set_random_seed(seed)
    print(f"Using seed {seed} (index {seed_index})")

print("Solving the problem based on the config path {}".format(config_path))

Using seed 996 (index 0)
Solving the problem based on the config path ./configs/KS/game_nn_n50.json


In [None]:
mparam = KSParam(config["n_agt"], config["beta"], config["mats_path"])
# save config at the beginning for checking
model_path = "../data/simul_results/KS/{}_{}_n{}_{}".format(
    "game" if config["policy_config"]["opt_type"] == "game" else "sp",
    config["dataset_config"]["value_sampling"],
    config["n_agt"],
    exp_name,
)
config["model_path"] = model_path
config["current_time"] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(model_path, exist_ok=True)
with open(os.path.join(model_path, "config_beg.json"), 'w') as f:
    json.dump(config, f)

In [None]:
start_time = time.monotonic()

# Initial value training
init_ds = KSInitDataSet(mparam, config)
value_config = config["value_config"]

if config["init_with_bchmk"]:
    init_policy = init_ds.k_policy_bchmk
    policy_type = "pde"
else:
    init_policy = init_ds.c_policy_const_share
    policy_type = "nn_share"

train_vds, valid_vds = init_ds.get_valuedataset(init_policy, policy_type, update_init=False)
vtrainers = []
for i in range(value_config["num_vnet"]):
    config["vnet_idx"] = str(i)
    vtrainers.append(ValueTrainer(config))

for vtr in vtrainers:
    vtr.train(train_vds, valid_vds, value_config["num_epoch"], value_config["batch_size"])

# Iterative policy and value training
policy_config = config["policy_config"]
ptrainer = KSPolicyTrainer(vtrainers, init_ds)
ptrainer.train(policy_config["num_step"], policy_config["batch_size"])

Average of total utility 20.070721.
The dataset has 4608 samples in total.
Epoch: 0, validation loss: 0.851127
Epoch: 20, validation loss: 0.492391
Epoch: 40, validation loss: 0.493855
Epoch: 60, validation loss: 0.486579
Epoch: 80, validation loss: 0.491024
Epoch: 100, validation loss: 0.483124
Epoch: 120, validation loss: 0.492418
Epoch: 140, validation loss: 0.493509
Epoch: 160, validation loss: 0.491173
Epoch: 180, validation loss: 0.493959
Epoch: 200, validation loss: 0.477271
Epoch: 0, validation loss: 0.838573
Epoch: 20, validation loss: 0.494163
Epoch: 40, validation loss: 0.49039
Epoch: 60, validation loss: 0.488238
Epoch: 80, validation loss: 0.486452
Epoch: 100, validation loss: 0.489583
Epoch: 120, validation loss: 0.484199
Epoch: 140, validation loss: 0.490537
Epoch: 160, validation loss: 0.488818
Epoch: 180, validation loss: 0.48991
Epoch: 200, validation loss: 0.479048
Epoch: 0, validation loss: 1.37506
Epoch: 20, validation loss: 0.500991
Epoch: 40, validation loss: 0.4

100%|██████████| 500/500 [06:15<00:00,  1.33it/s]


Step: 500, valid util: 56.9754, k_end: 17.1912


100%|██████████| 500/500 [03:26<00:00,  2.42it/s]


Step: 1000, valid util: 87.7788, k_end: 36.0737


100%|██████████| 500/500 [03:25<00:00,  2.43it/s]


Step: 1500, valid util: 88.3609, k_end: 34.9668


100%|██████████| 500/500 [03:25<00:00,  2.43it/s]


Step: 2000, valid util: 88.4709, k_end: 34.6618


  0%|          | 0/500 [00:00<?, ?it/s]

Average of total utility 20.070721.
The dataset has 4608 samples in total.
Epoch: 0, validation loss: 186.807
Epoch: 20, validation loss: 95.9493
Epoch: 40, validation loss: 57.5016
Epoch: 60, validation loss: 35.2899
Epoch: 80, validation loss: 19.1556
Epoch: 100, validation loss: 7.98536
Epoch: 120, validation loss: 1.91826
Epoch: 140, validation loss: 0.783189
Epoch: 160, validation loss: 0.779965
Epoch: 180, validation loss: 0.76577
Epoch: 200, validation loss: 0.745737
Epoch: 0, validation loss: 187.053
Epoch: 20, validation loss: 104.368
Epoch: 40, validation loss: 64.2171
Epoch: 60, validation loss: 40.1398
Epoch: 80, validation loss: 22.622
Epoch: 100, validation loss: 10.331
Epoch: 120, validation loss: 3.02649
Epoch: 140, validation loss: 0.764497
Epoch: 160, validation loss: 0.780034
Epoch: 180, validation loss: 0.757978
Epoch: 200, validation loss: 0.759521
Epoch: 0, validation loss: 185.717
Epoch: 20, validation loss: 92.1938
Epoch: 40, validation loss: 52.029
Epoch: 60, v

  0%|          | 1/500 [00:57<8:02:09, 57.98s/it]

Epoch: 200, validation loss: 0.539093
{'current': 61181696, 'peak': 1513204480}


100%|██████████| 500/500 [04:28<00:00,  1.86it/s]


Step: 2500, valid util: 104.214, k_end: 32.5469


100%|██████████| 500/500 [03:27<00:00,  2.41it/s]


Step: 3000, valid util: 104.231, k_end: 32.4812


100%|██████████| 500/500 [03:30<00:00,  2.38it/s]


Step: 3500, valid util: 104.318, k_end: 32.6791


100%|██████████| 500/500 [03:30<00:00,  2.38it/s]


Step: 4000, valid util: 104.334, k_end: 32.4697


  0%|          | 0/500 [00:00<?, ?it/s]

Average of total utility 20.070721.
The dataset has 4608 samples in total.
Epoch: 0, validation loss: 0.573441
Epoch: 20, validation loss: 0.448515
Epoch: 40, validation loss: 0.333241
Epoch: 60, validation loss: 0.321577
Epoch: 80, validation loss: 0.321664
Epoch: 100, validation loss: 0.322814
Epoch: 120, validation loss: 0.316713
Epoch: 140, validation loss: 0.315873
Epoch: 160, validation loss: 0.325676
Epoch: 180, validation loss: 0.321788
Epoch: 200, validation loss: 0.323633
Epoch: 0, validation loss: 0.59317
Epoch: 20, validation loss: 0.527919
Epoch: 40, validation loss: 0.377336
Epoch: 60, validation loss: 0.322659
Epoch: 80, validation loss: 0.320294
Epoch: 100, validation loss: 0.321874
Epoch: 120, validation loss: 0.318747
Epoch: 140, validation loss: 0.318832
Epoch: 160, validation loss: 0.314218
Epoch: 180, validation loss: 0.318597
Epoch: 200, validation loss: 0.316636
Epoch: 0, validation loss: 0.481291
Epoch: 20, validation loss: 0.348955
Epoch: 40, validation loss: 0

  0%|          | 1/500 [00:58<8:03:21, 58.12s/it]

Epoch: 200, validation loss: 0.310321
{'current': 56378880, 'peak': 1530509568}


100%|██████████| 500/500 [04:26<00:00,  1.87it/s]


Step: 4500, valid util: 104.082, k_end: 40.3925


100%|██████████| 500/500 [03:28<00:00,  2.39it/s]


Step: 5000, valid util: 104.101, k_end: 39.8909


100%|██████████| 500/500 [03:25<00:00,  2.43it/s]


Step: 5500, valid util: 104.104, k_end: 39.896


100%|██████████| 500/500 [03:29<00:00,  2.39it/s]


Step: 6000, valid util: 104.096, k_end: 40.1626


  0%|          | 0/500 [00:00<?, ?it/s]

Average of total utility 20.070721.
The dataset has 4608 samples in total.
Epoch: 0, validation loss: 0.365925
Epoch: 20, validation loss: 0.236676
Epoch: 40, validation loss: 0.228967
Epoch: 60, validation loss: 0.229102
Epoch: 80, validation loss: 0.23387
Epoch: 100, validation loss: 0.229223
Epoch: 120, validation loss: 0.234386
Epoch: 140, validation loss: 0.227782
Epoch: 160, validation loss: 0.225267
Epoch: 180, validation loss: 0.227703
Epoch: 200, validation loss: 0.229515
Epoch: 0, validation loss: 0.346858
Epoch: 20, validation loss: 0.229901
Epoch: 40, validation loss: 0.233521
Epoch: 60, validation loss: 0.230503
Epoch: 80, validation loss: 0.229877
Epoch: 100, validation loss: 0.231319
Epoch: 120, validation loss: 0.232255
Epoch: 140, validation loss: 0.226353
Epoch: 160, validation loss: 0.228707
Epoch: 180, validation loss: 0.225967
Epoch: 200, validation loss: 0.23268
Epoch: 0, validation loss: 0.36822
Epoch: 20, validation loss: 0.23493
Epoch: 40, validation loss: 0.23

  0%|          | 1/500 [00:58<8:04:15, 58.23s/it]

Epoch: 200, validation loss: 0.223517
{'current': 59098112, 'peak': 1530509568}


100%|██████████| 500/500 [04:28<00:00,  1.87it/s]


Step: 6500, valid util: 104.632, k_end: 38.8698


  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
# Save config and models
with open(os.path.join(model_path, "config.json"), 'w') as f:
    json.dump(config, f)

for i, vtr in enumerate(vtrainers):
    vtr.save_model(os.path.join(model_path, "value{}.weights.h5".format(i)))

ptrainer.save_model(os.path.join(model_path, "policy.weights.h5"))

# for i, vtr in enumerate(vtrainers):
#     vtr.save_model(os.path.join(model_path, "value{}.h5".format(i)))

# ptrainer.save_model(os.path.join(model_path, "policy.h5"))

end_time = time.monotonic()
print_elapsedtime(end_time - start_time)

In [None]:
model_path

In [None]:
# elapsed_time = end_time - start_time

# # Calculate and format elapsed time
# hours, rem = divmod(elapsed_time, 3600)
# minutes, seconds = divmod(rem, 60)
# formatted_time = f"{int(hours):02}:{int(minutes):02}:{seconds:05.2f}"

# # Print and save elapsed time to file
# elapsed_message = f"Solving the problem based on the config path {config_path} takes {formatted_time}"
# print(elapsed_message)

# with open(os.path.join(model_path, "time.txt"), 'w') as time_file:
#     time_file.write(elapsed_message)