### DeepHAM to solve KS model: policy function optimization with blanks (code on Nuvolos)

## Instructions

> In DeepHAM, the policy function is trained with a dedicated class.
> You will now **implement this class directly** in the notebook instead of importing it.
>
> We give you the skeleton (`PolicyTrainer` and `KSPolicyTrainer`) with some **TODO blanks**:
>
> * Write the **objective** (`loss`) for Krusell–Smith.
> * Use **`stop_gradient`** for fictitious play (game case).
> * Implement the **household dynamics simulation (use budget constraint)** and update of wealth/capital.


## Change to the code directory on Nuvolos

In [None]:
import os
os.chdir('/files/day2/Yang/code/DeepHAM_nuvolos/src')
os.getcwd()

#### code on local machine starts here

In [11]:
# Define the configurations directly instead of using absl flags
config_path = "./configs/KS/game_nn_n50_0fm1gm_test.json"
exp_name = "1gm_test"
seed_index = 3

### Import everything except the policy class to be written

In [12]:
# Imports from the original script
import json
import time
import datetime
from param import KSParam
from dataset import KSInitDataSet
from value import ValueTrainer
from util import print_elapsedtime
from util import set_random_seed

In [13]:
# Load the configuration from the JSON file
with open(config_path, 'r') as f:
    config = json.load(f)

if "random_seed" in config:
    seed = config["random_seed"][seed_index]
    set_random_seed(seed)
    print(f"Using seed {seed} (index {seed_index})")

print("Solving the problem based on the config path {}".format(config_path))

Using seed 789 (index 3)
Solving the problem based on the config path ./configs/KS/game_nn_n50_0fm1gm_test.json


In [14]:
mparam = KSParam(config["n_agt"], config["beta"], config["mats_path"])
# save config at the beginning for checking
model_path = "../data/simul_results/KS/{}_{}_n{}_{}".format(
    "game" if config["policy_config"]["opt_type"] == "game" else "sp",
    config["dataset_config"]["value_sampling"],
    config["n_agt"],
    exp_name,
)
config["model_path"] = model_path
config["current_time"] = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(model_path, exist_ok=True)
with open(os.path.join(model_path, "config_beg.json"), 'w') as f:
    json.dump(config, f)

## Step 1 — Dataset & policy initialization (one-time setup)

In [15]:
# --- Setup & dataset ---
start_time   = time.monotonic()
init_ds      = KSInitDataSet(mparam, config)
value_config = config["value_config"]

# --- Initial policy choice used to build the first value-training dataset ---
if config["init_with_bchmk"]:
    init_policy = init_ds.k_policy_bchmk     # PDE / bspline benchmark policy
    policy_type = "pde"
else:
    init_policy = init_ds.c_policy_const_share  # constant consumption share NN policy
    policy_type = "nn_share"

# --- Build value-training datasets from the chosen initial policy (supervised targets) ---
train_vds, valid_vds = init_ds.get_valuedataset(
    init_policy,
    policy_type,
    update_init=False,
)

Average of total utility 20.068157.
The dataset has 4608 samples in total.


## Step 2 — Initial value-function training (before any policy optimization)


In [16]:
%%time
vtrainers = []
for i in range(value_config["num_vnet"]):
    config["vnet_idx"] = str(i)
    vtrainers.append(ValueTrainer(config))

for vtr in vtrainers:
    vtr.train(train_vds, valid_vds, value_config["num_epoch"], value_config["batch_size"])

Value function learning epoch: 0
Value function learning epoch: 20
Value function learning epoch: 40
Value function learning epoch: 0
Value function learning epoch: 20
Value function learning epoch: 40
Value function learning epoch: 0
Value function learning epoch: 20
Value function learning epoch: 40
CPU times: user 30.5 s, sys: 2.12 s, total: 32.6 s
Wall time: 29.4 s


> **Notes for readers:**
>
> * We pre-train the value network(s) once on data generated by an initial policy.
> * Each `ValueTrainer.train(...)` runs for `value_config["num_epoch"]` epochs over the prepared datasets.
> * These $V$ nets will be used as the terminal bootstrap $\beta^T V(s_T)$ inside policy optimization.


## Step 3 — Policy training with periodic value function updating

**What happens here.**

