In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SBSim: A tutorial of using Reinforcement Learning for Optimizing Energy Use and Minimizing Carbon Emission in Office Buildings

___

Commercial office buildings contribute 17 percent of Carbon Emissions in the US, according to the US Energy Information Administration (EIA), and improving their efficiency will reduce their environmental burden and operating cost. A major contributor of energy consumption in these buildings are the Heating, Ventilation, and Air Conditioning (HVAC) devices. HVAC devices form a complex and interconnected thermodynamic system with the building and outside weather conditions, and current setpoint control policies are not fully optimized for minimizing energy use and carbon emission. Given a suitable training environment, a Reinforcement Learning (RL) agent is able to improve upon these policies, but training such a model, especially in a way that scales to thousands of buildings, presents many practical challenges. Most existing work on applying RL to this important task either makes use of proprietary data, or focuses on expensive and proprietary simulations that may not be grounded in the real world. We present the Smart Buildings Control Suite, the first open source interactive HVAC control dataset extracted from live sensor measurements of devices in real office buildings. The dataset consists of two components: real-world historical data from two buildings, for offline RL, and a lightweight interactive simulator for each of these buildings, calibrated using the historical data, for online and model-based RL. For ease of use, our RL environments are all compatible with the OpenAI gym environment standard. We believe this benchmark will accelerate progress and collaboration on HVAC optimization.

---

This notebook accompanies the paper titled, **Real-World Data and Calibrated Simulation Suite for Offline Training of Reinforcement Learning Agents to Optimize Energy and Emission in Office Buildings** by Judah Goldfeder and John Sipple.

In [5]:
import warnings

warnings.filterwarnings("ignore")

In [6]:
# @title Imports
from dataclasses import dataclass
import datetime, pytz
import enum
import functools
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
from typing import Final, Sequence
from typing import Optional
from typing import Union, cast
os.environ['WRAPT_DISABLE_EXTENSIONS'] = 'true'
from absl import logging
import gin
import gin
from matplotlib import patches
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import pandas as pd
import reverb
import mediapy as media
from IPython.display import clear_output
import sys

2024-11-29 00:55:18.583442: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-29 00:55:19.716717: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-29 00:55:22.754199: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-29 00:55:22.758189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-29 00:55:23.135402: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

In [6]:
sys.path

['/burg/home/ssa2206/miniforge3/envs/sbsim2/lib/python310.zip',
 '/burg/home/ssa2206/miniforge3/envs/sbsim2/lib/python3.10',
 '/burg/home/ssa2206/miniforge3/envs/sbsim2/lib/python3.10/lib-dynload',
 '',
 '/burg/home/ssa2206/miniforge3/envs/sbsim2/lib/python3.10/site-packages',
 '/burg/home/ssa2206/sbsim_dual_control',
 '/burg/home/ssa2206/miniforge3/envs/sbsim2/lib/python3.10/site-packages/setuptools/_vendor']

In [7]:
from smart_control.environment import environment

from smart_control.proto import smart_control_building_pb2, smart_control_normalization_pb2
from smart_control.reward import electricity_energy_cost, natural_gas_energy_cost, setpoint_energy_carbon_reward, setpoint_energy_carbon_regret

from smart_control.simulator import randomized_arrival_departure_occupancy, rejection_simulator_building
from smart_control.simulator import simulator_building, step_function_occupancy, stochastic_convection_simulator

from smart_control.utils import bounded_action_normalizer, building_renderer, controller_reader
from smart_control.utils import controller_writer, conversion_utils, observation_normalizer, reader_lib
from smart_control.utils import writer_lib, histogram_reducer, environment_utils

In [8]:
import tensorflow as tf
from tf_agents.agents.sac import sac_agent, tanh_normal_projection_network
from tf_agents.drivers import py_driver
from tf_agents.keras_layers import inner_reshape
from tf_agents.metrics import py_metrics
from tf_agents.networks import nest_map, sequential
from tf_agents.policies import greedy_policy, py_tf_eager_policy, random_py_policy, tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer, reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.train import actor, learner, triggers
from tf_agents.train.utils import spec_utils, train_utils
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.trajectories import trajectory as trajectory_lib
from tf_agents.trajectories import trajectory
from tf_agents.typing import types

In [19]:
# @title Set local runtime configurations


def logging_info(*args):
    logging.info(*args)
    print(*args)

data_path = "/burg/home/ssa2206/sbsim_dual_control/smart_control/configs/resources/sb1/" #@param {type:"string"}
metrics_path = "/burg/home/ssa2206/sbsim_dual_control/metrics" #@param {type:"string"}
output_data_path = '/burg/home/ssa2206/sbsim_dual_control/output' #@param {type:"string"}
root_dir = "/burg/home/ssa2206/sbsim_dual_control/root" #@param {type:"string"}


@gin.configurable
def get_histogram_reducer():
    reader = controller_reader.ProtoReader(data_path)

    hr = histogram_reducer.HistogramReducer(
        histogram_parameters_tuples=histogram_parameters_tuples,
        reader=reader,
        normalize_reduce=True,
        )
    return hr

!mkdir -p $root_dir
!mkdir -p $output_data_path
!mkdir -p $metrics_path




def remap_filepath(filepath) -> str:
    return str(filepath)


ValueError: A different configurable matching '__main__.get_histogram_reducer' already exists.

To allow re-registration of configurables in an interactive environment, use:

    gin.enter_interactive_mode()

In [8]:
# @title Plotting Utities
reward_shift = 0
reward_scale = 1.0
person_productivity_hour = 300.0

KELVIN_TO_CELSIUS = 273.15

def render_env(env: environment.Environment):
    """Renders the environment."""
    building_layout = env.building._simulator._building._floor_plan

    # create a renderer
    renderer = building_renderer.BuildingRenderer(building_layout, 1)

    # get the current temps to render
    # this also is not ideal, since the temps are not fully exposed.
    # V Ideally this should be a publicly accessable field
    temps = env.building._simulator._building.temp

    input_q = env.building._simulator._building.input_q

    # render
    vmin = 285
    vmax = 305
    image = renderer.render(temps, cmap='bwr', vmin=vmin, vmax=vmax, colorbar=False, 
                            input_q=input_q, diff_range=0.5, diff_size=1,).convert('RGB')
    media.show_image(image, title='Environment %s' % env.current_simulation_timestamp)

