In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import math

from seq_utils import generate_kv_mapping

In [2]:
generate_kv_mapping(4, 4)

array([3, 0, 2, 1])

# Constants

In [3]:
seq_folder = "/Users/ccnlab/Development/sequences/shaping/trans"

NUM_KEYS = 4
NUM_FOOD = 4
OUTPUT_COL_ORDER = [
    "stim",
    "correct_key",
    "block",
    "img_folder",
    "key0_trans",
    "key1_trans",
    "key2_trans",
    "key3_trans",
    "shop0_food",
    "shop1_food",
    "shop2_food",
    "shop3_food",
    "trans0_shop",
    "trans1_shop",
    "trans2_shop",
    "trans3_shop",
    "set_size",
]

In [4]:
# First, create an array with numbers 0 to 5 in order (no repeats)
seq = []
for i in range(6):
    seq.extend(np.random.permutation(6))
math.ceil(13 / 6)

3

# Helpers

In [5]:
from seq_utils import generate_sequence_optimized, shuffle_with_mask, swap_by_indices


def shuffle_with_consecutive_check(stim_seq, key_dir, idx_check=1):
    # Pair stim_seq and key_dir by index and shuffle the pairs
    # Shuffle pairs while ensuring no consecutive key_dir values
    max_attempts = 9000
    for attempt in range(max_attempts):
        paired_data = list(zip(stim_seq, key_dir))

        np.random.shuffle(paired_data)

        # Check if any consecutive elements have the same key_dir
        consecutive_same = False
        for i in range(len(paired_data) - 1):
            if (
                paired_data[i][idx_check] == paired_data[i + 1][idx_check]
            ):  # Compare key_dir values
                # Find a successive element with different key_dir to swap with
                swap_idx = None
                for j in range(i + 1, len(paired_data)):
                    if (paired_data[j][idx_check] != paired_data[i][idx_check]) and (
                        paired_data[j][idx_check] != paired_data[i - 1][idx_check]
                    ):
                        swap_idx = j
                        break

                if swap_idx is not None:
                    paired_data[i], paired_data[swap_idx] = (
                        paired_data[swap_idx],
                        paired_data[i],
                    )

            if (
                paired_data[i][idx_check] == paired_data[i + 1][idx_check]
            ):  # Compare key_dir values
                consecutive_same = True
                break

        if not consecutive_same:
            break
    else:
        print(
            f"Warning: Could not find arrangement without consecutive key_dir after {max_attempts} attempts"
        )
    return paired_data