* We launch `KSPolicyTrainer.train(num_step, batch_size)`.
* **Every step**:

  * draw a **fresh mini-batch** from `policy_ds` and **simulate new shocks** (`sampler`),
  * take **one policy gradient step** (`train_step`).
* **Policy-dataset refresh (from simulation)**:

  * Inside `sampler`, when `policy_ds.epoch_used > epoch_resample`, we **rebuild the dataset** via `update_policydataset(update_init)`.
  * With your config `epoch_resample = 0`, this means: **rebuild after each full pass over the dataset** (cadence ≈ `ceil(dataset_rows / batch_size)` steps).
  * If `update_init=True` (set right after value retraining), the next rebuild is a **hard refresh**: we also update dataset stats from the new simulation.
* **Every `freq_valid` steps**: run validation on a fixed validation set.
* **Every `freq_update_v` steps** (if `value_sampling != "bchmk"`):

  * **rebuild value datasets** under the **current policy**,
  * **retrain** each value net for `value_config["num_epoch"]` epochs,
  * set `update_init=True` so the **next policy-dataset rebuild** performs a **hard refresh**.


In [17]:
# Iterative policy and value training
policy_config = config["policy_config"]


## Block A — PolicyTrainer (base class)

> This is the **generic policy trainer** (abstract base class). It handles:
>
> * State preparation (`prepare_state`),
> * Policy evaluation (`policy_fn`),
> * Gradient calculation (`grad`),
> * Training loop (`train`).
>
> You **do not need to change this**. Read through it carefully.


In [19]:
from policy import PolicyTrainer

## Block B — KSPolicyTrainer (student exercise)

> Now implement the **objective**.
>
> Fill in the blanks:
>
> 1. **Policy output:** use `self.policy_fn(full_state_dict)[...,0]`.
> 2. **Game case:** apply `tf.stop_gradient` to fix others.
> 3. **Factor prices:** compute `R` and `wage` with the KS formulas.
> 4. **Budget constraint:** update `wealth`, `csmp`, and next-period `k_cross`.
> 5. **Utility accumulation:** add `β^t log(csmp)`.
> 6. **Terminal bootstrap:** average value net predictions at `t_unroll-1`.