In [9]:
from output_formatting import *
from plot_utils import *

# Load up the environment

In this section we load up the Smart Buildings simulator environment.

In [10]:
# @title Utils for importing the environment.

def load_environment(gin_config_file: str):
    """Returns an Environment from a config file."""
    # Global definition is required by Gin library to instantiate Environment.
    global environment  # pylint: disable=global-variable-not-assigned
    with gin.unlock_config():
        gin.parse_config_file(gin_config_file)
        return environment.Environment()  # pylint: disable=no-value-for-parameter


def get_latest_episode_reader(metrics_path: str,) -> controller_reader.ProtoReader:
    episode_infos = controller_reader.get_episode_data(metrics_path).sort_index()
    selected_episode = episode_infos.index[-1]
    episode_path = os.path.join(metrics_path, selected_episode)
    reader = controller_reader.ProtoReader(episode_path)
    return reader

@gin.configurable
def get_histogram_path():
    return data_path


@gin.configurable
def get_reset_temp_values():
    reset_temps_filepath = remap_filepath(
      os.path.join(data_path, "reset_temps.npy")
    )
    return np.load(reset_temps_filepath)


@gin.configurable
def get_zone_path():
    return remap_filepath(
      os.path.join(data_path, "double_resolution_zone_1_2.npy")
    )


@gin.configurable
def get_metrics_path():
    return os.path.join(metrics_path, "metrics")


@gin.configurable
def get_weather_path():
    return remap_filepath(os.path.join(
        data_path, "local_weather_moffett_field_20230701_20231122.csv"
      ))

In the cell below, we will load the collect and eval environments. While we are loading the same environment, below, it would be useful to load the same building over near, but non-overlapping time windows.

In [12]:
"""
@gin.configurable
def to_timestamp(date_str: str) -> pd.Timestamp:
    # Utilty macro for gin config
    return pd.Timestamp(date_str)


@gin.configurable
def local_time(time_str: str) -> pd.Timedelta:
    #Utilty macro for gin config.
    return pd.Timedelta(time_str)

"""
# @gin.configurable
# def enumerate_zones(
#     n_building_x: int, n_building_y: int
# ) -> Sequence[tuple[int, int]]:
#   # Utilty macro for gin config.
#   zone_coordinates = []
#   for x in range(n_building_x):
#     for y in range(n_building_y):
#       zone_coordinates.append((x, y))
#   return zone_coordinates


# @gin.configurable
# def set_observation_normalization_constants(
#     field_id: str, sample_mean: float, sample_variance: float
# ) -> smart_control_normalization_pb2.ContinuousVariableInfo:
#   return smart_control_normalization_pb2.ContinuousVariableInfo(
#       id=field_id, sample_mean=sample_mean, sample_variance=sample_variance
#   )


# @gin.configurable
# def set_action_normalization_constants(
#     min_native_value,
#     max_native_value,
#     min_normalized_value,
#     max_normalized_value,
# ) -> bounded_action_normalizer.BoundedActionNormalizer:
#   return bounded_action_normalizer.BoundedActionNormalizer(
#       min_native_value,
#       max_native_value,
#       min_normalized_value,
#       max_normalized_value,
#   )


# @gin.configurable
# def get_zones_from_config(
#     configuration_path: str,
# ) -> Sequence[smart_control_building_pb2.ZoneInfo]:
#   """Loads up the zones as a gin macro."""
#   with gin.unlock_config():
#     reader = reader_lib_google.RecordIoReader(input_dir=configuration_path)
#     zone_infos = reader.read_zone_infos()
#     return zone_infos


# @gin.configurable
# def get_devices_from_config(
#     configuration_path: str,
# ) -> Sequence[smart_control_building_pb2.DeviceInfo]:
#   """Loads up HVAC devices as a gin macro."""
#   with gin.unlock_config():
#     reader = reader_lib_google.RecordIoReader(input_dir=configuration_path)
#     device_infos = reader.read_device_infos()
#     return device_infos

# @title Load the environments

histogram_parameters_tuples = (
        ('zone_air_temperature_sensor',(285., 286., 287., 288, 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299., 300.,301,302,303)),
        ('supply_air_damper_percentage_command',(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)),
        ('supply_air_flowrate_setpoint',( 0., 0.05, .1, .2, .3, .4, .5,  .7,  .9)),
    )

time_zone = 'US/Pacific'
collect_scenario_config = os.path.join(data_path, "sim_config.gin")
print(collect_scenario_config)
eval_scenario_config = os.path.join(data_path, "sim_config.gin")
print(eval_scenario_config)

collect_env = load_environment(collect_scenario_config)

# For efficency, set metrics_path to None
collect_env._metrics_path = None
collect_env._occupancy_normalization_constant = 125.0

eval_env = load_environment(eval_scenario_config)
# eval_env._label += "_eval"
eval_env._metrics_path = metrics_path
eval_env._occupancy_normalization_constant = 125.0

initial_collect_env = load_environment(eval_scenario_config)

initial_collect_env._metrics_path = metrics_path
initial_collect_env._occupancy_normalization_constant = 125.0

/burg/home/ssa2206/sbsim_dual_control/smart_control/configs/resources/sb1/sim_config.gin
/burg/home/ssa2206/sbsim_dual_control/smart_control/configs/resources/sb1/sim_config.gin


KeyError: "Ambiguous selector 'to_timestamp', matches ['__main__.to_timestamp', 'smart_control.utils.environment_utils.to_timestamp']."
  In file "/burg/home/ssa2206/sbsim_dual_control/smart_control/configs/resources/sb1/sim_config.gin", line 166
        sim/to_timestamp.date_str = %start_timestamp

In the section below, we'll define a function that accepts the envirnment and a policy, and runs a fixed number of episodes. The policy can be a rules-based policy or an RL-based policy.

In [13]:
# @title Define a method to execute the policy on the environment.


