# AsTRiQue Showcase

This is a showcase of [AsTRiQue](https://github.com/prokophanzl/astrique), an active machine learning framework for reducing stimulus count in perception experiments.

In this showcase, you will be asked to classify sounds as either s/ʃ or z/ʒ (/s/ and /ʃ/ are batched under /s/, /z/ and /ʒ/ are batched under /z/). Once the model has good enough predictions for the rest of the stimuli, a chart will be plotted displaying your answers and the model's predictions.

Afterwards, you can choose to classify the remaining sounds to evaluate the model's performance.

In [None]:
# @title 1. Setup
# @markdown Run this cell to resolve imports and Google Colab-specific features (even if you are not running this notebook in Google Colab).

import os

IN_COLAB = os.getenv("COLAB_RELEASE_TAG") is not None
if IN_COLAB:
	print("Notebook is running in Google Colab. Getting data...")
	
	# check if data has already been downloaded - check for data directory
	if not os.path.exists('data'):
		!git clone https://github.com/prokophanzl/AsTRiQue/
		%cd AsTRiQue
		print("Data downloaded.")
	else:
		print("Data already downloaded.")
	
	print("Getting dependencies...")
	!pip install jupyter_ui_poll
else:
	print("Notebook is not running in Google Colab. Skipping data download.")

import astrique_module
import pandas as pd
import time
from ipywidgets import Button, HBox, VBox, Output
from jupyter_ui_poll import ui_events

if IN_COLAB:
	from IPython.display import Audio, display
else:
	from playsound import playsound

PREDICTOR1 = 'voicing'
PREDICTOR2 = 'duration'
FILENAME_COL = 'filename'
LABEL_MAPPING = {'s': 0, 'z': 1}

TARGET = 'answer_batch'
DATA_PATH = 'data/data.csv'
PROCESSED_PATH = 'data_processed.csv'
AUDIO_FOLDER = 'data/audio'

In [None]:
# @title 2. Config
# @markdown In this cell, you can change the following config values:
# @markdown <hr>
INIT_RANDOM_SAMPLES = 10      # @param {type:"slider", min: 0, max: 104, step: 1}
# @markdown number of random samples to collect in the initial random sampling phase
# @markdown <hr>
MIN_ITERATIONS = 30           # @param {type:"slider", min: 0, max: 104, step: 1}
# @markdown minimum number of iterations
# @markdown <hr>
CLEANSER_FREQUENCY = 0        # @param {type:"slider", min: 0, max: 20, step: 1}
# @markdown insert a high-certainty sample every nth iteration to prevent participant fatigue; 0 to disable
# @markdown <hr>
MODEL_CERTAINTY_CUTOFF = 0.95 # @param {type:"slider", min: 0.5 , max: 1, step: 0.01}
# @markdown certainty threshold to end training
# @markdown <hr>

In [None]:
# @title 3. Participant Functions

def play_sound(filename):
    """
    Plays a sound file, choosing the appropriate method based on environment.
    """
    if IN_COLAB:
        display(Audio(filename, autoplay=True))
    else:
        playsound(filename)

def query_participant_classification(filename):
    """
    Plays audio and waits for human response, then returns the selected label's mapped value.
    """

    filepath = os.path.join(AUDIO_FOLDER, filename)
    if not os.path.exists(filepath):
        print(f"Missing file: {filepath}. Skipping.")
        return None
    
    state = {"done": False, "selected": None}

    def make_click_fn(label=None, special=None):
        def fn(btn):
            if special == "exit":
                raise KeyboardInterrupt("User exited.")
            elif special == "replay":
                play_sound(filepath)
            else:
                state["selected"] = label
                state["done"] = True
        return fn

    # create classification buttons
    label_buttons = [Button(description=l) for l in LABEL_MAPPING]
    for btn in label_buttons:
        btn.on_click(make_click_fn(label=btn.description))

    # create control buttons
    replay_btn = Button(description="replay")
    exit_btn = Button(description="exit")
    replay_btn.on_click(make_click_fn(special="replay"))
    exit_btn.on_click(make_click_fn(special="exit"))

    # display buttons
    controls = VBox([HBox(label_buttons), HBox([replay_btn, exit_btn])])
    display(controls)

    # play sound at the beginning
    play_sound(filepath)

    # wait for user input
    with ui_events() as poll:
        while not state["done"]:
            poll(10)
            time.sleep(0.1)

    # clean up
    controls.close()

    selected = state["selected"]
    mapped = LABEL_MAPPING[selected]

    return mapped

In [None]:
# @title 4. Main Execution

# create stimuli dataframe
stimuli = pd.read_csv(DATA_PATH)

astrique_module.initialize_dataframe(stimuli)

# display output widget
output_widget = Output()
display(output_widget)

# initial random sampling with class balance
collected_classes = set()

iteration = 1 # starting at 1 for human readability
while iteration <= INIT_RANDOM_SAMPLES or len(collected_classes) < 2:
    print(f"Iteration {iteration}: Random sampling")

    # select a random stimulus where real class is unknown
    sample = stimuli[stimuli['participant_classification'].isna()].sample(1)

    # get classification, querying filename
    classification = int(query_participant_classification(sample[FILENAME_COL].values[0]))
    
    collected_classes.add(classification)

    # update row in dataframe
    idx = stimuli[FILENAME_COL] == sample[FILENAME_COL].values[0]
    stimuli.loc[idx, 'classification_order'] = iteration
    stimuli.loc[idx, 'classification_type'] = 'random'
    stimuli.loc[idx, 'participant_classification'] = classification

    iteration += 1

random_samples = iteration - 1

# train initial model
model = astrique_module.train_model(stimuli, PREDICTOR1, PREDICTOR2)

while True:
    # retrain model to get up-to-date predictions on remaining unlabeled samples
    model = astrique_module.train_model(stimuli, PREDICTOR1, PREDICTOR2)

    # get updated unanswered subset
    unanswered = stimuli[stimuli['participant_classification'].isna()]
    
    # check stopping condition
    below_cutoff = unanswered['prediction_certainty'] < MODEL_CERTAINTY_CUTOFF
    if below_cutoff.sum() == 0 and iteration >= MIN_ITERATIONS:
        print("Stopping active learning: all predictions above certainty threshold "f"({MODEL_CERTAINTY_CUTOFF}) and minimum iterations met ({MIN_ITERATIONS}).")
        break

    # select next sample using uncertainty sampling (with optional cleanser)
    sample, sample_type = astrique_module.get_sample(unanswered, iteration, CLEANSER_FREQUENCY, random_samples)

    # query real classification
    classification = int(query_participant_classification(sample[FILENAME_COL].values[0]))

    # update row in dataframe
    idx = stimuli[FILENAME_COL] == sample[FILENAME_COL].values[0]
    stimuli.loc[idx, 'classification_order'] = iteration
    stimuli.loc[idx, 'classification_type'] = sample_type
    stimuli.loc[idx, 'participant_classification'] = classification

    iteration += 1

plot_title = 'Participant Results'

# plot the results
astrique_module.plot_results(stimuli, model, plot_title, PREDICTOR1, PREDICTOR2, LABEL_MAPPING)

# export data if desired
astrique_module.export_data(stimuli, PROCESSED_PATH)

In [None]:
# @title 5. Model Evaluation
# @markdown If you'd like to evaluate the performance of the model, you can run this cell to classify the remaining unanswered stimuli for something to compare the model's predictions against.

output_widget = Output()
display(output_widget)
astrique_module.evaluate_model(stimuli, FILENAME_COL, query_participant_classification, output_widget)