### Visualize strokes
This notebook is designed for exploratory analysis by visualizing user strokes.

In [61]:
EXPERIMENT_GROUP = "0.0" # Set this to determine which experiment to visualize

ROOT_DIR = "../../.."
DATA_DIRECTORY = "data_experiment/laps_{}/raw"
DATA_FILE = "raw_experiment_data.json"

IMAGES_BASE_DIRECTORY = "static/images"

import os, json, argparse, random, copy
import pathlib
from datetime import datetime
from collections import defaultdict

image_directory = os.path.join(ROOT_DIR, IMAGES_BASE_DIRECTORY)
experiment_file = os.path.join(ROOT_DIR, DATA_DIRECTORY.format(EXPERIMENT_GROUP), DATA_FILE)
print(f"Visualizing data from experiment file: {experiment_file}")
with open(experiment_file) as f:
    experiment_data = json.load(f)

Visualizing data from experiment file: ../../../data_experiment/laps_0.0/raw/raw_experiment_data.json


In [59]:
experiment_ids = experiment_data['metadata']['experiment_ids']
experiments = experiment_data['experiment_ids']
ALL = "all"

#### Experiment summary statistics

In [48]:
for experiment_id in experiment_ids:
    print(experiment_id)
    for condition in experiments[experiment_id]['conditions']:
        print(f"\t{condition}: {len(experiments[experiment_id]['conditions'][condition])} subject")

0_baselines_priors__train-none__test-default__neurips_2020
	all: 1 subject
1_no_provided_language__train-im-dr__test-default__neurips_2020
3_producing_language__train-im-de__test-default__neurips_2020
	condition_S12: 1 subject
3_producing_language__train-im-dr-de__test-default__neurips_2020


#### Visualize images and strokes

In [73]:
from IPython.display import HTML, Image

def _src_from_data(data):
    """Base64 encodes image bytes for inclusion in an HTML img element"""
    img_obj = Image(data=data)
    for bundle in img_obj._repr_mimebundle_():
        for mimetype, b64value in bundle.items():
            if mimetype.startswith('image/'):
                return f'data:{mimetype};base64,{b64value}'

def visualizer_gallery_html(images, descriptions, row_height='auto'):
    """Shows a set of images in a gallery that flexes with the width of the notebook.
    
    Parameters
    ----------
    images: list of str or bytes
        URLs or bytes of images to display

    row_height: str
        CSS height value to assign to all images. Set to 'auto' by default to show images
        with their native dimensions. Set to a value like '250px' to make all rows
        in the gallery equal height.
    """
    figures = []
    for image_idx, image in enumerate(images):
        if isinstance(image, bytes):
            src = _src_from_data(image)
            caption = ''
        else:
            src = image
            img_description = descriptions[image_idx]
            caption = f'<figcaption style="font-size: 0.6em">{image}</figcaption>'
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{src}" style="height: {row_height}">
              {caption}
            </figure>
        ''')
    return f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    '''

def text_html(text):
    return f"<div>{text}</div>"

html = ""
experiments_to_load = [ALL]
for experiment_id in experiment_ids:
    should_load = experiment_ids in experiments_to_load or ALL in experiments_to_load
    this_experiment = experiments[experiment_id]
    has_users = this_experiment['summary']['total_users'] > 0
    if not should_load: continue
    if not has_users: continue
    html += text_html(f"Visualizing strokes for: {experiment_id}")
    
    for condition in this_experiment['conditions']:
        html += text_html(f"Condition: {condition}")
        condition_users = this_experiment['conditions'][condition]
        for idx, user_id in enumerate(condition_users):
            html += text_html(f"\nUser {idx}/{len(condition_users)}")
            user_images = this_experiment['images'][user_id]
            user_descriptions = this_experiment['descriptions'][user_id]
            user_strokes = this_experiment['strokes'][user_id]
            
            user_images = [os.path.join(image_directory, img) for img in user_images if img is not None]
            
            html += gallery_html(user_images, user_descriptions, row_height='150px')

HTML(data=html)
    