def get_trajectory(time_step, current_action: policy_step.PolicyStep):
    """Get the trajectory for the current action and time step."""
    observation = time_step.observation
    action = current_action.action
    policy_info = ()
    reward = time_step.reward
    discount = time_step.discount

    if time_step.is_first():
        return(trajectory.first(observation, action, policy_info, reward, discount))
    elif time_step.is_last():
        return(trajectory.last(observation, action, policy_info, reward, discount))
    else:
        return(trajectory.mid(observation, action, policy_info, reward, discount))


def compute_avg_return(environment, policy, num_episodes=1, time_zone: str = "US/Pacific", 
                       render_interval_steps: int = 24,trajectory_observers=None,):
    """Computes the average return of the policy on the environment.
    Args:
    environment: environment.Environment
    policy: policy.Policy
    num_episodes: total number of eposides to run.
    time_zone: time zone of the environment
    render_interval_steps: Number of steps to take between rendering.
    trajectory_observers: list of trajectory observers for use in rendering.
    """
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        t0 = time.time()
        epoch = t0
        step_id = 0
        execution_times = []
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)

            if trajectory_observers is not None:
                traj = get_trajectory(time_step, action_step)
                for observer in trajectory_observers:
                    observer(traj)

            episode_return += time_step.reward
            t1 = time.time()
            dt = t1 - t0
            episode_seconds = t1 - epoch
            execution_times.append(dt)
            sim_time = environment.current_simulation_timestamp.tz_convert(time_zone)

            print("Step %5d Sim Time: %s, Reward: %8.2f, Return: %8.2f, Mean Step Time:"
                  " %8.2f s, Episode Time: %8.2f s" % (step_id, sim_time.strftime("%Y-%m-%d %H:%M"),
                                                       time_step.reward, episode_return, 
                                                       np.mean(execution_times), episode_seconds,)
                 )
            if (step_id > 0) and (step_id % render_interval_steps == 0):
                if environment._metrics_path:
                    clear_output(wait=True)
                    reader = get_latest_episode_reader(environment._metrics_path)
                    plot_timeseries_charts(reader, time_zone)
                render_env(environment)

            t0 = t1
            step_id += 1
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return

# Rules-based Control (RBC)

In [1]:
# @title Utils for RBC

# We're concerned with controlling Heatpumps/ACs and Hot Water Systems (HWS).
class DeviceType(enum.Enum):
    AC = 0
    HWS = 1


SetpointName = str    # Identify the setpoint
# Setpoint value.
SetpointValue = Union[float, int, bool]


@dataclass
class ScheduleEvent:
    start_time: pd.Timedelta
    device: DeviceType
    setpointaa_name: SetpointName
    setpoint_value: SetpointValue


# A schedule is a list of times and setpoints for a device.
Schedule = list[ScheduleEvent]
ActionSequence = list[tuple[DeviceType, SetpointName]]


def to_rad(sin_theta: float, cos_theta: float) -> float:
    """Converts a sin and cos theta to radians to extract the time."""

    if sin_theta >= 0 and cos_theta >= 0:
        return np.arccos(cos_theta)
    elif sin_theta >= 0 and cos_theta < 0:
        return np.pi - np.arcsin(sin_theta)
    elif sin_theta < 0 and cos_theta < 0:
        return np.pi - np.arcsin(sin_theta)
    else:
        return 2 * np.pi - np.arccos(cos_theta)

    return np.arccos(cos_theta) + rad_offset


def to_dow(sin_theta: float, cos_theta: float) -> float:
    """Converts a sin and cos theta to days to extract day of week."""
    theta = to_rad(sin_theta, cos_theta)
    return np.floor(7 * theta / 2 / np.pi)


def to_hod(sin_theta: float, cos_theta: float) -> float:
    """Converts a sin and cos theta to hours to extract hour of day."""
    theta = to_rad(sin_theta, cos_theta)
    return np.floor(24 * theta / 2 / np.pi)


def find_schedule_action(schedule: Schedule, device: DeviceType, 
                         setpoint_name: SetpointName, timestamp: pd.Timedelta,) -> SetpointValue:
    """Finds the action for a schedule event for a time and schedule."""

    # Get all the schedule events for the device and the setpoint, and turn it
    # into a series.
    device_schedule_dict = {}
    for schedule_event in schedule:
        if (schedule_event.device == device and schedule_event.setpoint_name == setpoint_name):
            device_schedule_dict[schedule_event.start_time] = (schedule_event.setpoint_value)
    device_schedule = pd.Series(device_schedule_dict)

    # Get the indexes of the schedule events that fall before the timestamp.

    device_schedule_indexes = device_schedule.index[device_schedule.index <= timestamp]

    # If are no events preceedding the time, then choose the last
    # (assuming it wraps around).
    if device_schedule_indexes.empty:
        return device_schedule.loc[device_schedule.index[-1]]
    return device_schedule.loc[device_schedule_indexes[-1]]

NameError: name 'enum' is not defined

In [2]:
# @title Define a schedule policy

