In [20]:
import os
import random
import shutil
from pathlib import Path

import pandas as pd
import torch

from task.multi_task import filter_tasks
from utils.stim_io import HvMImageLoader, HvMMetaData, HvMImageMapper, _subpath_after

In [21]:
hvm_dir = '/Users/markbai/PycharmProjects/RNN_NatAbs/data/original'
OUTPUT_CSV = "../resources/interdms_position_identity_trials.csv"
IMAGES_DIR = Path('images')
BG_FP = IMAGES_DIR / 'gray_background.png'
N_STIMS = 100
N_OBJS = N_STIMS // 2

SESSIONS = [1, 2, 3, 4, 5]
TRIALS_PER_SESSION = 20
FRAMES_PER_TRIAL = 4
RANDOM_SEED = 2025
GRID_SIZE = 3
ACTION_MAP = {
    0: 'b',  # no action
    1: 'x',  # match
    2: 'None',  # non-match
}

random.seed(RANDOM_SEED)

In [22]:
def pick_objs(df, n_objs):
    catids = list()
    cats = sorted(df['cat_1b'].unique())

    obj_per_cat = n_objs // len(cats)
    remainder = n_objs % len(cats)

    # main balanced allocation
    for cat in cats:
        objs_in_cat = sorted(df.loc[df['cat_1b'] == cat, 'id_1b'].unique())
        chosen = random.sample(objs_in_cat, obj_per_cat)
        for obj in chosen:
            catids.append((cat, obj))

    # randomly choose remainder categories
    if remainder > 0:
        extra_cats = random.sample(cats, k=remainder)  # ðŸ‘ˆ random instead of cats[:remainder]
        for cat in extra_cats:
            objs_in_cat = sorted(df.loc[df['cat_1b'] == cat, 'id_1b'].unique())
            already = {obj for (c, obj) in catids if c == cat}
            available = list(set(objs_in_cat) - already)
            if available:
                catids.append((cat, random.choice(available)))
    return catids


def pick_locations(df, catid_to_positions, chosen_objs, loc_c=5):
    rows = list()
    for catid in chosen_objs:
        cat, obj = catid
        positions = sorted(catid_to_positions[catid])

        # make sure there is a consistent location for task sampling
        chosen_pos = [loc_c]
        positions.remove(loc_c)
        chosen_pos.append(random.sample(positions, k=1)[0])
        for pos in chosen_pos:
            subset = df[
                (df["cat_1b"] == cat) &
                (df["id_1b"] == obj) &
                (df["pos_1b"] == pos)
                ]
            row = subset.sample(1)
            rows.append(row)
    return rows


def sample_df(hvm_dir, n_objs, grid_size, df_path=None):
    meta = HvMMetaData(hvm_dir)
    img_loader = HvMImageLoader(
        root_dir=hvm_dir,
        metadata=meta,
        preload_images=False,
    )
    img_loader.prepare_for_tasks(grid_size)
    df = img_loader.df
    if df_path is not None and os.path.isfile(df_path):
        print(f'Loading subset csv at {df_path}')
        df_subset = pd.read_csv(df_path)
    else:
        catid_to_positions = img_loader._task_cache.catid_to_positions
        chosen_objs = pick_objs(df, n_objs)
        rows = pick_locations(df, catid_to_positions, chosen_objs, loc_c=5)
        df_subset = pd.concat(rows, ignore_index=True)
        df_subset.to_csv(df_path)
        print(f'Saving subset csv at {df_path}')

    img_loader.df = df_subset
    img_loader._task_cache = None
    img_loader.prepare_for_tasks(grid_size)
    return img_loader, meta

In [23]:
project_dir = Path.cwd().parent
images_dir = project_dir / 'resources'
df_path = images_dir / 'subset.csv'

img_loader, meta_data = sample_df(hvm_dir, N_OBJS, GRID_SIZE, df_path=df_path)
img_loader._task_cache.catid_to_positions

removed 0 rows from the original df
found 5760 local images after filtering
normalizing with stats: ([0, 0, 0], [1, 1, 1])
Loading subset csv at /Users/markbai/PycharmProjects/abs_nat_psychopy_mturk/resources/subset.csv