In [20]:
class KSPolicyTrainer(PolicyTrainer):
    def __init__(self, vtrainers, init_ds, policy_path=None):
        super().__init__(vtrainers, init_ds, policy_path)
        if self.config["init_with_bchmk"]:
            init_policy = self.init_ds.k_policy_bchmk
            policy_type = "pde"
        else:
            init_policy = self.init_ds.c_policy_const_share
            policy_type = "nn_share"
        self.policy_ds = self.init_ds.get_policydataset(init_policy, policy_type, update_init=False)

    @tf.function
    def loss(self, input_data):
        k_cross = input_data["k_cross"]
        ashock, ishock = input_data["ashock"], input_data["ishock"]
        util_sum = 0

        for t in range(self.t_unroll):
            k_mean = tf.reduce_mean(k_cross, axis=1, keepdims=True)
            k_mean_tmp = tf.tile(k_mean, [1, self.mparam.n_agt])
            k_mean_tmp = tf.expand_dims(k_mean_tmp, axis=-1)
            i_tmp = ishock[:, :, t:t+1]
            a_tmp = tf.tile(ashock[:, t:t+1], [1, self.mparam.n_agt])
            a_tmp = tf.expand_dims(a_tmp, axis=2)

            basic_s_tmp = tf.concat(
                [tf.expand_dims(k_cross, axis=-1), k_mean_tmp, a_tmp, i_tmp],
                axis=-1
            )
            basic_s_tmp = self.init_ds.normalize_data(basic_s_tmp, key="basic_s", withtf=True)
            full_state_dict = {
                "basic_s": basic_s_tmp,
                "agt_s": self.init_ds.normalize_data(tf.expand_dims(k_cross, axis=-1), key="agt_s", withtf=True)
            }

            if t == self.t_unroll - 1:
                # --- terminal bootstrap ---
                value = 0
                for vtr in self.vtrainers:
                    value += self.init_ds.unnormalize_data(
                        vtr.value_fn(full_state_dict)[..., 0], key="value", withtf=True
                    )
                value /= self.num_vnet
                util_sum += self.discount[t]*value
                continue

            # (1) policy output
            c_share = ...   # TODO: call self.policy_fn(full_state_dict)[...,0]

            # (2) game case
            if self.policy_config["opt_type"] == "game":
                c_share = tf.concat(
                    [c_share[:, 0:1], tf.stop_gradient(c_share[:, 1:])],
                    axis=1
                )

            # (3) prices
            tau = tf.where(ashock[:, t:t+1] < 1, self.mparam.tau_b, self.mparam.tau_g)
            emp = tf.where(
                ashock[:, t:t+1] < 1,
                self.mparam.l_bar*self.mparam.er_b,
                self.mparam.l_bar*self.mparam.er_g
            )
            tau, emp = tf.cast(tau, DTYPE), tf.cast(emp, DTYPE)
            R    = ...   # TODO: rental rate formula
            wage = ...   # TODO: wage formula

            # (4) budget
            wealth = ...  # TODO: R*k_cross + (1-tau)*wage*l_bar*ishock + mu*wage*(1-ishock)
            csmp   = tf.clip_by_value(c_share * wealth, EPSILON, wealth-EPSILON)
            k_cross = wealth - csmp

            # (5) utility
            util_sum += self.discount[t] * tf.math.log(csmp)

        if self.policy_config["opt_type"] == "socialplanner":
            output_dict = {"m_util": -tf.reduce_mean(util_sum), "k_end": tf.reduce_mean(k_cross)}
        elif self.policy_config["opt_type"] == "game":
            output_dict = {"m_util": -tf.reduce_mean(util_sum[:, 0]), "k_end": tf.reduce_mean(k_cross)}
        return output_dict

    def update_policydataset(self, update_init=False):
        self.policy_ds = self.init_ds.get_policydataset(self.current_c_policy, "nn_share", update_init)

    def get_valuedataset(self, update_init=False):
        return self.init_ds.get_valuedataset(self.current_c_policy, "nn_share", update_init)

    def current_c_policy(self, k_cross, ashock, ishock):
        k_mean = np.mean(k_cross, axis=1, keepdims=True)
        k_mean = np.repeat(k_mean, self.mparam.n_agt, axis=1)
        ashock = np.repeat(ashock, self.mparam.n_agt, axis=1)
        basic_s = np.stack([k_cross, k_mean, ashock, ishock], axis=-1)
        basic_s = self.init_ds.normalize_data(basic_s, key="basic_s")
        basic_s = basic_s.astype(NP_DTYPE)
        full_state_dict = {
            "basic_s": basic_s,
            "agt_s": self.init_ds.normalize_data(k_cross[:, :, None], key="agt_s")
        }
        c_share = self.policy_fn(full_state_dict)[..., 0]
        return c_share

    def simul_shocks(self, n_sample, T, mparam, state_init):
        return KS.simul_shocks(n_sample, T, mparam, state_init)


## Block C — Training call

> Once you fill the blanks, you can run the training just like before:

In [None]:
ptrainer = KSPolicyTrainer(vtrainers, init_ds)
ptrainer.train(policy_config["num_step"], policy_config["batch_size"])


> **Notes for readers:**
>
> * **Mini-batch & shocks:** new **every step**.
> * **Policy-dataset cadence:** with `t_sample=200` and `t_skip=4`, each path contributes \~50 time-slices;
>   dataset rows ≈ `n_path * 50` (minus NaN rows). Rebuild after each full pass:
>   `steps_per_dataset ≈ ceil(dataset_rows / batch_size)`.
>   Example: if `n_path=384`, rows ≈ `384*50=19,200` → `19,200/384=50` steps per rebuild.
> * **Validation:** every `freq_valid=500` steps (20 times for `num_step=10,000`).
> * **Value retrain:** every `freq_update_v=2000` steps (5 times total). This sets `update_init=True`; the **next** dataset rebuild then also updates dataset statistics from the new simulation (hard refresh).
>

In [None]:
# Save config and models
with open(os.path.join(model_path, "config.json"), 'w') as f:
    json.dump(config, f)

for i, vtr in enumerate(vtrainers):
    vtr.save_model(os.path.join(model_path, "value{}.weights.h5".format(i)))

ptrainer.save_model(os.path.join(model_path, "policy.weights.h5"))

end_time = time.monotonic()
print_elapsedtime(end_time - start_time)

In [None]:
model_path