class SchedulePolicy(tf_policy.TFPolicy):
    """TF Policy implementation of the Schedule policy."""
    def __init__(self, time_step_spec,
      action_spec: types.NestedTensorSpec,
      action_sequence: ActionSequence,
      weekday_schedule_events: Schedule,
      weekend_holiday_schedule_events: Schedule,
      dow_sin_index: int,
      dow_cos_index: int,
      hod_sin_index: int,
      hod_cos_index: int,
      action_normalizers,
      local_start_time: str = pd.Timestamp,
      policy_state_spec: types.NestedTensorSpec = (),
      info_spec: types.NestedTensorSpec = (),
      training: bool = False,
      name: Optional[str] = None,
    ):
        self.weekday_schedule_events = weekday_schedule_events
        self.weekend_holiday_schedule_events = weekend_holiday_schedule_events
        self.dow_sin_index = dow_sin_index
        self.dow_cos_index = dow_cos_index
        self.hod_sin_index = hod_sin_index
        self.hod_cos_index = hod_cos_index
        self.action_sequence = action_sequence
        self.action_normalizers = action_normalizers
        self.local_start_time = local_start_time
        self.norm_mean = 0.0
        self.norm_std = 1.0

        policy_state_spec = ()

        super().__init__(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            policy_state_spec=policy_state_spec,
            info_spec=info_spec,
            clip=False,
            observation_and_action_constraint_splitter=None,
            name=name,
        )

    def _normalize_action_map(
      self, action_map: dict[tuple[DeviceType, SetpointName], SetpointValue]
    ) -> dict[tuple[DeviceType, SetpointName], SetpointValue]:

        normalized_action_map = {}

        for k, v in action_map.items():
            for normalizer_k, normalizer in self.action_normalizers.items():
                if normalizer_k.endswith(k[1]):
                    normed_v = normalizer.agent_value(v)
                    normalized_action_map[k] = normed_v
        return normalized_action_map

    def _get_action(self, time_step) -> dict[tuple[DeviceType, SetpointName], SetpointValue]:

        observation = time_step.observation
        action_spec = cast(tensor_spec.BoundedTensorSpec, self.action_spec)
        dow_sin = (observation[self.dow_sin_index] * self.norm_std) + self.norm_mean
        dow_cos = (observation[self.dow_cos_index] * self.norm_std) + self.norm_mean
        hod_sin = (observation[self.hod_sin_index] * self.norm_std) + self.norm_mean
        hod_cos = (observation[self.hod_cos_index] * self.norm_std) + self.norm_mean

        dow = to_dow(dow_sin, dow_cos)
        hod = to_hod(hod_sin, hod_cos)

        timestamp = (
            pd.Timedelta(hod, unit='hour') + self.local_start_time.utcoffset()
        )

        if dow < 5:  # weekday
            action_map = {(tup[0], tup[1]): 
                        find_schedule_action(self.weekday_schedule_events, tup[0], tup[1], timestamp)
                        for tup in action_sequence}

        else:  # Weekend
            action_map = {
              (tup[0], tup[1]): find_schedule_action(
                  self.weekend_holiday_schedule_events, tup[0], tup[1], timestamp
              )
              for tup in action_sequence
            }

        return action_map

    def _action(self, time_step, policy_state, seed):
        del seed
        action_map = self._get_action(time_step)
        normalized_action_map = self._normalize_action_map(action_map)

        action = np.array([normalized_action_map[device_setpoint] 
                           for device_setpoint in action_sequence],dtype=np.float32,)

        t_action = tf.convert_to_tensor(action)
        
        return policy_step.PolicyStep(t_action, (), ())

NameError: name 'tf_policy' is not defined

Next, we parameterize the setpoint schedule.

We distinguish between weekend and holiday schedules:

* For **weekdays, between 6:00 am and 7:00 pm local time** we maintain occupancy conditions:
  * AC/Heatpump supply air heating setpoint is 12 C
  * Supply water temperarure is 77 C
* For **weekday, before 6:00 am and after 7:00 pm locl time** we maintain efficiency conditions (setback):
  * AC/Heatpump supply air heating setpoint is 0 C
  * Supply water temperarure is 42 C

* For **weekends and holdidays**, all day, we maintain efficiency conditions (setback):
  * AC/Heatpump supply air heating setpoint is 0 C
  * Supply water temperarure is 42 C


In [10]:
# @title Configure the schedule parameters

hod_cos_index = collect_env._field_names.index('hod_cos_000')
hod_sin_index = collect_env._field_names.index('hod_sin_000')
dow_cos_index = collect_env._field_names.index('dow_cos_000')
dow_sin_index = collect_env._field_names.index('dow_sin_000')


# Note that temperatures are specified in Kelvin:
weekday_schedule_events = [
    ScheduleEvent(
        pd.Timedelta(6, unit='hour'),
        DeviceType.AC,
        'supply_air_heating_temperature_setpoint',
        292.0,
    ),
    ScheduleEvent(
        pd.Timedelta(19, unit='hour'),
        DeviceType.AC,
        'supply_air_heating_temperature_setpoint',
        285.0,
    ),
    ScheduleEvent(
        pd.Timedelta(6, unit='hour'),
        DeviceType.HWS,
        'supply_water_setpoint',
        350.0,
    ),
    ScheduleEvent(
        pd.Timedelta(19, unit='hour'),
        DeviceType.HWS,
        'supply_water_setpoint',
        315.0,
    ),
]


weekend_holiday_schedule_events = [
    ScheduleEvent(
        pd.Timedelta(6, unit='hour'),
        DeviceType.AC,
        'supply_air_heating_temperature_setpoint',
        285.0,
    ),
    ScheduleEvent(
        pd.Timedelta(19, unit='hour'),
        DeviceType.AC,
        'supply_air_heating_temperature_setpoint',
        285.0,
    ),
    ScheduleEvent(
        pd.Timedelta(6, unit='hour'),
        DeviceType.HWS,
        'supply_water_setpoint',
        315.0,
    ),
    ScheduleEvent(
        pd.Timedelta(19, unit='hour'),
        DeviceType.HWS,
        'supply_water_setpoint',
        315.0,
    ),
]

action_sequence = [
    (DeviceType.HWS, 'supply_water_setpoint'),
    (DeviceType.AC, 'supply_air_heating_temperature_setpoint'),
]

We instantiate the schedule policy below.

In [11]:
# @title Instantiate the Schedule RBC policy
ts = collect_env.reset()
local_start_time = collect_env.current_simulation_timestamp.tz_convert(tz = 'US/Pacific')

action_normalizers = collect_env._action_normalizers

observation_spec, action_spec, time_step_spec = spec_utils.get_tensor_specs(collect_env)
schedule_policy = SchedulePolicy(
    time_step_spec= time_step_spec,
    action_spec= action_spec,
    action_sequence = action_sequence,
    weekday_schedule_events = weekday_schedule_events,
    weekend_holiday_schedule_events = weekend_holiday_schedule_events,
    dow_sin_index=dow_sin_index,
    dow_cos_index=dow_cos_index,
    hod_sin_index=hod_sin_index,
    hod_cos_index=hod_cos_index,
    local_start_time=local_start_time,
    action_normalizers=action_normalizers,

)


Next, we will run the static control setpoints on the environment to establish baseline performance.

**Note:** This will take some time to execute. Feel free to skip this step if you want to jump directly to the RL section below.

In [12]:
# @title Optionally, execute the schedule policy on the environment
# Optional
# compute_avg_return(eval_env, schedule_policy, 1, time_zone="US/Pacific", render_interval_steps=144, trajectory_observers=None)

