# 1. Action Sequence Data Generation

In [56]:
import blosc
from collections import defaultdict
from copy import copy
import io
import os
import pickle as pkl
from PIL import Image
import random
from tqdm import tqdm

In [67]:
def get_negative_action_seq(actions, tolerance=20):    
    if ("right" not in actions) and ("left" not in actions):
        # If all the actions are "forward", then we do not swap
        negative_case_id = random.choice([1,2])
    else:
        negative_case_id = random.choice([0,1,2])

    if negative_case_id == 0:
        # Swap action tokens
        id_1 = random.randint(0, len(actions)-1)
        count = 0
        while count < tolerance:
            count += 1
            id_2 = random.randint(0, len(actions)-1)
            if id_1 == id_2:
                id_2 = random.randint(0, len(actions)-1)
                continue
            if actions[id_1] == actions[id_2]:
                continue
            break
        actions[id_1], actions[id_2] = actions[id_2], actions[id_1]

    elif negative_case_id == 1:
        # Delete action_token
        id_1 = random.randint(0, len(actions)-1)
        actions.pop(id_1)

    elif negative_case_id == 2:
        # Add action token
        new_token = random.choice(["left", "right"])
        new_token_id = random.choice(list(range(0, len(actions)-1)))  # Don't add as the last action
        actions.insert(new_token_id, new_token)
    return(actions)


In [102]:
def get_pretraining_data(
    env,
    data,
    output_dir_path,
    is_train,
    split_id=0,
):
    def image_to_bytes(image):
        bytes_io = io.BytesIO()
        image.save(bytes_io, "PNG")
        return bytes_io.getvalue()
    
    output_dir = os.path.join(output_dir_path, env)
    os.makedirs(output_dir, exist_ok=True)
    
    data_split = []
    for ex_id, ex in tqdm(enumerate(data), total=len(data)):
        if ex[-1] == "No":
            # Ignore unsolvable cases now
            continue
            
        if len(ex[6]) <= 3 or len(ex[6]) >= 50:
            # Ignore solvable cases with too short/long action sequence
            continue
        
        img = blosc.unpack_array(ex[2])
        img = Image.fromarray(img[0])  # Always select the first frame
        img_encoded = image_to_bytes(img)
        action = ",".join(ex[6])
        
        # Positive case
        data_split.append({
            "img": img_encoded,
            "mission": ex[0],
            "action": action,
            "text_input": f"Goal: {ex[0]}. Action: {action}.",
            "label": "Yes"
        })
        
        # Negative case
        action_tokens = copy(ex[6])
        negative_action = get_negative_action_seq(action_tokens)
        #text_input = ["Goal", ":"] + ex[0] + [".", "Action", ":"] + negative_action + "."
        negative_action = ",".join(negative_action)
        data_split.append({
            "img": img_encoded,
            "mission": ex[0],
            "action": negative_action,
            "text_input": f"Goal: {ex[0]}. Action: {negative_action}.",
            "label": "No"
        })
        
        """
        if (ex_id+1) % split_size == 0:
            if is_train:
                output_name = f"split_train_{split_start_id + split_id}"
            else:
                output_name = f"split_valid_{split_start_id + split_id}"
            
            output_path = os.path.join(output_dir, f"{output_name}.pkl")
            with open(output_path, "wb") as f:
                pkl.dump(data_split, f)
            data_split = []
            split_id += 1
        """
    
    print(f"Collected: {len(data_split)}")
    if data_split:
        if is_train:
            output_name = f"split_train_{split_id}"
        else:
            output_name = f"split_valid_{split_id}"

        output_path = os.path.join(output_dir, f"{output_name}.pkl")
        print(output_path)
        with open(output_path, "wb") as f:
            pkl.dump(data_split, f)
    

In [103]:
level_name = "BossLevel"

for split_id in range(10):
    print(split_id)
    with open(f"../../data/task_0/{level_name}/Task0_{level_name}_train_{split_id}.pkl", "rb") as f:
        demos_train = pkl.load(f)
    
    get_pretraining_data(
        level_name, 
        demos_train, 
        output_dir_path="../../data/clip_experiment/action_seq_binary",
        is_train=True,
        split_id=split_id,
    )

0


100%|██████████████████████████████████████| 5000/5000 [00:29<00:00, 170.42it/s]


Collected: 2014
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_0.pkl
1


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 183.81it/s]


Collected: 1860
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_1.pkl
2


100%|██████████████████████████████████████| 5000/5000 [00:46<00:00, 108.00it/s]


