In [1]:
import numpy as np
import pandas as pd

from os.path import join, exists
from os import makedirs
from glob import glob
import re

from tqdm import tqdm

import scipy.io

import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib import patches
from matplotlib import animation, rc

In [2]:
dir_behav = '/mnt/ext5/GP/behav_data'

In [3]:
dir_reg = join(dir_behav, 'regressors/AM/0s_shifted')
makedirs(dir_reg, exist_ok=True)

shift = float(*re.findall(r'\d+\.\d+', dir_reg.split('/')[-1]))
print(shift)

0.0


In [9]:
dir_record = join(dir_behav, 'records')
makedirs(dir_record, exist_ok=True)

In [4]:
list_subj = [
    '09', '10', '18', '21', '22'
    , '24', '27', '34', '35', '36'
    , '38', '42', '08', '11', '17'
    , '19', '20', '26', '32', '33'
    , '37', '39', '40', '41', '43'
    , '44', '45', '46', '47', '48'
    , '49', '50', '51', '53', '54'
    , '55', '56', '57', '58', '59'
    , '61', '62'
]

In [5]:
def convert_ID(ID):
    ##################   ##################
    #  1  2  3  4  5 #   #        2       #
    #  6  7  8  9 10 #   #        1       #
    # 11 12 13 14 15 # = # -2 -1  0  1  2 #
    # 16 17 18 19 20 #   #       -1       #
    # 21 22 23 24 25 #   #       -2       #
    ##################   ##################
    x = np.kron(np.ones(5),np.arange(-2,3)).astype(int)
    y = np.kron(np.arange(2,-3,-1),np.ones(5)).astype(int)
    pos = np.array((x[ID-1],y[ID-1]))
    return pos

In [6]:
def func_AMregressor(datum):
    sec = 5
    nS = int(datum['nSampleTrial'][0][0]) # 5 s * 60 Hz = 300 samples
    assert sec*60==nS

    ntrial = 12
    nblock = 8
#     tpr = datum['nTrialperRun'][0][0]     # the number of trials per run = 1 trial+ 12 trial/block * 8 blocks
    tpr = 97
    assert 1+ntrial*nblock==tpr

    onsettime = datum['LearnTrialStartTime'][0]
    idx_editpoint = [i+1 for i,t in enumerate(onsettime[:-2]) if (onsettime[i]>onsettime[i+1])]
    assert (np.diff(idx_editpoint)==tpr).all() # Are you sure the fact that 145 trials per run?

#     nrun = int(datum['nRun'][0][0])       # the total number of runs = 6 runs
    nrun = 3
    ## target ID
    tmp = datum['targetID'][0]
    targetID = tmp[tmp!=0][:tpr*nrun]    # targetID.shape = 291 trials = 97 trial/run * 3 runs

    ## onset times
    tmp = np.zeros((nrun, tpr), dtype=float)
    for run in range(nrun):
        idx = idx_editpoint[run]
        tmp[run,:] = onsettime[idx:idx+tpr]*0.001
    onsettime=tmp

    ## counting how many times did they hit the target
    hit_or_not = np.zeros((tpr*nrun, nS), dtype=bool) # hit_or_not.shape = (# of trials/run, # if frames/trial)
    for t, ID in enumerate(targetID):
        pos = datum['boxSize']*convert_ID(ID) # r_target = [x_target, y_target]
        ## allXY.shape = (2, 60 Hz * 4 s/trial * 145 trials/run * 6 runs = 208800 frames)
        xy = datum['allXY'][:,nS*t:nS*(t+1)] # r_cursor = [x_cursor, y_cursor]
        ## err.shape = (2, nS)
        err = xy - np.ones((2,nS))*pos.T # dr = r_cursor - r_target
        ## is the cursor in the target box?
        hit_or_not[t,:] = (abs(err[0,:]) <= datum['boxSize']*0.5) & (abs(err[1,:]) <= datum['boxSize']*0.5)

    cnt_hit = hit_or_not.reshape(nrun, tpr, sec, 60).sum(axis=(2,3))

    return onsettime, cnt_hit