# Reinforcement Learning Control
In the previous section we used a simple schedule to control the HVAC setpoints, however in this section, we configure and train a Reinforcement Learning (RL) agent.



In [9]:

# @title Utilities to configure networks for the RL Agent.
dense = functools.partial(
    tf.keras.layers.Dense,
    activation=tf.keras.activations.relu,
    kernel_initializer='glorot_uniform',
)


def logging_info(*args):
    logging.info(*args)
    print(*args)


def create_fc_network(layer_units):
    return sequential.Sequential([dense(num_units) for num_units in layer_units])


def create_identity_layer():
    return tf.keras.layers.Lambda(lambda x: x)


### SAC Critic Network

- obs network learns meaningful representation of state 
- action network learns meaningful representation of action
- joint network $f(z_a, z_s) \rightarrow \hat{Q}(s, a)$

In [15]:


def create_sequential_critic_network(obs_fc_layer_units, action_fc_layer_units, joint_fc_layer_units):
    """Create a sequential critic network."""
    # Split the inputs into observations and actions.
    def split_inputs(inputs):
        return {'observation': inputs[0], 'action': inputs[1]}

    # Create an observation network.
    obs_network = (
        create_fc_network(obs_fc_layer_units) if obs_fc_layer_units else create_identity_layer()
    )

    # Create an action network.
    action_network = (
        create_fc_network(action_fc_layer_units) if action_fc_layer_units else create_identity_layer()
    )

    # Create a joint network.
    joint_network = (
        create_fc_network(joint_fc_layer_units) if joint_fc_layer_units else create_identity_layer()
    )

    # Final layer.
    value_layer = tf.keras.layers.Dense(1, kernel_initializer='glorot_uniform')

    return sequential.Sequential(
        [
            tf.keras.layers.Lambda(split_inputs),
            nest_map.NestMap({'observation': obs_network, 'action': action_network}),
            nest_map.NestFlatten(),
            tf.keras.layers.Concatenate(),
            joint_network,
            value_layer,
            inner_reshape.InnerReshape(current_shape=[1], new_shape=[]),
        ],
        name='sequential_critic',
    )



In [11]:

class _TanhNormalProjectionNetworkWrapper(
        tanh_normal_projection_network.TanhNormalProjectionNetwork
):
    """Wrapper to pass predefined `outer_rank` to underlying projection net."""

    def __init__(self, sample_spec, predefined_outer_rank=1):
        super(_TanhNormalProjectionNetworkWrapper, self).__init__(sample_spec)
        self.predefined_outer_rank = predefined_outer_rank

    def call(self, inputs, network_state=(), **kwargs):
        kwargs['outer_rank'] = self.predefined_outer_rank
        if 'step_type' in kwargs:
            del kwargs['step_type']
        return super(_TanhNormalProjectionNetworkWrapper, self).call(inputs, **kwargs)


def create_sequential_actor_network(actor_fc_layers, action_tensor_spec):
    """Create a sequential actor network."""

    def tile_as_nest(non_nested_output):
        return tf.nest.map_structure(
                lambda _: non_nested_output, action_tensor_spec
        )

    return sequential.Sequential(
            [dense(num_units) for num_units in actor_fc_layers]
            + [tf.keras.layers.Lambda(tile_as_nest)]
            + [nest_map.NestMap(tf.nest.map_structure(_TanhNormalProjectionNetworkWrapper, 
                                                      action_tensor_spec))])

Set the configuration parameters for the SAC Agent

In [12]:
# @title Set the RL Agent's parameters

# Actor network fully connected layers.
actor_fc_layers = (128, 128)
# Critic network observation fully connected layers.
critic_obs_fc_layers = (128, 64)
# Critic network action fully connected layers.
critic_action_fc_layers = (128, 64)
# Critic network joint fully connected layers.
critic_joint_fc_layers = (128, 64)

batch_size = 256
actor_learning_rate = 3e-4
critic_learning_rate = 3e-4
alpha_learning_rate = 3e-4
gamma = 0.99
target_update_tau= 0.005
target_update_period= 1
reward_scale_factor = 1.0

# Replay params
replay_capacity = 1000000
debug_summaries = True
summarize_grads_and_vars = True


## Initialize the SAC agent

Of all the Reinforcement learning algorithms, we have chosen [Soft Actor Cirtic (SAC)](https://arxiv.org/abs/1801.01290) because its proven performance on evironments with  high-dimensional states and real-valued actions.

In this notebook we illustrate the use of the buidling control environment using the SAC implementation in [TF-Agents](https://www.tensorflow.org/agents).

In [13]:
# @title Construct the SAC agent

_, action_tensor_spec, time_step_tensor_spec = spec_utils.get_tensor_specs(
    collect_env
)

actor_net = create_sequential_actor_network(
    actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec
)

critic_net = create_sequential_critic_network(
    obs_fc_layer_units=critic_obs_fc_layers,
    action_fc_layer_units=critic_action_fc_layers,
    joint_fc_layer_units=critic_joint_fc_layers,
)


NameError: name 'collect_env' is not defined

In [14]:
train_step = train_utils.create_train_step()
agent = sac_agent.SacAgent(
    time_step_tensor_spec,
    action_tensor_spec,
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=tf.keras.optimizers.Adam(learning_rate=actor_learning_rate),
    critic_optimizer=tf.keras.optimizers.Adam(
        learning_rate=critic_learning_rate
    ),
    alpha_optimizer=tf.keras.optimizers.Adam(learning_rate=alpha_learning_rate),
    target_update_tau=target_update_tau,
    target_update_period=target_update_period,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=gamma,
    reward_scale_factor=reward_scale_factor,
    gradient_clipping=None,
    debug_summaries=debug_summaries,
    summarize_grads_and_vars=summarize_grads_and_vars,
    train_step_counter=train_step,
)
agent.initialize()

NameError: name 'time_step_tensor_spec' is not defined

Below we construct a replay buffer using reverb. The replay buffer is populated with state-action-reward-state tuples during collect. Thie allows the agent to relive past experiences, and prevents the model from overfitting in the local neighborhood.

During traning, the agent samples from the replay buffer. This helps decorrelate the training data in a way that randomization of a training set would in supervised learning. Otherwise, in most environments the experience in a window of time is highly correlated.

In [24]:
# @title Set up the replay buffer
replay_capacity = 50000
table_name = 'uniform_table'
table = reverb.Table(
    table_name,
    max_size=replay_capacity,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
)

reverb_checkpoint_dir = output_data_path + "/reverb_checkpoint"
reverb_port = None
print('reverb_checkpoint_dir=%s' %reverb_checkpoint_dir)
reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(
    path=reverb_checkpoint_dir
)
reverb_server = reverb.Server(
    [table], port=reverb_port, checkpointer=reverb_checkpointer
)
logging_info('reverb_server_port=%d' %reverb_server.port)
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    agent.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server,
)
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
    reverb_replay.py_client, table_name, sequence_length=2, stride_length=1
)
print('num_frames in replay buffer=%d' %reverb_replay.num_frames())

