# AsTRiQue Showcase

This is a showcase of [AsTRiQue](https://github.com/prokophanzl/astrique), an active machine learning framework for reducing stimulus count in perception experiments.

In this showcase, a lookup table (virtual agent) will be asked to classify sounds as either s/ʃ or z/ʒ (/s/ and /ʃ/ are batched under /s/, /z/ and /ʒ/ are batched under /z/). Once the model has good enough predictions for the rest of the stimuli, its performance will be evaluated on the sounds the virtual agent didn't classify and a chart will be plotted displaying the virtual agent's answers and the model's predictions.

In [None]:
# @title Setup
# @markdown Run this cell to resolve imports and Google Colab-specific features (even if you are not running this notebook in Google Colab).

import os

IN_COLAB = os.getenv("COLAB_RELEASE_TAG") is not None
if IN_COLAB:
	print("Notebook is running in Google Colab. Getting data...")
	
	# check if data has already been downloaded - check for data directory
	if not os.path.exists('data'):
		!git clone https://github.com/prokophanzl/AsTRiQue/
		%cd AsTRiQue
		print("Data downloaded.")
	else:
		print("Data already downloaded.")
else:
	print("Notebook is not running in Google Colab. Skipping data download.")

import astrique_module
import pandas as pd

PREDICTOR1 = 'voicing'
PREDICTOR2 = 'duration'
FILENAME_COL = 'filename'
LABEL_MAPPING = {'s': 0, 'z': 1}

TARGET = 'answer_batch'
DATA_PATH = 'data/data.csv'
PARTICIPANT_CSV_DIR = 'data/participants'
PROCESSED_PATH = 'data_processed.csv'

astrique_module.set_seed(42)

In [None]:
# @title 2. Config
# @markdown In this cell, you can change the following config values:
# @markdown <hr>
STRATIFIED_SAMPLING_RESOLUTION = 3 # @param {type:"slider", min: 1, max: 10, step: 1}
# @markdown resolution for stratified sampling of random samples; number of initial stratified samples collected is up to STRATIFIED_SAMPLING_RESOLUTION^2
# @markdown <hr>
MIN_ITERATIONS = 30           # @param {type:"slider", min: 0, max: 104, step: 1}
# @markdown minimum number of iterations
# @markdown <hr>
CLEANSER_FREQUENCY = 0        # @param {type:"slider", min: 0, max: 20, step: 1}
# @markdown insert a high-certainty sample every nth iteration to prevent participant fatigue; 0 to disable
# @markdown <hr>
MODEL_CERTAINTY_CUTOFF = 0.95 # @param {type:"slider", min: 0.5 , max: 1, step: 0.01}
# @markdown certainty threshold to end training
# @markdown <hr>
PARTICIPANT_TO_MODEL = 'p03'  # @param ["p01", "p02", "p03", "p04", "p05", "p06", "p07", "p08", "p09", "p10", "p11", "p12", "p13", "p14", "p15", "p16", "p17", "p18", "p19", "p20", "p21", "p22", "p23", "p24", "p25", "p26", "p27", "p28", "p29", "p30", "p31"]
# @markdown participant ID to simulate
# @markdown <hr>

In [None]:
# @title 3. Virtual Agent Functions

def query_participant_classification(filename):
    """
    Queries a virtual agent for a classification of a given sample
    """
    # look into the participant's answer lookup table - PARTICIPANT_CSV_DIR/PARTICIPANT_TO_MODEL.csv
    # return the real class based on LABEL_MAPPING
    participant_answers = pd.read_csv(PARTICIPANT_CSV_DIR + '/' + PARTICIPANT_TO_MODEL + '.csv')
    real_answer = participant_answers[participant_answers[FILENAME_COL] == filename][TARGET].values[0]
    return LABEL_MAPPING[real_answer]

In [None]:
# @title 4. Main Execution

# create stimuli dataframe
stimuli = pd.read_csv(DATA_PATH)

astrique_module.initialize_dataframe(stimuli)

# initial random sampling with class balance
collected_classes = set()

iteration = 1 # starting at 1 for human readability

stratified_samples = astrique_module.get_stratified_samples(stimuli, PREDICTOR1, PREDICTOR2, STRATIFIED_SAMPLING_RESOLUTION)

# create an iterator over the stratified samples
stratified_iterator = stratified_samples.iterrows()

# loop over stratified samples
for _, sample in stratified_samples.iterrows():
    print(f"Iteration {iteration}: Stratified sampling")

    # get classification by querying with filename
    filename = sample[FILENAME_COL]
    classification = int(query_participant_classification(filename))

    collected_classes.add(classification)

    # update corresponding row(s) in the main stimuli dataframe
    idx = stimuli[FILENAME_COL] == filename
    stimuli.loc[idx, 'classification_order'] = iteration
    stimuli.loc[idx, 'classification_type'] = 'stratified'
    stimuli.loc[idx, 'participant_classification'] = classification

    iteration += 1

# ensure class diversity
while len(collected_classes) < 2:
    print(f"Iteration {iteration}: Random sampling to balance classes")

    # select a random stimulus where real class is unknown
    sample = stimuli[stimuli['participant_classification'].isna()].sample(1)

    # get classification, querying filename
    classification = int(query_participant_classification(sample[FILENAME_COL].values[0]))
    
    collected_classes.add(classification)

    # update row in dataframe
    idx = stimuli[FILENAME_COL] == sample[FILENAME_COL].values[0]
    stimuli.loc[idx, 'classification_order'] = iteration
    stimuli.loc[idx, 'classification_type'] = 'random'
    stimuli.loc[idx, 'participant_classification'] = classification

    iteration += 1

init_samples = iteration - 1


# train initial model
model = astrique_module.train_model(stimuli, PREDICTOR1, PREDICTOR2)

while True:
    # retrain model to get up-to-date predictions on remaining unlabeled samples
    model = astrique_module.train_model(stimuli, PREDICTOR1, PREDICTOR2)

    # get updated unanswered subset
    unanswered = stimuli[stimuli['participant_classification'].isna()]
    
    # check stopping condition
    below_cutoff = unanswered['prediction_certainty'] < MODEL_CERTAINTY_CUTOFF
    if below_cutoff.sum() == 0 and iteration > MIN_ITERATIONS:
        print("Stopping active learning: all predictions above certainty threshold "
              f"({MODEL_CERTAINTY_CUTOFF}) and minimum iterations met ({MIN_ITERATIONS}).")
        break

    # select next sample using uncertainty sampling (with optional cleanser)
    sample, sample_type = astrique_module.get_sample(unanswered, iteration, CLEANSER_FREQUENCY, init_samples)

    # query real classification
    classification = int(query_participant_classification(sample[FILENAME_COL].values[0]))

    # update row in dataframe
    idx = stimuli[FILENAME_COL] == sample[FILENAME_COL].values[0]
    stimuli.loc[idx, 'classification_order'] = iteration
    stimuli.loc[idx, 'classification_type'] = sample_type
    stimuli.loc[idx, 'participant_classification'] = classification

    iteration += 1

# evaluate model
astrique_module.evaluate_model(stimuli, FILENAME_COL, query_participant_classification)

plot_title = f'Virtual Agent Results (participant: {PARTICIPANT_TO_MODEL})'

# plot the results
astrique_module.plot_results(stimuli, model, plot_title, PREDICTOR1, PREDICTOR2, STRATIFIED_SAMPLING_RESOLUTION, LABEL_MAPPING)

# export data if desired
astrique_module.export_data(stimuli, PROCESSED_PATH)