# electricity_market_env

> Fill in a module description here

In [None]:
#| default_exp electricity_market_env

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import gymnasium as gym
import numpy as np

In [None]:
#| export
class ElectricityMarketEnv(gym.Env):
    def __init__(self, env_config):
        # Environment Configuration
        self._battery_capacity = env_config['battery_capacity']
        self._config = env_config

        self.action_space = gym.spaces.Box(low=-self._battery_capacity, high=self._battery_capacity, shape=(1,), dtype=np.float64)
        self.observation_space = gym.spaces.Box(low=np.array([0, 0, 0]),
                                                high=np.array([self._battery_capacity, np.inf, np.inf]),
                                                shape=(3,), dtype=np.float64)


        # State of Environment
        self._current_state_of_charge = env_config['init_state_of_charge']
        self._current_demand_of_electricity = env_config['init_current_demand_of_electricity']
        self._current_price = env_config['init_current_price']

    def _is_action_valid(self, action) -> bool:
        target_state_of_charge = self._current_state_of_charge + action
        return target_state_of_charge > self._battery_capacity or target_state_of_charge < 0

    def step(self, action):
        if not self._is_action_valid(action):
            raise ValueError('Invalid action')

        reward = self._calculate_reward(action)

        self._current_state_of_charge += action
        self._current_demand_of_electricity = self._yield_demand_of_electricity()
        self._current_price = self._yield_price()


        done = False
        observation = self._get_obs()

        return observation, reward, done, {}

    def _calculate_reward(self, action) -> float:
        return self._current_price * min(self._current_demand_of_electricity, action)

    def _yield_demand_of_electricity(self) -> float:
        raise NotImplementedError()

    def _yield_price(self) -> float:
        raise NotImplementedError()

    def reset(self, seed=None, options=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        self._current_state_of_charge = self._config['init_state_of_charge']
        self._current_demand_of_electricity = self._config['init_current_demand_of_electricity']
        self._current_price = self._config['init_current_price']

        observation = self._get_obs()
        return observation

    def _get_obs(self):
        return np.array([self._current_state_of_charge, self._current_demand_of_electricity, self._current_price])



In [None]:
#| hide
env = ElectricityMarketEnv({
    "battery_capacity": 100,
    "init_state_of_charge": 100,
    "init_current_demand_of_electricity": 0,
    "init_current_price": 10
})

env.reset()

env.step(100)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()