In [7]:
# rew = {}
for nn in tqdm(list_subj):
# for nn in ['03']:
    subj = 'GP'+nn
    if exists(join(dir_reg, '%s_reward.txt'%subj)):
        continue
    datum = scipy.io.loadmat(join(dir_behav, '%s-fmri.mat'%subj))

    nrun=3
    sec=5

    onsettime, cnt_hit = func_AMregressor(datum)
    reward = cnt_hit/(sec*60)

    AM2 = [[],[],[]]
    for run in range(nrun):
#         ## cut off the first trial
#         AM2[run] = ['%.1f*%.3f'%(o,r) for o,r in zip(onsettime[run][1:], reward[run][1:])]
        ## 4s shift, After every trial, a reward was received in full.
        AM2[run] = ['%.1f*%.3f'%(o,r) for o,r in zip(onsettime[run]+shift, reward[run])]

    np.savetxt(
        join(dir_reg, '%s_reward.txt'%subj)
        , X=AM2, fmt='%s', delimiter=' ', newline='\n'
    )

100%|██████████| 42/42 [00:17<00:00,  2.36it/s]


In [14]:
onsettime[run]+shift

array([  2.   ,   6.995,  11.995,  16.994,  21.994,  26.994,  31.993,
        36.993,  41.993,  46.992,  51.992,  56.992,  61.991,  66.991,
        71.991,  76.99 ,  81.99 ,  86.99 ,  91.99 ,  96.992, 101.99 ,
       106.989, 111.989, 116.988, 121.988, 126.988, 131.991, 136.987,
       141.989, 146.986, 151.986, 156.985, 161.986, 166.989, 171.988,
       176.992, 181.988, 186.987, 191.985, 196.988, 201.982, 206.983,
       211.983, 216.991, 221.983, 226.981, 231.988, 236.987, 241.987,
       246.979, 251.979, 256.979, 261.978, 266.978, 271.978, 276.977,
       281.978, 286.977, 291.976, 296.976, 301.976, 306.977, 311.975,
       316.975, 321.975, 326.974, 331.974, 336.974, 341.974, 346.973,
       351.973, 356.972, 361.976, 366.972, 371.971, 376.971, 381.971,
       386.97 , 391.97 , 396.97 , 401.969, 406.974, 411.973, 416.969,
       421.973, 426.974, 431.967, 436.968, 441.967, 446.973, 451.967,
       456.967, 461.973, 466.972, 471.972, 476.966, 481.971])

In [15]:
onsettime[run]+shift

array([  4.5  ,   9.495,  14.495,  19.494,  24.494,  29.494,  34.493,
        39.493,  44.493,  49.492,  54.492,  59.492,  64.491,  69.491,
        74.491,  79.49 ,  84.49 ,  89.49 ,  94.49 ,  99.492, 104.49 ,
       109.489, 114.489, 119.488, 124.488, 129.488, 134.491, 139.487,
       144.489, 149.486, 154.486, 159.485, 164.486, 169.489, 174.488,
       179.492, 184.488, 189.487, 194.485, 199.488, 204.482, 209.483,
       214.483, 219.491, 224.483, 229.481, 234.488, 239.487, 244.487,
       249.479, 254.479, 259.479, 264.478, 269.478, 274.478, 279.477,
       284.478, 289.477, 294.476, 299.476, 304.476, 309.477, 314.475,
       319.475, 324.475, 329.474, 334.474, 339.474, 344.474, 349.473,
       354.473, 359.472, 364.476, 369.472, 374.471, 379.471, 384.471,
       389.47 , 394.47 , 399.47 , 404.469, 409.474, 414.473, 419.469,
       424.473, 429.474, 434.467, 439.468, 444.467, 449.473, 454.467,
       459.467, 464.473, 469.472, 474.472, 479.466, 484.471])