{(1, 2): array([5, 9]),
 (1, 6): array([5, 9]),
 (1, 4): array([5, 8]),
 (1, 8): array([2, 5]),
 (1, 1): array([5, 8]),
 (1, 5): array([5, 7]),
 (2, 7): array([1, 5]),
 (2, 5): array([5, 7]),
 (2, 2): array([4, 5]),
 (2, 1): array([5, 7]),
 (2, 4): array([4, 5]),
 (2, 6): array([5, 6]),
 (3, 2): array([3, 5]),
 (3, 1): array([5, 6]),
 (3, 7): array([5, 7]),
 (3, 8): array([5, 6]),
 (3, 4): array([5, 7]),
 (3, 6): array([2, 5]),
 (4, 7): array([3, 5]),
 (4, 1): array([5, 6]),
 (4, 5): array([5, 6]),
 (4, 8): array([3, 5]),
 (4, 6): array([1, 5]),
 (4, 4): array([5, 7]),
 (5, 4): array([3, 5]),
 (5, 1): array([5, 8]),
 (5, 6): array([5, 7]),
 (5, 2): array([4, 5]),
 (5, 8): array([5, 6]),
 (5, 5): array([1, 5]),
 (6, 6): array([5, 8]),
 (6, 5): array([1, 5]),
 (6, 1): array([4, 5]),
 (6, 7): array([4, 5]),
 (6, 8): array([4, 5]),
 (6, 3): array([4, 5]),
 (7, 4): array([3, 5]),
 (7, 2): array([5, 9]),
 (7, 1): array([3, 5]),
 (7, 7): array([5, 6]),
 (7, 6): array([1, 5]),
 (7, 8): array([

In [7]:
dataloaders_dict = filter_tasks([20])
datasets = dict()
for task_name, (DatasetClass, kwargs) in dataloaders_dict.items():
    dataset_kwargs = kwargs.copy()
    dataset_kwargs.update({
        'hvm_loader': img_loader,
        'pad_to': FRAMES_PER_TRIAL,
        'dataset_size': len(SESSIONS) * TRIALS_PER_SESSION,
    })
    # Instantiate the dataset
    tmp_task = DatasetClass(**dataset_kwargs)
    tmp_task.reset()
    datasets[task_name] = tmp_task
datasets

{'interdms_ABAB_position_identity': <task.dms.InterDMSDataset at 0x17dbcec40>}

In [17]:
def trial_to_row(ds, emb, action, session_id):
    zero_mask = torch.all(torch.isclose(
        emb,
        torch.tensor(0.0, dtype=emb.dtype)
    ), dim=1)
    nonzero_idx = torch.nonzero(~zero_mask).squeeze(1)

    subset = emb[nonzero_idx]
    img_mapper = HvMImageMapper(ds)
    files, decode_tuples = img_mapper._batch_decode_and_find(subset)
    fp_list = list()
    for i in range(FRAMES_PER_TRIAL):
        if i in nonzero_idx:
            fp = files.pop(0)
            fp = Path(fp)
            fp = _subpath_after(
                fp, segment='HvM_with_discfade'
            )
            fp = IMAGES_DIR / fp
        else:
            fp = BG_FP
        fp_list.append(str(fp))

    row = {
        'session': session_id,
    }
    for i, (a, fp) in enumerate(zip(action, fp_list)):
        if i != 0 and a != 2:  
            # only save action frames
            row[f'act{i + 1}'] = ACTION_MAP[a.item()]

        row[f'stim{i + 1}'] = fp
    return row


def build_csv_from_dataset(dataset, out_csv: str = OUTPUT_CSV) -> pd.DataFrame:
    rows = []
    session_id = 1
    for i, (emb, action, task_index) in enumerate(dataset):
        if i % TRIALS_PER_SESSION == 0 and i > 0:
            session_id += 1
        rows.append(trial_to_row(dataset, emb, action, session_id))
    stim_cols = [f"stim{i}" for i in range(1, FRAMES_PER_TRIAL + 1)]
    act_cols = [f"act{i}" for i in range(3, FRAMES_PER_TRIAL + 1)]
    df = pd.DataFrame(rows, columns=["session"] + stim_cols + act_cols)
    df.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv} (rows={len(df)})")
    return df


def move_images_to_local_folder(df: pd.DataFrame, images_dir: Path = IMAGES_DIR) -> pd.DataFrame:
    project_dir = Path.cwd().parent
    images_dir = project_dir / 'resources' / images_dir
    shutil.rmtree(images_dir / 'Variation00_20110203', True)
    shutil.rmtree(images_dir / 'Variation03_20110128', True)
    shutil.rmtree(images_dir / 'Variation06_20110131', True)

    fps = df['filename'].unique()
    for fp in fps:
        fp = Path(fp)
        new_fp = _subpath_after(
            fp, segment='HvM_with_discfade'
        )
        new_fp = images_dir / new_fp
        new_fp.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(fp, new_fp.absolute())
    return


move_images_to_local_folder(img_loader.df, IMAGES_DIR)
inter_dms = datasets["interdms_ABAB_position_identity"]
out_df = build_csv_from_dataset(inter_dms, out_csv=OUTPUT_CSV)

Saved: ../resources/interdms_position_identity_trials.csv (rows=100)


In [28]:
img_loader.df['var'].value_counts()

var
3    51
6    40
0     9
Name: count, dtype: int64

In [12]:
inter_dms.actions

tensor([[2, 2, 0, 0],
        [2, 2, 1, 0],
        [2, 2, 0, 1],
        [2, 2, 0, 0],
        [2, 2, 0, 1],
        [2, 2, 1, 0],
        [2, 2, 0, 0],
        [2, 2, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 0, 0],
        [2, 2, 0, 1],
        [2, 2, 0, 1],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 0, 1],
        [2, 2, 1, 1],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 0],
        [2, 2, 0, 1],
        [2, 2, 1, 1],
        [2, 2, 0, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 0],
        [2, 2, 0, 1],
        [2, 2, 1, 1],
        [2, 2, 1, 0],
        [2, 2, 1, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 1],
        [2, 2, 1, 1],
        [2, 2, 1, 0],
        [2, 2, 0, 1],
        [2, 2, 0, 0],
        [2, 2, 1, 0],
        [2, 2, 1, 0],
        [2

In [9]:
project_dir = Path.cwd().parent
images_dir = project_dir / 'resources'
for col in [c for c in out_df.columns if 'stim' in c]:
    for img in out_df[col]:
        if not os.path.isfile(images_dir / Path(img)):
            print(img)