In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from utils.metrics import UnbiasedExponentialMovingAverage as uema

In [2]:
good_days = {
    "E11/zone1": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11],
    "E11/zone2": [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11],
    "E11/zone3": [0, 1, 2, 3, 4, 5, 6, 7, 9, 10],
    "E11/zone4": [1, 2, 3, 4, 5, 6, 7, 9, 10],
    "E11/zone5": [],
    "E11/zone6": [1, 2, 3, 4, 5, 6, 7],
    # maybe add day 9 to zone6, but the daily curve was very noisy
    "E11/zone7": [],
    "E11/zone8": [1, 2, 3, 4, 5, 6, 7],
    "E11/zone9": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12],
    "E11/zone10": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11],
    "E11/zone11": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11],
    "E11/zone12": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11],
    "E12/zone1": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone2": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone3": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone4": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone5": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone6": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone7": [],
    "E12/zone8": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone9": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone10": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone11": [0, 1, 2, 3, 4, 5, 6, 7],
    "E12/zone12": [0, 1, 2, 3, 4, 5, 6, 7],
}
good_zones = {
    "E11": [1, 2, 3, 4, 6, 8, 9, 10, 11, 12],
    "E12": [1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12],
}

In [3]:
def get_daytime_segments(data):
    segments = []
    current_segment = []
    for i, value in enumerate(data):
        if np.sum(data[i : i + 60]) == 0:
            if current_segment:
                segments.append(current_segment)
                current_segment = []
        elif value == 0:
            if current_segment:
                current_segment.append(value)
        else:
            current_segment.append(value)
    if current_segment:
        segments.append(current_segment)
    return segments

In [4]:
def get_action_trace_history(action_segments):
    # label traces by beta = 1 - alpha
    # the three dimensions correspond to red, white, blue
    trace1 = [uema(alpha=0.9), uema(alpha=0.9), uema(alpha=0.9)]
    trace3 = [uema(alpha=0.7), uema(alpha=0.7), uema(alpha=0.7)]
    trace5 = [uema(alpha=0.5), uema(alpha=0.5), uema(alpha=0.5)]
    trace7 = [uema(alpha=0.3), uema(alpha=0.3), uema(alpha=0.3)]
    trace9 = [uema(alpha=0.1), uema(alpha=0.1), uema(alpha=0.1)]

    action_trace_history = []
    for action_seg in action_segments:
        if len(action_seg) > 10:  # avoid short segments (on incomplete final days)
            if action_seg[10] < 10:  # red
                action = [1, 0, 0]
            elif action_seg[10] < 20:  # white
                action = [0, 1, 0]
            else:
                action = [0, 0, 1]

            for i in range(3):
                trace1[i].update(action[i])
                trace3[i].update(action[i])
                trace5[i].update(action[i])
                trace7[i].update(action[i])
                trace9[i].update(action[i])

            action_trace_history.append(
                action
                + [trace1[j].compute().item() for j in range(3)]
                + [trace3[j].compute().item() for j in range(3)]
                + [trace5[j].compute().item() for j in range(3)]
                + [trace7[j].compute().item() for j in range(3)]
                + [trace9[j].compute().item() for j in range(3)]
            )
        else:
            action_trace_history.append([])
    return action_trace_history

In [5]:
EVERY_DATA = []
MEAN_DATA = []
MEDIAN_DATA = []
for exp_id in [11, 12]:
    for zone_id in good_zones[f"E{exp_id}"]:
        if zone_id < 10:
            data_path = f"/home/lolanff/plant-rl/data/online/E{exp_id}/P1/DiscreteRandom{zone_id}/alliance-zone0{zone_id}/raw.csv"
        else:
            data_path = f"/home/lolanff/plant-rl/data/online/E{exp_id}/P1/DiscreteRandom{zone_id}/alliance-zone{zone_id}/raw.csv"
        df = pd.read_csv(data_path)

        actions = df["action.0"].to_numpy()
        actions = np.reshape(actions, (-1, 18))
        action = np.array([np.mean(actions[i, :]) for i in range(actions.shape[0])])
        action_segments = get_daytime_segments(action)
        action_trace_history = get_action_trace_history(action_segments)

        areas = df["clean_area"].to_numpy()
        areas = np.reshape(areas, (-1, 18))
        mean_area = np.array([np.mean(areas[i, :]) for i in range(areas.shape[0])])
        median_area = np.array([np.median(areas[i, :]) for i in range(areas.shape[0])])
        median_segments = get_daytime_segments(median_area)
        mean_segments = get_daytime_segments(mean_area)
        for day in good_days[f"E{exp_id}/zone{zone_id}"]:
            action = action_trace_history[day]
            MEAN_DATA.append(
                [day, mean_segments[day][4]] + action + [mean_segments[day + 1][4]]
            )
            MEDIAN_DATA.append(
                [day, median_segments[day][4]] + action + [median_segments[day + 1][4]]
            )

            for plant_id in range(areas.shape[1]):
                this_area = areas[:, plant_id]
                segments = get_daytime_segments(this_area)
                EVERY_DATA.append(
                    [day, segments[day][4]] + action + [segments[day + 1][4]]
                )

  df = pd.read_csv(data_path)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)
  df = pd.read_csv(data_path)


In [6]:
# remove data points where plant area goes down after 1 full day (e.g. dead plants)
def trim_dead_plants(x):
    trimmed_data = []
    counter = 0
    for row in x:
        if row[-1] >= row[1]:
            trimmed_data.append(row)
        else:
            counter += 1
    print(f"Removed {counter} out of {len(x)} data points")
    return np.vstack(trimmed_data)

In [7]:
np.save("./GP_data/every_size_1day.npy", trim_dead_plants(EVERY_DATA))
np.save("./GP_data/mean_size_1day.npy", trim_dead_plants(MEAN_DATA))
np.save("./GP_data/median_size_1day.npy", trim_dead_plants(MEDIAN_DATA))

Removed 39 out of 3294 data points
Removed 1 out of 183 data points
Removed 0 out of 183 data points
