In [1]:
import emip_toolkit as emtk
import filter_fixation as ff
import numpy as np
import pandas as pd
import json
import time
import copy

%matplotlib tk      
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

%load_ext autoreload
%autoreload 2

In [2]:
experiments = emtk.AlMadi_dataset(path='datasets/AlMadi2018/ASCII/')

parsing file: datasets/AlMadi2018/ASCII/001.asc
parsing file: datasets/AlMadi2018/ASCII/003.asc
parsing file: datasets/AlMadi2018/ASCII/002.asc
parsing file: datasets/AlMadi2018/ASCII/006.asc
parsing file: datasets/AlMadi2018/ASCII/007.asc
parsing file: datasets/AlMadi2018/ASCII/005.asc
parsing file: datasets/AlMadi2018/ASCII/004.asc
parsing file: datasets/AlMadi2018/ASCII/008.asc


In [71]:
for participant_id in experiments:
    for trial_id in range(16):
        cens = {}
        fixations = experiments[participant_id].trial[trial_id].get_fixations()
        image = experiments[participant_id].trial[trial_id].get_trial_image()
        for order in fixations:
            duration, x_cord, y_cord = fixations[order].get_fixation()[3:6]
            if duration < 100:
                continue
            cens[order] = [duration, x_cord, y_cord]
        with open(f'originals_AlMadi/{participant_id}_{trial_id}.json', 'w') as file:
            file.write(json.dumps(cens))

In [3]:
def draw(centers):
    fixs = []
    for data in centers.values():
        if data[0] == -1 and data[1] == -1: continue
        elif data[0] == -2 and data[1] == -2: continue
        else:
            x, y, duration = data
            fixs.append([x, y])
    fixs = np.array(fixs)
    
    l = plt.plot(fixs[:, 0], fixs[:, 1], alpha=0.4, c='orange')
    line = l.pop(0)
    
    scatter_color = ['yellow'] * len(fixs)
    if list(centers.values())[0][0] != -2: scatter_color[0] = 'red'
    scatter = plt.scatter(fixs[:, 0], fixs[:, 1], alpha=0.5, s=100, c=scatter_color)
    return line, scatter

def is_fixation(x, y, centers):
    if find_index(centers):
        # Empty space exists, new fixation needs to be drawn
        return None
    for index, (c_x, c_y, duration) in centers.items():
        # If click is on one fixation -> Hide it
        if c_x-12 < x < c_x+12 and c_y-12 < y < c_y+12:
            return index
    return None

def find_index(centers):
    for index, data in centers.items():
        if data[0] == -1 and data[1] == -1:
            return index
    return None

def changed_fix(participant_id, trial_id):
    count = 0
    original = f'originals_AlMadi/{participant_id}_{trial_id}.json'
    with open(original, 'r') as file:
        original = json.loads(file.read())
        
    corrected = f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json'
    with open(corrected, 'r') as file:
        corrected = json.loads(file.read())
    for key in original:
        if original[key][1] != corrected[key][0]:
            count += 1
    return count, len(original)

In [4]:
def read_AlMadi(path):
    with open(path, 'r') as file:
        centers = json.loads(file.read())
    return centers

In [67]:
from PIL import Image
lst = '-8005777751935009053.png', '-8430727859316488961.png', '-4351795153660429395.png', '2370169755802222682.png', '-4820041588910330858.png', '-1446515203633643784.png', '5667346413132987794.png', '-471492104380388071.png', '3505254255594601664.png', '-3307581271792277746.png', '-173769458936473370.png', '-8419304329526813939.png', '-504446773508767833.png', '7818276589008410372.png', '-2683174645433713120.png', '5659673816338139458.png', '1706913863030379050.png', '-6803693801363870677.png', '8178455215069693319.png', '6253885407738466738.png', '179618635161020985.png'
for i in lst:
    image = 'images/' + i
    img_ = Image.open(image).convert('RGB')
    background = Image.new('RGBA', (1024, 768), (0, 0, 0, 255))
    background.paste(img_, (10, 375), img_.convert('RGBA'))
    background.save(f'Imagess/{i}')

In [76]:
start = time.time()
fig = plt.figure(figsize=(12,8))

path = 'originals_AlMadi/001_9.json'

fixations = read_AlMadi(path)

path = path.split('/')[-1]
trial_id = int(path.split('.')[0][4:])
participant_id = path[0:3]

imagename = experiments[participant_id].trial[trial_id].get_trial_image()
imagepath = f'Imagess/{imagename}'

img = mpimg.imread(imagepath)

# Initialize fixation centers
centers = {}
for order, (duration, x_cord, y_cord) in fixations.items():
    if duration < 100:
        continue
    centers[order] = [x_cord, y_cord, duration]

line, scatter = draw(centers)
history = []

def onclick(event):
    global line, scatter, centers, clicked, history

    if event.button == 1:
        x, y = event.xdata, event.ydata
        index = is_fixation(x, y, centers)

        if index:
            # If click is on one fixation -> Hide it
            history.append(copy.deepcopy(centers))
            centers[index][0] = -1
            centers[index][1] = -1
        else:
            # If click is on empty spot -> Find empty spot and fill with current position
            index = find_index(centers)
            if index:
                centers[index][0] = x
                centers[index][1] = y

        line.remove()
        scatter.remove()

        line, scatter = draw(centers)
        fig.canvas.draw()

        with open(f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json', 'w') as file:
            file.write(json.dumps(centers))
    elif event.button == 3:
        x, y = event.xdata, event.ydata
        index = is_fixation(x, y, centers)
        
        if index:
            history.append(copy.deepcopy(centers))
            centers[index][0] = -2
            centers[index][1] = -2
            
        line.remove()
        scatter.remove()

        line, scatter = draw(centers)
        fig.canvas.draw()

        with open(f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json', 'w') as file:
            file.write(json.dumps(centers))

    
def onpress(event):
    global history, line, scatter, centers

    if event.key == 'z':
        if len(history) > 0:
            centers = history.pop()

            line.remove()
            scatter.remove()

            line, scatter = draw(centers)
            fig.canvas.draw()

            with open(f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json', 'w') as file:
                file.write(json.dumps(centers))

def onclose(event):
    global participant_id, trial_id
    end = time.time()
    total_time = round(end-start, 2)
    
    changed_fixations, total_fixations = changed_fix(participant_id, trial_id)
    
    with open(f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json', 'r') as file:
        fixations_data = json.loads(file.read())
    
    output = {}
    output['total_time'] = total_time
    output['changed_fixations'] = changed_fixations
    output['total_fixations'] = total_fixations
    output['fixations_data'] = fixations_data
    
    with open(f'corrections_AlMadi/{participant_id}_{trial_id}_CORRECTED.json', 'w') as file:
        file.write(json.dumps(output))

cid = fig.canvas.mpl_connect('button_press_event', onclick)
cid2 = fig.canvas.mpl_connect('close_event', onclose)
cid3 = fig.canvas.mpl_connect('key_press_event', onpress)
imgplot = plt.imshow(img)
plt.show()