In [None]:
import h5py
import numpy as np
import pandas as pd

ras = [10000]
types = ["zero", "ppo", "random"]
splits = ["val", "test", "train"]
channels = ["T", "u", "w"]


def check(check_f):
    for split in splits:
        for ra in ras:
            for t in types:
                path = f"../data/2D-control/{split}/ra{ra}/{t}.h5"
                with h5py.File(path, "r") as file:
                    check_f(file, path)

# Dataset Info

In [3]:
def dataset_info(file, path):
    print(
        f"file {path} has {file.attrs['episodes']} episodes with seed {file.attrs['base_seed']}"
    )


check(dataset_info)

file ../data/2D-control/val/ra10000/zero.h5 has 20 episodes with seed 400
file ../data/2D-control/val/ra10000/ppo.h5 has 20 episodes with seed 400
file ../data/2D-control/val/ra10000/random.h5 has 20 episodes with seed 400
file ../data/2D-control/test/ra10000/zero.h5 has 20 episodes with seed 400
file ../data/2D-control/test/ra10000/ppo.h5 has 20 episodes with seed 400
file ../data/2D-control/test/ra10000/random.h5 has 20 episodes with seed 400
file ../data/2D-control/train/ra10000/zero.h5 has 20 episodes with seed 400
file ../data/2D-control/train/ra10000/ppo.h5 has 20 episodes with seed 400
file ../data/2D-control/train/ra10000/random.h5 has 20 episodes with seed 400


# Validate Temperature

In [5]:
def validate(file, path):
    passed = True
    episodes = file.attrs["episodes"]
    steps = file.attrs["steps"]

    for episode in range(episodes):
        states = file[f"states{episode}"]
        assert len(states) == steps, (
            f"Mismatch in number of steps for episode {episode}: expected {steps}, got {len(states)}"
        )

        for step in range(steps):
            assert states[step].shape == (3, 64, 96), (
                f"Unexpected shape at episode {episode}, step {step}: {states[step].shape}"
            )
            # validate temperature
            min, max = states[step][0].min(), states[step][0].max()
            if min < 1 or max > 2:
                print(
                    f"in file {path} at episode {episode} and step {step} - Temperature: min={min}, max={max}"
                )
                passed = False
    if not passed:
        print(f"File {path} failed validation.")
    else:
        print(
            f"File {path} passed validation with {episodes} episodes and {steps} steps each."
        )


check(validate)

File ../data/2D-control/val/ra10000/zero.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/val/ra10000/ppo.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/val/ra10000/random.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/test/ra10000/zero.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/test/ra10000/ppo.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/test/ra10000/random.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/train/ra10000/zero.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/train/ra10000/ppo.h5 passed validation with 20 episodes and 400 steps each.
File ../data/2D-control/train/ra10000/random.h5 passed validation with 20 episodes and 400 steps each.


# Mean, Min and Max per channel

In [22]:
def get_file_stats(file, path, ra, split):
    data = []
    for episode in range(file.attrs["episodes"]):
        states = np.array(file[f"states{episode}"])

        for ch, name in enumerate(channels):
            # get stats
            mean = states[:, ch, :, :].mean()
            min = states[:, ch, :, :].min()
            max = states[:, ch, :, :].max()
            std = states[:, ch, :, :].std()

            # put to data list
            data.append(
                {
                    "file": path,
                    "ra": ra,
                    "split": split,
                    "episode": episode,
                    "channel": name,
                    "mean": mean,
                    "min": min,
                    "max": max,
                    "std": std,
                }
            )

    return pd.DataFrame(data)


dfs = []
for split in splits:
    for typ in types:
        for ra in ras:
            print(
                f"Calculating mean, min, max for split {split}, type {typ} and ra {ra}"
            )
            path = f"../data/2D-control/{split}/ra{ra}/{typ}.h5"
            with h5py.File(path, "r") as file:
                dfs.append(get_file_stats(file, path, ra, split))

df = pd.concat(dfs, ignore_index=True)
print(df.columns)

Calculating mean, min, max for split val, type zero and ra 10000
Calculating mean, min, max for split val, type ppo and ra 10000
Calculating mean, min, max for split val, type random and ra 10000
Calculating mean, min, max for split test, type zero and ra 10000
Calculating mean, min, max for split test, type ppo and ra 10000
Calculating mean, min, max for split test, type random and ra 10000
Calculating mean, min, max for split train, type zero and ra 10000
Calculating mean, min, max for split train, type ppo and ra 10000
Calculating mean, min, max for split train, type random and ra 10000
Index(['file', 'ra', 'split', 'episode', 'channel', 'mean', 'min', 'max',
       'std'],
      dtype='object')


In [23]:
df_filtered = df[df["ra"] == 10000]
agg = df_filtered.groupby("channel")[["mean", "min", "max", "std"]].mean()
print(agg)

                 mean       min       max       std
channel                                            
T        1.500000e+00  1.009462  1.990526  0.173923
u       -9.312225e-07 -0.618703  0.618899  0.280684
w        5.487214e-11 -0.792265  0.791930  0.343374
