In [None]:
from NaiveDataModule import NaiveDataModule, NaiveTEACHDataset
import os
from itertools import chain
from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
parent_dir_path = "E:"
data_parent_dir_path = os.path.join(parent_dir_path, 'teach-dataset')
w2v_path = os.path.join(parent_dir_path, 'GoogleNews-vectors-negative300.bin.gz')

## Plot actions Distribution on EDH

In [None]:
naive_datamodule = NaiveDataModule(
    data_parent_dir_path, 
    w2v_path, 
    16, 
    x_text_pad_length=1024, 
    use_small_dataset=False,
    num_workers=4,
    include_x_cur_image=False,
    include_x_prev_actions=False
)

In [None]:
naive_datamodule.setup("train")

In [None]:
naive_datamodule.setup("val")

In [None]:
naive_datamodule.setup("val_unseen")

In [None]:
all_agent_actions = [
    "Stop","Forward","Backward","Turn Left","Turn Right","Look Up","Look Down","Pan Left","Pan Right","Pickup","Place",
    "Open","Close","ToggleOn","ToggleOff","Slice", "Pour"
]
navigation_actions = set(["Forward","Backward","Turn Left","Turn Right","Look Up","Look Down","Pan Left","Pan Right"])
interaction_actions = set(["Pickup","Place","Open","Close","ToggleOn","ToggleOff","Slice", "Pour"])

In [None]:
def action_one_hot_to_name(onehot_tensor):
    return all_agent_actions[onehot_tensor.argmax()]

In [None]:
action_counts = Counter()
navigation_actions_count = Counter()
interaction_actions_count = Counter()

In [None]:
naive_dls = [
    iter(naive_datamodule.train_dataloader()), 
    iter(naive_datamodule.val_dataloader()),
    iter(naive_datamodule.val_unseen_dataloader())
]

In [None]:
for _, batched_y in tqdm(chain(*naive_dls)):
    for i in range(batched_y.size(0)):
        action_name = action_one_hot_to_name(batched_y[i])
        action_counts[action_name] += 1
        if action_name in navigation_actions:
            navigation_actions_count[action_name] += 1
        if action_name in interaction_actions:
            interaction_actions_count[action_name] += 1

In [None]:
action_counts

In [None]:
navigation_actions_count

In [None]:
interaction_actions_count

In [None]:
sum(interaction_actions_count.values())

In [None]:
sum(navigation_actions_count.values())

In [None]:
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
counts = [sum(navigation_actions_count.values()), sum(interaction_actions_count.values())] + list(action_counts.values())
labels = ['Navigation', 'Interaction'] + list(action_counts.keys())
colors = ['blue', 'red'] + ['blue' if action in navigation_actions else 'red' for action in action_counts.keys()]
ax.bar(labels, counts, color=colors)
plt.show()