def generate_seq_pair(num_stims, num_iter_per_stim, num_directions=4):
    stim_seq = np.repeat(np.arange(num_stims), num_iter_per_stim)
    key_dir = np.tile(np.arange(num_directions), len(stim_seq) // 4)
    for i in [0, 1]:
        # best effort to avoid consecutive in stim seq first and then key_dir
        paired_data = shuffle_with_consecutive_check(stim_seq, key_dir, i)
        stim_seq, key_dir = zip(*paired_data)

    return np.array(stim_seq), np.array(key_dir)


def generate_shaping_block(num_keys, num_food, num_iter_per_stim, trans_shop_mapping):
    num_trans = len(trans_shop_mapping)
    trans_stim_seq, correct_key_seq = generate_seq_pair(
        num_trans, num_iter_per_stim, num_keys
    )

    food_seq = []
    for i in range(math.ceil(len(correct_key_seq) / num_food)):
        food_seq.extend(np.random.permutation(num_food))

    seq_data = {
        "stim": [],
        "correct_key": correct_key_seq,
    }
    for i in range(num_trans):
        seq_data[f"shop{i}_food"] = []
        seq_data[f"key{i}_trans"] = []

    all_trans_indexes = np.arange(num_trans)
    all_food_indexes = np.arange(num_food)
    for i, correct_key in enumerate(correct_key_seq):
        seq_data["stim"].append(food_seq[i])

        correct_trans = trans_stim_seq[i]
        base = swap_by_indices(all_trans_indexes, correct_trans, correct_key)
        key_trans_array = shuffle_with_mask(
            base, np.array([i == correct_key for i in range(num_keys)])
        )
        for j, trans in enumerate(key_trans_array):
            seq_data[f"key{j}_trans"].append(trans)

        correct_shop = trans_shop_mapping[correct_trans]
        base = swap_by_indices(all_food_indexes, correct_shop, food_seq[i])
        shop_food_array = shuffle_with_mask(
            base, np.array([i == correct_shop for i in range(num_food)])
        )
        for j, food in enumerate(shop_food_array):
            seq_data[f"shop{j}_food"].append(food)

    for k, v in enumerate(trans_shop_mapping):
        seq_data[f"trans{k}_shop"] = [v] * len(trans_stim_seq)

    return pd.DataFrame(seq_data)


def generate_non_shaping_block(
    num_keys, num_food, num_iter_per_stim, trans_shop_mapping, stim_food_mapping
):
    num_villagers = len(stim_food_mapping)
    num_trans = len(trans_shop_mapping)
    villager_seq, correct_key_seq = generate_seq_pair(
        num_villagers, num_iter_per_stim, num_keys
    )

    trans_seq = []
    for i in range(math.ceil(len(correct_key_seq) / num_trans)):
        trans_seq.extend(np.random.permutation(num_trans))

    seq_data = {
        "stim": [],
        "correct_key": correct_key_seq,
    }
    for i in range(num_keys):
        seq_data[f"shop{i}_food"] = []
        seq_data[f"key{i}_trans"] = []

    all_trans_indexes = np.arange(num_trans)
    all_food_indexes = np.arange(num_food)
    for i, correct_key in enumerate(correct_key_seq):
        seq_data["stim"].append(villager_seq[i])

        correct_trans = trans_seq[i]
        base = swap_by_indices(all_trans_indexes, correct_trans, correct_key)
        key_trans_array = shuffle_with_mask(
            base, np.array([i == correct_key for i in range(num_keys)])
        )
        for j, trans in enumerate(key_trans_array):
            seq_data[f"key{j}_trans"].append(trans)

        correct_shop = trans_shop_mapping[correct_trans]
        correct_food = stim_food_mapping[villager_seq[i]]
        base = swap_by_indices(all_food_indexes, correct_shop, correct_food)
        shop_food_array = shuffle_with_mask(
            base, np.array([i == correct_shop for i in range(num_food)])
        )
        for j, food in enumerate(shop_food_array):
            seq_data[f"shop{j}_food"].append(food)

    for k, v in enumerate(trans_shop_mapping):
        seq_data[f"trans{k}_shop"] = [v] * len(villager_seq)

    return pd.DataFrame(seq_data)

# Task sequences

In [11]:
def generate_shaping_round(
    bs,
    img_set,
    num_iter_per_stim,
    trans_shop_mapping,
    last_nonshaping_block,
    num_food=NUM_FOOD,
):
    shaping_block = generate_shaping_block(
        NUM_KEYS, num_food, num_iter_per_stim, trans_shop_mapping
    )

    shaping_block["block"] = bs * 2
    shaping_block["img_folder"] = img_set + 1
    shaping_block["set_size"] = last_nonshaping_block["set_size"]

    nonshaping_block = last_nonshaping_block.copy()
    nonshaping_block["block"] = bs * 2 + 1
    nonshaping_block["img_folder"] = img_set + 1
    return pd.concat([shaping_block, nonshaping_block])


def generate_nonshaping_round(
    bs,
    img_set,
    num_iter_per_stim,
    trans_shop_mapping,
    stim_food_mapping,
    last_nonshaping_block,
    num_food=NUM_FOOD,
):
    nonshaping_block = generate_non_shaping_block(
        NUM_KEYS,
        num_food,
        num_iter_per_stim,
        trans_shop_mapping,
        stim_food_mapping,
    )
    nonshaping_block["block"] = bs * 2
    nonshaping_block["img_folder"] = img_set + 1
    nonshaping_block["set_size"] = last_nonshaping_block["set_size"]

    last_block = last_nonshaping_block.copy()
    last_block["block"] = bs * 2 + 1
    last_block["img_folder"] = img_set + 1

    return pd.concat([nonshaping_block, last_block])


def generate_learning_round(
    last_num_stim_iter,
    num_villagers,
    img_set,
    iter_by_setsz={4: 12, 6: 8},
    num_trans=4,
    num_food=NUM_FOOD,
):
    shaping_blocks, nonshaping_blocks = [], []
    trans_shop_mapping = generate_kv_mapping(NUM_KEYS, NUM_KEYS)
    stim_food_mapping = generate_kv_mapping(num_villagers, num_food)

    last_nonshaping_block = generate_non_shaping_block(
        NUM_KEYS,
        num_food,
        last_num_stim_iter,
        trans_shop_mapping,
        stim_food_mapping,
    )
    last_nonshaping_block["set_size"] = num_villagers
    shaping_blocks = generate_shaping_round(
        0,
        img_set,
        iter_by_setsz[num_trans],
        trans_shop_mapping,
        last_nonshaping_block,
    )
    nonshaping_blocks = generate_nonshaping_round(
        0,
        img_set,
        iter_by_setsz[num_villagers],
        trans_shop_mapping,
        stim_food_mapping,
        last_nonshaping_block,
    )
    return (
        shaping_blocks,
        nonshaping_blocks,
        trans_shop_mapping,
        stim_food_mapping,
    )

def self_paced_seq(sr_mapping):
    return pd.DataFrame({'stim': np.arange(len(sr_mapping)), 'correct_key': sr_mapping})    

In [None]:
# testing round - two iteration per direction per stimulus
LAST_BLOCK_ITER = 24
SZ_TO_ITER = {4: 15, 6: 10}
NUM_TEST_ITER = 4

img_set = 0
set_size = 6
for seq_idx in range(2):
    shaping_blocks, nonshaping_blocks, trans_shop_mapping, stim_food_mapping = (
        generate_learning_round(LAST_BLOCK_ITER, set_size, img_set, SZ_TO_ITER, 4, NUM_FOOD)
    )
    testing_data = generate_non_shaping_block(
        NUM_KEYS,
        NUM_FOOD,
        NUM_TEST_ITER,
        trans_shop_mapping,
        stim_food_mapping,
    )
    testing_data["block"] = 1
    testing_data["img_folder"] = img_set
    testing_data["set_size"] = set_size
    testing_data = testing_data[OUTPUT_COL_ORDER]

    self_paced_seq_data = []
    for t, mapping in zip(['trans', 'food'], [trans_shop_mapping, stim_food_mapping]):
        seq_data = self_paced_seq(mapping)
        seq_data['type'] = t
        self_paced_seq_data.append(seq_data)

    self_paced_seq_data = pd.concat(self_paced_seq_data)
    self_paced_seq_data['set_size'] = set_size
    self_paced_seq_data['img_folder'] = img_set
    for name, data in zip(
        ["shaping", "nonshaping"], [shaping_blocks, nonshaping_blocks]
    ):
        concated_data = data
        concated_data["block"] = concated_data["block"] + 1
        concated_data = concated_data[OUTPUT_COL_ORDER]
        concated_data.to_csv(f"{seq_folder}/{name}_trans{seq_idx}_learning.csv", index=False)

    testing_data.to_csv(f"{seq_folder}/trans{seq_idx}_testing.csv", index=False)
    self_paced_seq_data.to_csv(f"{seq_folder}/trans{seq_idx}_self_paced_testing.csv", index=False)