reverb_checkpoint_dir=/burg/home/ssa2206/sbsim_dual_control/output/reverb_checkpoint


[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /burg/home/ssa2206/sbsim_dual_control/output/reverb_checkpoint.
[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /burg/home/ssa2206/sbsim_dual_control/output/reverb_checkpoint
[reverb/cc/platform/default/server.cc:71] Started replay server on port 46071


reverb_server_port=46071
num_frames in replay buffer=0


For simplicity, we'll grab eval and collact policies and give them short variable names.

In [25]:
# @title Access the eval and collect policies
eval_policy = agent.policy
collect_policy = agent.collect_policy

In the next section we define observer classes that enable printing model and environment output as the scenario evolves to who you the percentage of the episode, the timestamp in the scenario, cumulative reward, and the execution time.

We also provide a plot observer that periodically outputs the performance charts and the temperature gradient across both floors of the buidling.

In [26]:
# @title Define Observers
class RenderAndPlotObserver:
    """Renders and plots the environment."""
    def __init__(self, render_interval_steps: int = 10, environment=None,):
        self._counter = 0
        self._render_interval_steps = render_interval_steps
        self._environment = environment
        self._cumulative_reward = 0.0
        self._start_time = None
        if self._environment is not None:
            self._num_timesteps_in_episode = (self._environment._num_timesteps_in_episode)
            self._environment._end_timestamp

    def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:
        reward = trajectory.reward
        self._cumulative_reward += reward
        self._counter += 1
        if self._start_time is None:
            self._start_time = pd.Timestamp.now()

        if self._counter % self._render_interval_steps == 0 and self._environment:
            execution_time = pd.Timestamp.now() - self._start_time
            mean_execution_time = execution_time.total_seconds() / self._counter
            clear_output(wait=True)
            if self._environment._metrics_path is not None:
                reader = get_latest_episode_reader(self._environment._metrics_path)
                plot_timeseries_charts(reader, time_zone)

            render_env(self._environment)

class PrintStatusObserver:
    """Prints status information."""

    def __init__(self, status_interval_steps: int = 1, environment=None, replay_buffer=None):
        self._counter = 0
        self._status_interval_steps = status_interval_steps
        self._environment = environment
        self._cumulative_reward = 0.0
        self._replay_buffer = replay_buffer

        self._start_time = None
        if self._environment is not None:
            self._num_timesteps_in_episode = (
                    self._environment._num_timesteps_in_episode
            )
            self._environment._end_timestamp

    def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:

        reward = trajectory.reward
        self._cumulative_reward += reward
        self._counter += 1
        if self._start_time is None:
            self._start_time = pd.Timestamp.now()

        if self._counter % self._status_interval_steps == 0 and self._environment:
            execution_time = pd.Timestamp.now() - self._start_time
            mean_execution_time = execution_time.total_seconds() / self._counter

            sim_time = self._environment.current_simulation_timestamp.tz_convert(
                    time_zone
            )
            percent_complete = int(
                    100.0 * (self._counter / self._num_timesteps_in_episode)
            )

            if self._replay_buffer is not None:
                rb_size = self._replay_buffer.num_frames()
                rb_string = " Replay Buffer Size: %d" % rb_size
            else:
                rb_string = ""

            print(
                    "Step %5d of %5d (%3d%%) Sim Time: %s Reward: %2.2f Cumulative"
                    " Reward: %8.2f Execution Time: %s Mean Execution Time: %3.2fs %s"
                    % (
                            self._environment._step_count,
                            self._num_timesteps_in_episode,
                            percent_complete,
                            sim_time.strftime("%Y-%m-%d %H:%M"),
                            reward,
                            self._cumulative_reward,
                            execution_time,
                            mean_execution_time,
                            rb_string,
                    )
            )


initial_collect_render_plot_observer = RenderAndPlotObserver(
    render_interval_steps=144, environment=initial_collect_env
)
initial_collect_print_status_observer = PrintStatusObserver(
    status_interval_steps=1,
    environment=initial_collect_env,
    replay_buffer=reverb_replay,
)
collect_render_plot_observer = RenderAndPlotObserver(
    render_interval_steps=144, environment=collect_env
)
collect_print_status_observer = PrintStatusObserver(
    status_interval_steps=1,
    environment=collect_env,
    replay_buffer=reverb_replay,
)
eval_render_plot_observer = RenderAndPlotObserver(
    render_interval_steps=144, environment=eval_env
)
eval_print_status_observer = PrintStatusObserver(
    status_interval_steps=1, environment=eval_env, replay_buffer=reverb_replay
)



In the following cell, we shall run the baseline control on the scenario to populate the replay buffer. We will use the schedule policy we build above to simulate training off-policy from recorded telemetry.

In [27]:
# @title Populate the replay buffer with data from baseline control
# initial_collect_actor = actor.Actor(
#   initial_collect_env,
#   schedule_policy,
#   train_step,
#   steps_per_run=initial_collect_env._num_timesteps_in_episode,
#   observers=[rb_observer, initial_collect_print_status_observer, initial_collect_render_plot_observer])
# initial_collect_actor.run()
# reverb_replay.py_client.checkpoint()

Next wrap the replay buffer into a TF dataset.

In [28]:
# @title Make a TF Dataset
# Dataset generates trajectories with shape [Bx2x...]
dataset = reverb_replay.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2).prefetch(50)

dataset

<_PrefetchDataset element_spec=(Trajectory(
{'action': TensorSpec(shape=(256, 2, 2), dtype=tf.float32, name=None),
 'discount': TensorSpec(shape=(256, 2), dtype=tf.float32, name=None),
 'next_step_type': TensorSpec(shape=(256, 2), dtype=tf.int32, name=None),
 'observation': TensorSpec(shape=(256, 2, 53), dtype=tf.float32, name=None),
 'policy_info': (),
 'reward': TensorSpec(shape=(256, 2), dtype=tf.float32, name=None),
 'step_type': TensorSpec(shape=(256, 2), dtype=tf.int32, name=None)}), SampleInfo(key=TensorSpec(shape=(256, 2), dtype=tf.uint64, name=None), probability=TensorSpec(shape=(256, 2), dtype=tf.float64, name=None), table_size=TensorSpec(shape=(256, 2), dtype=tf.int64, name=None), priority=TensorSpec(shape=(256, 2), dtype=tf.float64, name=None), times_sampled=TensorSpec(shape=(256, 2), dtype=tf.int32, name=None)))>

Here, we extract the collect and evaluation policies for training.

In [29]:
# @title Convert the policies into TF Eager Policies

tf_collect_policy = agent.collect_policy
agent_collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True)

