<a href="https://colab.research.google.com/github/rsarpongstreetor/Environment-customenv/blob/main/DDataenv_Class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Benchmarl action prefrence
import ipdb
# Instead of importing private functions, use public functions if available
from typing import Optional, Any, Dict, List, Union
from typing_extensions import TypedDict
import numpy as np
import pandas as pd
import os
import random
import torch

# TorchRL imports

from torchrl.data import (
    Unbounded,
    Composite,
    Bounded,
    Binary,
    Categorical,
    UnboundedContinuousTensorSpec,
    Composite ,
    DiscreteTensorSpec,
    MultiCategorical  # Make sure to import this class
)


from torchrl.envs import (
    Compose,
    DoubleToFloat,
    ObservationNorm,
    StepCounter,
    TransformedEnv,
    EnvBase
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import (
    check_env_specs,
    ExplorationType,
    set_exploration_type,
    step_mdp
)
from torchrl.modules import (
    ProbabilisticActor,
    TanhNormal,
    ValueOperator,
    MultiAgentConvNet
)
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

# Other imports
import pygame
from gym.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete
import math
import matplotlib.pyplot as plt
from collections import namedtuple, deque, defaultdict
from itertools import count
import plotly.express as px
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
import pyvirtualdisplay

import multiprocessing
from typing import Generic, TypeVar
from typing import Dict # Import Dict from typing instead of gym.spaces
import matplotlib.pyplot as plt
import networkx as nx  # For graph visualization


class DDataenv(Generic[T]):
    def __init__(self, data_path: str, data_columns: List[str], data_type: Any = np.float32, allow_repeat: bool = True):
        self.data_path = data_path
        self.data_columns = data_columns
        self.data_type = data_type
        self.data = None
        self.current_index = 0  # Initialize current_index here
        self.allow_repeat = allow_repeat


    def load_data(self) -> pd.DataFrame:
        if not os.path.exists(self.data_path):
            raise FileNotFoundError(f"Data file not found at {self.data_path}")

        with open(self.data_path, 'rb') as f:
            self.data = torch.load(f, weights_only=False)

        self.data = np.array(self.data)
        if len(self.data.shape) >= 3:
            self.data = self.data.reshape(self.data.shape[1], self.data.shape[2])

        if not isinstance(self.data, pd.DataFrame):
            self.data = pd.DataFrame(self.data, columns=self.data_columns)

        return self.data

    def get_observation(self) -> Dict[str, Union[np.ndarray, Dict[str, float]]]:
        if self.data is None:
            self.load_data()

        # Sample with replacement if allow_repeat is True
        if self.allow_repeat:
            self.current_index = random.randint(0, len(self.data) - 1)
        else:
            # Reset the index if it exceeds the data size
            if self.current_index >= len(self.data):
                self.current_index = 0

        # Get observation from the data
        observation = self.data.iloc[self.current_index, :].to_numpy().astype(self.data_type)

        # Increment the index for the next observation
        self.current_index += 1

        describe_data = self.data.describe()

        observation_dict = {
            'obsState&Fuel': observation[0:13],
            'Date': observation[-1],
            'rewardState&reward': observation[13:26],
            'actionState&action': observation[26:39],
            'obsState&Fuel_max': describe_data.loc['max'][0:13].values,
            'obsState&Fuel_min': describe_data.loc['min'][0:13].values,
            'Date_max': describe_data['Date'].max(),
            'Date_min': describe_data['Date'].min(),
            'rewardState&reward_max': describe_data.loc['max'][13:26].values,
            'rewardState&reward_min': describe_data.loc['min'][13:26].values,
            'actionState&action_max': describe_data.loc['max'][26:39].values,
            'actionState&action_min': describe_data.loc['min'][26:39].values,
        }
        return observation_dict