In [None]:
from online.Environment import BatteryScheduling
from online.EnergyDataset import EnergyDataset
from torchrl.envs.utils import check_env_specs
from hydra import initialize, compose
from torchrl.modules import MLP, Actor,OrnsteinUhlenbeckProcessModule
import torch
from torchrl.envs import (
    CatTensors,
    TransformedEnv,
    UnsqueezeTransform,
    Compose,
    InitTracker,
)

from tensordict.nn import TensorDictModule, TensorDictSequential

from torchrl.objectives import ValueEstimators, SoftUpdate, DDPGLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, ReplayBuffer, RandomSampler

In [None]:
ds_train = EnergyDataset('../data/1_processed/energy.csv', '../data/1_processed/price.csv', 10, 1, 'train')
ds_test = EnergyDataset('../data/1_processed/energy.csv', '../data/1_processed/price.csv', 10, 1, 'test')

In [None]:
def makeEnv(cfg, seed, ds):
    env = BatteryScheduling(cfg, 42, ds)
    env_transformed = TransformedEnv(env, 
                                     Compose(InitTracker(),
                                             UnsqueezeTransform(dim=-1, 
                                                                in_keys=['soe', 'prosumption', 'price', 'cost'],
                                                                in_keys_inv=['soe', 'prosumption', 'price', 'cost'],),
                                            CatTensors(dim=-1,
                                                        in_keys=['time_feature', 'soe', 'prosumption', 'prosumption_forecast', 'price', 'price_forecast'],
                                                        out_key='observation',
                                                        del_keys=False),
                                     )
                                    )
    return env_transformed

In [None]:
with initialize(version_base=None, config_path="conf/"):
    cfg = compose(config_name='config.yaml')
    env_train = makeEnv(cfg, 42, ds_train)
  
    # check_env_specs(env_train)


    policy_net = MLP(
        in_features=env_train.observation_spec['observation'].shape[-1],
        out_features=env_train.action_spec.shape.numel(),
        depth=2,
        num_cells=[400,300],
        activation_class=torch.nn.ReLU,
    )

    policy_module = TensorDictModule(
        module=policy_net,
        in_keys=['observation'],
        out_keys=['action']
    )

    actor = Actor(
        module=policy_module,
        spec=env_train.full_action_spec['action'],
        in_keys=['observation'],
        out_keys=['action'],
    )


    ou = OrnsteinUhlenbeckProcessModule(
        annealing_num_steps=25_000,
        n_steps_annealing=25_000,
        spec=actor.spec.clone(),
    )

    exploration_policy = TensorDictSequential(
        actor,
        ou
    )


    critic_module = TensorDictModule(
        module=MLP(
            in_features=env_train.observation_spec['observation'].shape[-1] + env_train.full_action_spec['action'].shape.numel(),
            out_features=1,
            depth=2,
            num_cells=[400,300],
            activation_class=torch.nn.ReLU,
        ),
        in_keys=['observation', 'action'],
        out_keys=['state_action_value']
    )

    collector = SyncDataCollector(create_env_fn=makeEnv(cfg, 42, ds_train),
                                  policy=exploration_policy,
                                  frames_per_batch=100,
                                  total_frames=100_000,)
    
    replay_buffer = ReplayBuffer(
        storage=LazyMemmapStorage(
            max_size=250_000,  # We will store up to memory_size transitions
        ),  # We will store up to memory_size multi-agent transitions
        sampler=RandomSampler(),
        batch_size=32,  # We will sample batches of this size
    )

    loss_module = DDPGLoss(
        actor_network=actor,
        value_network=critic_module,
        delay_value=True,
    )

    loss_module.make_value_estimator(ValueEstimators.TD0, gamma=0.99)

    target_updater = SoftUpdate(loss_module,tau=0.001)

    optimisers = {
        "loss_actor": torch.optim.Adam(
            loss_module.actor_network.parameters(), lr=1e-4
        ),
        "loss_value": torch.optim.Adam(
            loss_module.value_network.parameters(), lr=1e-3
        ),
    }


    for iteration, batch in enumerate(collector):
        current_frames = batch.numel()
        exploration_policy[-1].step(current_frames)
        replay_buffer.extend(batch)

        # Train for train_iterations_per_frame iterations per frame
        for i in range(1):
            sample = replay_buffer.sample()
            loss_vals = loss_module(sample)
            for loss_name in ["loss_actor", "loss_value"]:
                loss = loss_vals[loss_name]
                loss.backward()
                optimiser = optimisers[loss_name]
                optimiser.step()
                optimiser.zero_grad()

            # if (iteration*32+i) % 5 == 0:
            target_updater.step()
    

In [None]:
env_test = makeEnv(cfg, 42, ds_test)
env_test.reset()
tensordict_result = env_test.rollout(max_steps=10000000, policy=actor)

In [None]:
tensordict_result[-1]['next', 'cost']

In [None]:
import pandas as pd
pd.DataFrame(tensordict_result[0:100]['soe'].detach().numpy()).plot()

In [None]:
pd.DataFrame(tensordict_result[5700:5800]['action'].detach().numpy()).plot()

In [None]:
pd.DataFrame(tensordict_result[5700:5800]['price'].detach().numpy()).plot()

In [None]:
pd.DataFrame(tensordict_result[5700:5800]['next', 'reward'].detach().numpy()).plot()