tf_policy = agent.policy
agent_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy, use_tf_function=True)

We will set the interval of saving the policies and writing critic, actor, and alphs losses.


In [30]:
policy_save_interval = 1 # Save the policy after every learning step.
learner_summary_interval = 1 # Produce a summary of the critic, actor, and alpha losses after every gradient update step.

In the following cell we will define the agent learner, a TF-Agents wrapper around the process that performs gradiant-based updates to the actor and critic networks in the agent.

You should see a statememt that shows you where the policies will be saved to.

In [31]:
# @title Define an Agent Learner

experience_dataset_fn = lambda: dataset

saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
print('Policies will be saved to saved_model_dir: %s' % saved_model_dir)
env_step_metric = py_metrics.EnvironmentSteps()
learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric},
      ),
      triggers.StepPerSecondLogTrigger(train_step, interval=10),
]

agent_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=None,
      summary_interval=learner_summary_interval,
)


Policies will be saved to saved_model_dir: /burg/home/ssa2206/sbsim_dual_control/root/policies




INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/policy/assets


INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/policy/assets


INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/collect_policy/assets


INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/collect_policy/assets


INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/greedy_policy/assets


INFO:tensorflow:Assets written to: /burg/home/ssa2206/sbsim_dual_control/root/policies/greedy_policy/assets


Set the number of training steps in a training iteration. This is the number of collect steps between gradient updates.

Here we set the number of training steps to the length of a full episode.

In [32]:
collect_steps_per_treining_iteration = collect_env._num_timesteps_in_episode

Next, we will define a *collect actor* and an *eval actor* that wrap the policy and the environment, and can execute and collect metrics.

The principal difference between the collect actor and the eval actor, is that the collect actor will choose actions by drawing off the actor network distribution, choosing actions that have a high probability over actions with lower probability. This stochastic property enables the agent explore bettwer actions and improve the policy.

However, the eval actor always chooses the action associated with the highest probability.

In [25]:
# @title Define a TF-Agents Actor for collect and eval
tf_collect_policy = agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
    tf_collect_policy, use_tf_function=True
)
collect_actor = actor.Actor(
    collect_env,
    collect_policy,
    train_step,
    steps_per_run=collect_steps_per_treining_iteration,
    metrics=actor.collect_metrics(1),
    summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    summary_interval=1,
    observers=[
        rb_observer,
        env_step_metric,
        collect_print_status_observer,
        collect_render_plot_observer,
    ],
)

tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
    tf_greedy_policy, use_tf_function=True
)

eval_actor = actor.Actor(
    eval_env,
    eval_greedy_policy,
    train_step,
    episodes_per_run=1,
    metrics=actor.eval_metrics(1),
    summary_dir=os.path.join(root_dir, 'eval'),
    summary_interval=1,
    observers=[rb_observer, eval_print_status_observer, eval_render_plot_observer],
)

# Define the World Model




In [28]:
class TransformerBlock(tf.keras.layers.Layer):
    """Transformer encoder block."""
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation="relu"),
            tf.keras.layers.Dense(embed_dim),
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, inputs, training=False):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

class TransformerWorldModel(tf.keras.Model):
    """Transformer-based world model for predicting next states and rewards."""
    def __init__(
        self, 
        state_dim, 
        action_dim, 
        hidden_dim=256, 
        num_heads=8,
        num_transformer_blocks=3,
        ff_dim=512,
        dropout_rate=0.1,
        name="transformer_world_model"
    ):
        super().__init__(name=name)
        
        self.state_embedding = tf.keras.layers.Dense(
            hidden_dim, 
            name="state_embedding"
        )
        self.action_embedding = tf.keras.layers.Dense(
            hidden_dim, 
            name="action_embedding"
        )
        
        self.position_embedding = tf.keras.layers.Embedding(
            input_dim=50,
            output_dim=hidden_dim,
            name="position_embedding"
        )
        
        self.transformer_blocks = [
            TransformerBlock(hidden_dim, num_heads, ff_dim, dropout_rate)
            for _ in range(num_transformer_blocks)
        ]
        
        self.next_state_head = tf.keras.layers.Dense(
            state_dim, 
            name="next_state_prediction"
        )
        self.reward_head = tf.keras.layers.Dense(
            1, 
            name="reward_prediction"
        )
        
        self.uncertainty_head = tf.keras.layers.Dense(
            state_dim, 
            name="uncertainty_estimation"
        )
        
    def call(self, inputs, training=False):
        states, actions = inputs
        
        positions = tf.range(start=0, limit=tf.shape(states)[1], delta=1)
        positions = tf.expand_dims(positions, axis=0)
        
        state_emb = self.state_embedding(states)  # [batch_size, seq_len, hidden_dim]
        action_emb = self.action_embedding(actions)  # [batch_size, seq_len, hidden_dim]
        pos_emb = self.position_embedding(positions)  # [1, seq_len, hidden_dim]
        
        x = state_emb + action_emb + pos_emb
        
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        
        next_states = self.next_state_head(x)
        rewards = self.reward_head(x)
        uncertainties = self.uncertainty_head(x)
        
        return next_states, rewards, uncertainties
    
    def train_step(self, data):
        states, actions, next_states, rewards = data
        
        with tf.GradientTape() as tape:
            pred_next_states, pred_rewards, uncertainties = self((states, actions), training=True)

            state_loss = tf.reduce_mean(tf.square(next_states - pred_next_states))
            reward_loss = tf.reduce_mean(tf.square(rewards - pred_rewards))
            
            uncertainty_loss = tf.reduce_mean(
                tf.exp(-uncertainties) * tf.square(next_states - pred_next_states) + uncertainties
            )
            
            total_loss = state_loss + reward_loss + 0.1 * uncertainty_loss
        
        grads = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        
        return {
            "loss": total_loss,
            "state_loss": state_loss,
            "reward_loss": reward_loss,
            "uncertainty_loss": uncertainty_loss
        }

