In [1]:
from online.Environment import BatteryScheduling
from online.EnergyDataset import EnergyDataset
from torchrl.envs.utils import check_env_specs
from hydra import initialize, compose
from torchrl.modules import MLP, EGreedyModule, QValueModule, Actor,OrnsteinUhlenbeckProcessModule
import torch
from torchrl.envs import (
    CatTensors,
    EnvBase,
    Transform,
    TransformedEnv,
    UnsqueezeTransform,
    Compose,
    InitTracker
)

from tensordict.nn import TensorDictModule, TensorDictSequential

from torchrl.objectives import DQNLoss, ValueEstimators, SoftUpdate, DDPGLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, ReplayBuffer, RandomSampler

In [2]:
ds = EnergyDataset('../data/1_processed/energy.csv', '../data/1_processed/price.csv', 48, 1, 'train')

In [3]:
torch.unsqueeze(torch.tensor([1, 2, 3]), -1)

tensor([[1],
        [2],
        [3]])

In [4]:
from tensordict import TensorDict, TensorDictBase

In [5]:
test = TensorDict({
    'bla': 0,
})
test['bla'] = 1
test['bla']

tensor(1)

In [6]:
torch.ones(1), [1]

(tensor([1.]), [1])

In [7]:
def makeEnv(cfg, seed, ds):
    env = BatteryScheduling(cfg, 42, ds)
    env_transformed = TransformedEnv(env, 
                                     Compose(InitTracker(),
                                             UnsqueezeTransform(dim=-1, 
                                                                in_keys=['soe', 'prosumption', 'price', 'cost'],
                                                                in_keys_inv=['soe', 'prosumption', 'price', 'cost'],),
                                            CatTensors(dim=-1,
                                                        in_keys=['soe', 'prosumption', 'price', 'price_forecast', 'cost'],
                                                        out_key='observation',
                                                        del_keys=False)
                                             )
                                    )
    return env_transformed

In [13]:
with initialize(version_base=None, config_path="conf/"):
    cfg = compose(config_name='config.yaml')
    env_transformed = makeEnv(cfg, 42, ds)
    # check_env_specs(env_transformed)


    policy_net = MLP(
        in_features=env_transformed.observation_spec['observation'].shape[-1],
        out_features=env_transformed.action_spec.shape.numel(),
        depth=2,
        num_cells=[400,300],
        activation_class=torch.nn.Tanh,
    )

    policy_module = TensorDictModule(
        module=policy_net,
        in_keys=['observation'],
        out_keys=['action']
    )

    policy = Actor(
        module=policy_module,
        spec=env_transformed.full_action_spec['action'],
        in_keys=['observation'],
        out_keys=['action'],
    )


    exploration_policy = TensorDictSequential(
        policy,
        OrnsteinUhlenbeckProcessModule(
            spec=policy.spec,
        )
    )


    critic_module = TensorDictModule(
        module=MLP(
            in_features=env_transformed.observation_spec['observation'].shape[-1] + env_transformed.full_action_spec['action'].shape.numel(),
            out_features=1,
            depth=2,
            num_cells=[400,300],
            activation_class=torch.nn.Tanh,
        ),
        in_keys=['observation', 'action'],
        out_keys=['state_action_value']
    )

    collector = SyncDataCollector(create_env_fn=makeEnv(cfg, 42, ds),
                                  policy=exploration_policy,
                                  frames_per_batch=100,
                                  total_frames=100,)
    
    replay_buffer = ReplayBuffer(
        storage=LazyMemmapStorage(
            max_size=1000000,  # We will store up to memory_size transitions
        ),  # We will store up to memory_size multi-agent transitions
        sampler=RandomSampler(),
        batch_size=1,  # We will sample batches of this size
    )
    # print(collector)
    for iteration, batch in enumerate(collector):
        print(batch['action'][0])
        print(batch['observation'][1])


tensor([0.1123])
tensor([-0.5595,  0.3062,  0.2998,  0.2998,  0.2805,  0.2805,  0.2770,  0.2770,
         0.2836,  0.2836,  0.2867,  0.2867,  0.3532,  0.3532,  0.4202,  0.4202,
         0.4154,  0.4154,  0.3690,  0.3690,  0.3422,  0.3422,  0.3199,  0.3199,
         0.3004,  0.3004,  0.2877,  0.2877,  0.2943,  0.2943,  0.2999,  0.2999,
         0.3004,  0.3004,  0.3494,  0.3494,  0.3728,  0.3728,  0.4258,  0.4258,
         0.4242,  0.4242,  0.3968,  0.3968,  0.3698,  0.3698,  0.2839,  0.2839,
         0.2874,  0.2874,  1.7150,  0.1123])