Collected: 1892
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_2.pkl
3


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 180.38it/s]


Collected: 1850
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_3.pkl
4


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 180.40it/s]


Collected: 1880
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_4.pkl
5


100%|██████████████████████████████████████| 5000/5000 [00:28<00:00, 177.34it/s]


Collected: 1898
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_5.pkl
6


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 183.95it/s]


Collected: 1874
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_6.pkl
7


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 183.38it/s]


Collected: 1846
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_7.pkl
8


100%|██████████████████████████████████████| 5000/5000 [00:27<00:00, 180.72it/s]


Collected: 1872
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_8.pkl
9


100%|██████████████████████████████████████| 5000/5000 [00:28<00:00, 172.94it/s]

Collected: 1944
../../data/clip_experiment/action_seq_binary/BossLevel/split_train_9.pkl





In [104]:
level_name = "BossLevel"

for split_id in range(10):
    print(split_id)
    with open(f"../../data/task_0/{level_name}/Task0_{level_name}_valid_{split_id}.pkl", "rb") as f:
        demos_train = pkl.load(f)
    
    get_pretraining_data(
        level_name, 
        demos_train, 
        output_dir_path="../../data/clip_experiment/action_seq_binary",
        is_train=False,
        split_id=split_id,
    )

0


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 180.65it/s]


Collected: 378
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_0.pkl
1


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 169.59it/s]


Collected: 394
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_1.pkl
2


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 186.13it/s]


Collected: 366
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_2.pkl
3


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 193.25it/s]


Collected: 358
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_3.pkl
4


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 176.67it/s]


Collected: 380
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_4.pkl
5


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 173.25it/s]


Collected: 384
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_5.pkl
6


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 176.64it/s]


Collected: 384
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_6.pkl
7


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 174.88it/s]


Collected: 390
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_7.pkl
8


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 188.73it/s]


Collected: 362
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_8.pkl
9


100%|██████████████████████████████████████| 1000/1000 [00:05<00:00, 176.20it/s]

Collected: 384
../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_9.pkl





# 1.1. Number of examples

In [105]:

count = 0
for split_id in range(10):
    with open(f"../../data/clip_experiment/action_seq_binary/BossLevel/split_train_{split_id}.pkl", "rb") as f:
        data = pkl.load(f)
        count += len(data)
print(f"Train data: {count}")

count = 0
for split_id in range(10):
    with open(f"../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_{split_id}.pkl", "rb") as f:
        data = pkl.load(f)
        count += len(data)
print(f"Valid data: {count}")

Train data: 18930
Valid data: 3780


# 1.2. Length of Text Features

In [113]:
ex["action"]

'forward,left,forward,forward,forward'

In [116]:
train_text_input_lens = defaultdict(int)
for split_id in range(10):
    with open(f"../../data/clip_experiment/action_seq_binary/BossLevel/split_train_{split_id}.pkl", "rb") as f:
        data = pkl.load(f)
        for ex_id, ex in enumerate(data):
            mission = ex["mission"]
            action = ex["action"]
            train_text_input_lens[str(len(text_input))] += 1
print(f"Train data: {len(mission.split(' ')) + len(action.split(','))}")

valid_text_input_lens = defaultdict(int)
for split_id in range(10):
    with open(f"../../data/clip_experiment/action_seq_binary/BossLevel/split_valid_{split_id}.pkl", "rb") as f:
        data = pkl.load(f)
        for ex_id, ex in enumerate(data):
            text_input = ex["text_input"]
            valid_text_input_lens[str(len(text_input))] += 1
print(f"Valid data: {len(mission.split(' ')) + len(action.split(','))}")

Train data: 51
Valid data: 51


In [119]:
valid_text_input_lens

defaultdict(int,
            {'284': 14,
             '289': 7,
             '294': 4,
             '286': 8,
             '84': 20,
             '77': 25,
             '255': 10,
             '250': 6,
             '155': 14,
             '149': 14,
             '136': 17,
             '142': 8,
             '97': 25,
             '89': 16,
             '87': 13,
             '79': 28,
             '181': 14,
             '176': 17,
             '283': 5,
             '375': 5,
             '381': 9,
             '429': 4,
             '434': 1,
             '360': 6,
             '352': 10,
             '345': 10,
             '350': 9,
             '113': 14,
             '282': 7,
             '147': 7,
             '141': 14,
             '406': 7,
             '411': 2,
             '390': 5,
             '382': 8,
             '191': 17,
             '196': 18,
             '90': 20,
             '319': 9,
             '377': 8,
             '178': 15,
             '170': 21,
  