Modify the replay buffer to use the model

In [29]:
class ModelAssistedReplayBuffer:
    """Replay buffer augmented with model-generated rollouts."""
    def __init__(
        self, 
        real_buffer: reverb.Client,
        world_model: TransformerWorldModel,
        rollout_length: int = 5,
        rollout_ratio: float = 0.5,
        uncertainty_threshold: float = 0.5
    ):
        self.real_buffer = real_buffer
        self.world_model = world_model
        self.rollout_length = rollout_length
        self.rollout_ratio = rollout_ratio
        self.uncertainty_threshold = uncertainty_threshold
        
    def generate_rollouts(self, initial_state, initial_action, policy):
        """Generate model-based rollouts with uncertainty estimation."""
        current_state = tf.convert_to_tensor([initial_state], dtype=tf.float32)
        current_action = tf.convert_to_tensor([initial_action], dtype=tf.float32)
        
        generated_experience = []
        
        for _ in range(self.rollout_length):
            next_state, reward, uncertainty = self.world_model(
                (current_state, current_action),
                training=False
            )
            
            if tf.reduce_mean(uncertainty) > self.uncertainty_threshold:
                break
                
            time_step = ts.transition(
                next_state[0], 
                reward[0][0],
                discount=1.0
            )
            action_step = policy.action(time_step)
            next_action = action_step.action
            
            generated_experience.append({
                'state': current_state[0],
                'action': current_action[0],
                'reward': reward[0][0],
                'next_state': next_state[0],
                'uncertainty': uncertainty[0]
            })
            
            current_state = next_state
            current_action = tf.convert_to_tensor([next_action], dtype=tf.float32)
            
        return generated_experience

Define function to train the world model

In [30]:
def train_world_model(
    world_model: TransformerWorldModel,
    replay_buffer: reverb.Client,
    batch_size: int = 256,
    training_steps: int = 1000
):
    """Train world model on real experience."""
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    world_model.compile(optimizer=optimizer)
    
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2
    ).prefetch(3)
    
    losses = []
    for step, experience in enumerate(dataset.take(training_steps)):
        states = experience.observation[:, 0]
        actions = experience.action[:, 0]
        next_states = experience.observation[:, 1]
        rewards = experience.reward[:, 0]
        
        metrics = world_model.train_step((states, actions, next_states, rewards))
        losses.append(metrics['loss'])
        
        if step % 100 == 0:
            avg_loss = tf.reduce_mean(losses[-100:])
            logging_info(f'World Model Training Step {step}, Avg Loss: {avg_loss:.4f}')
            
    return tf.reduce_mean(losses)

Finally we're ready to execute the RL traiing loop with SAC!

You can sepcify the total number of trainng iterations, and the number of gradient steps per iteration. With fewer steps, the model will train more slowly, but more steps may make the agent less stable.

In [33]:
# @title Execute the training loop

num_training_iterations = 10
num_gradient_updates_per_training_iteration = 100

# Collect the performance results with teh untrained model.
eval_actor.run_and_log()

logging_info('Training.')

world_model = TransformerWorldModel(
    state_dim=collect_env.observation_spec().shape[0],
    action_dim=collect_env.action_spec().shape[0],
    hidden_dim=256,
    num_heads=8,
    num_transformer_blocks=3
)

model_buffer = ModelAssistedReplayBuffer(
    real_buffer=reverb_replay,
    world_model=world_model,
    rollout_length=5,
    rollout_ratio=0.5,
    uncertainty_threshold=0.5
)


# log_dir = root_dir + '/train'
# with tf.summary.create_file_writer(log_dir).as_default() as writer:   

for iter in range(num_training_iterations):
    print('Training iteration: ', iter)
    
    collect_actor.run()
    
    world_model_loss = train_world_model(
        world_model=world_model,
        replay_buffer=reverb_replay,
        batch_size=256,
        training_steps=1000
    )
    logging_info(f'World Model Loss: {world_model_loss:.4f}')

    real_batch = next(iter(reverb_replay.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2
    )))
    
    initial_states = real_batch.observation[:, 0]
    initial_actions = real_batch.action[:, 0]
    
    for i in range(len(initial_states)):
        generated_experience = model_buffer.generate_rollouts(
            initial_states[i],
            initial_actions[i],
            agent_collect_policy
        )
        
        for exp in generated_experience:
            rb_observer(trajectory.from_episode(
                observation=exp['state'],
                action=exp['action'],
                policy_info=(),
                reward=exp['reward'],
                discount=1.0
            ))
    
    loss_info = agent_learner.run(
        iterations=num_gradient_updates_per_training_iteration
    )
    
    logging_info(
        'Actor Loss: %6.2f, Critic Loss: %6.2f, Alpha Loss: %6.2f'
        % (
            loss_info.extra.actor_loss.numpy(),
            loss_info.extra.critic_loss.numpy(),
            loss_info.extra.alpha_loss.numpy(),
        )
    )
    
    eval_env.reset()
    eval_actor.run_and_log()

rb_observer.close()
reverb_server.stop()

NameError: name 'eval_actor' is not defined