In [1]:
%matplotlib inline
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import seaborn
import scipy.stats as stt
seaborn.set(font_scale=1.5,style='ticks')
import os
import re
import sys
import itertools
import functools
import networkx as nx
from datetime import date, timedelta
from datetime import datetime
from scipy.ndimage import gaussian_filter1d

#sys.path.append(r"C:\Users\yweissenberger\Documents\code\line_loop-master")
#sys.path.append(r"C:\Users\yweissenberger\Documents\code\line_loop-master\packages")
sys.path.append("/Users/joshuakeeling/Documents/Python/Project/line_loop/packages/")

In [2]:
import mouse_poker as mpk
from mouse_poker.navi import *

In [3]:
def performance_arrays(state_seq,port_seq,rew_list,forced_seq):
    
    """
    Find the state being rewarded during each step in the session
    If the choice that is made decreases the step distance from rew
    score choice as correct (1) if not, score as incorrect (0)
    
    Creates some useful arrays: 
    state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,
    rew_change,rew_change_end,trial_perf,trial_rew_change_end
    """
    
    ##Need to end the used data with the final reward rather than 
    #half way through a trial
    #take out forced trials
    final = np.where(rew_list)[0][-1]+1
    state_seq = state_seq[0:final]
    rew_list = rew_list[0:final]
    port_seq = port_seq[0:final]
    forced_seq = forced_seq[0:final]    
    
    ##Create array of state being rewarded during each step in the session
    rew_state = np.zeros(len(rew_list))
    prev_rew_ind = 0
    for rew_ind in np.where(rew_list)[0]:
        rew_state[prev_rew_ind:rew_ind+1]=state_seq[rew_ind]
        prev_rew_ind = rew_ind+1

    #Whether the choice is correct:
    score_list = np.zeros(len(rew_list))

    for state_ind in np.arange(len(state_seq)):
        if rew_list[state_ind] == True:
            score_list[state_ind] = 'Nan' #Score the reward state "choice" as Nan
        else:
            d0 = np.abs(state_seq[state_ind]-rew_state[state_ind])
            d1 = np.abs(state_seq[state_ind+1]-rew_state[state_ind+1])
            if d1 - d0 == -1:
                #correct choice gets score of 1
                score_list[state_ind]=1
            else:
                #incorrect choice gets score of 0
                score_list[state_ind]=0
    
    ##Remove forced choices from performance scores
    score_list[np.where(forced_seq)[0]]='Nan'
    
    ##Remove end of line choices from performance scores if "line"
    if mpk.load.get_metadata(lines)[5]=='line':
        ind = np.where((np.array(state_seq)==min(state_seq))|(np.array(state_seq)==max(state_seq)))[0]
        score_list[ind]='Nan'
        
    ##Arr for indices of reward changes (rew is in new loc after this index)
    rew_change = np.where(rew_state[:-1] != rew_state[1:])[0]
    #Ensure analysis of final section
    rew_change_end = np.append(rew_change, len(rew_list)-1) #could take out if final trial is not useful
    #rew_change_end = rew_change
    
    ## Arr of perf with each successive trial
    # trial_perf
    rew_ind = np.where(rew_list)[0]
    trial_perf = [] #Arr w perf on each trial scored as above
    trial_dec_count = [] #Arr w num decisions on each trial
    low1 = -1
    low2 = 0
    
    for rew_change_ind in rew_change_end:
        #Create arrays for score + number of decisions for each trial
        for trial_ind in np.where((rew_ind>low1) & (rew_ind<=rew_change_ind))[0]:
            upp = int(rew_ind[trial_ind])
            score = np.nanmean(score_list[low2:upp])
            count = np.sum(~np.isnan(score_list[low2:upp]))
            low2 = int(rew_ind[trial_ind])
            trial_perf.append(score)
            trial_dec_count.append(count)
        low1 = rew_change_ind

        
    ##Create arr for trial# where rew change
    # trial_rew_change_end
    trial_rew_change_end = np.zeros(len(rew_change_end))
    for rew_change_step in np.arange(len(rew_change_end)):
        y = np.where(np.where(rew_list)[0]==rew_change_end[rew_change_step])[0]
        trial_rew_change_end[rew_change_step] = y

    return state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,\
    rew_change,rew_change_end,trial_perf,trial_rew_change_end,trial_dec_count


def get_poke_to_state_map(lines):
    """ """
    tmp = []
    for i,j in zip([int(re.findall('POKEDPORT_([0-9])',i)[0]) for i in lines if '_POKEDPORT' in i],
                         [int(re.findall('POKEDSTATE_([0-9])',i)[0]) for i in lines if '_POKEDSTATE' in i]):

        if [i,j] not in tmp:
            tmp.append([i,j])
    poke_to_state_map = [i[1] for i in sorted(tmp)]
    return poke_to_state_map

In [4]:
##BATCH ANALYSIS:

#ROOT = '/Users/joshuakeeling/Documents/Python/Project/beh_data_newroom/line_loop_batch_3NAVI/'
ROOT = '/Users/joshuakeeling/Documents/Python/Project/beh_data_newroom/line_loop_batch_4_RUNNAVI/'
#ROOT = '/Users/joshuakeeling/Documents/Python/Project/beh_data_newroom/'

res_dict = {}
minNrew = 20
today = datetime.now()

files = []
DF = pd.DataFrame()

for filename in os.listdir(ROOT):
    if filename.endswith(".txt"): 
        files.append(filename)
    else:
        continue

for file in files:
    try:
        
        fpath = os.path.join(ROOT, file)
        f = open(fpath, 'r')
        lines = f.readlines()

        #Get metadata:
        experiment_name, task_name, subject_id, task_nr, graph,lineloop,date,\
            test,overview = mpk.load.get_metadata(lines)


        if overview['n_rewards'] > minNrew:

            state_seq,rew_list,port_seq,forced_seq = extract_navi_dat(lines)

            state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,\
                rew_change,rew_change_end,trial_perf,trial_rew_change_end,trial_dec_count\
                = performance_arrays(state_seq,port_seq,rew_list,forced_seq)

            poke_state_map = get_poke_to_state_map(lines)

            file_dict = {'subject_id':subject_id,'date_time':pd.to_datetime(date),\
                    'state_seq':state_seq,'rew_list':rew_list,'port_seq':port_seq,\
                    'forced_seq':forced_seq,'rew_state':rew_state,'score_list':score_list,\
                    'rew_change':rew_change,'rew_change_end':rew_change_end,\
                    'trial_perf':trial_perf,'trial_rew_change_end':trial_rew_change_end,\
                    'poke_state_map':poke_state_map,'trial_dec_count':trial_dec_count}

            DF = DF.append(file_dict, ignore_index = True)

            mean_perf = np.nanmean(score_list)

            print(mpk.load.get_metadata(lines)[5], int(100*mean_perf), file)
    except:
        print("error!!!!", file)

line 78 '456675_3'-2021-03-02-103339.txt
line 68 '456675_3'-2021-02-15-113953.txt
line 70 '456675_3'-2021-02-23-102808.txt
line 69 '456675_3'-2021-02-25-110427.txt


  score = np.nanmean(score_list[low2:upp])


line 56 '456675_3'-2021-03-05-120244.txt
line 69 '456675_3'-2021-03-01-110336.txt
line 79 '456675_3'-2021-03-04-105909.txt
line 69 '456675_3'-2021-02-12-121633.txt
line 82 '456675_3'-2021-03-03-105842.txt
line 53 '456675_3'-2021-02-22-112819.txt
line 65 '456675_3'-2021-02-11-122740.txt
line 76 '456675_3'-2021-02-26-104433.txt
line 61 '456675_3'-2021-02-24-105205.txt


In [5]:
##Performance in each state for a given reward location

min_state = np.min(DF['state_seq'][0]) #uses values from the first session for all sessions
max_state = np.max(DF['state_seq'][0]) #uses values from the first session for all sessions
rew_locs = (np.unique(DF['rew_state'][0])).astype(int) #uses values from the first session for all sessions 

Ncols = len(np.arange(min_state,max_state+1,1))
Nrows = len(np.unique(rew_locs))
Nsess = len(np.arange(np.shape(DF)[0]))

score_matrix = np.zeros((Nsess,Nrows,Ncols))
count_matrix = np.zeros((Nsess,Nrows,Ncols))

for sess in np.arange(np.shape(DF)[0]):
    STATE = DF['state_seq'][sess]
    SCORE = DF['score_list'][sess]
    REW_LOC = (DF['rew_state'][sess]).astype(int)
    
    for col in np.arange(min_state,max_state+1,1):
        for rew_loc in rew_locs:
            
            score_sum = np.nansum(((SCORE[np.where((np.array(STATE)==col)&(np.array(REW_LOC)==rew_loc))[0]])))
            count = sum(~np.isnan(SCORE[np.where((np.array(STATE)==col)&(np.array(REW_LOC)==rew_loc))[0]]))
            row = np.where(rew_locs==rew_loc)[0]
            score_matrix[sess,row,col] = score_sum
            count_matrix[sess,row,col] = count
            
score_matrix = np.sum(score_matrix, axis = 0)/np.sum(count_matrix, axis = 0)

  score_matrix = np.sum(score_matrix, axis = 0)/np.sum(count_matrix, axis = 0)


#Need to work out logic gate for trial selection
#IF N is the trial with a new reward location (rew location is discovered this trial)
#need to look at N + 1 and see if the mouse goes towards the old rew location or the new one
#can only test this when on trial N+1 the mouse is started between the old and new rew location in state space

#need trial where rew change loc

In [6]:
#Trial rew change gives the trial before reward is moved
#we want the decision after the trial after rew changes so we want the decision after trial +1 
score_df = pd.DataFrame()
for index in DF.index:
    score_arr = []
    for num, rew_change in enumerate(DF['trial_rew_change_end'][index][0:-1]): #dont use "end"
        change_I = np.where(DF['rew_list'][index])[0][int(rew_change)]
        plus1_I = np.where(DF['rew_list'][index])[0][int(rew_change)+1]

        #make logic gate for which trials are used!
        state = DF['state_seq'][index][plus1_I+1]
        old_rew = DF['rew_state'][index][change_I]
        new_rew = DF['rew_state'][index][plus1_I+1]

        if (old_rew > state > new_rew) or (old_rew < state < new_rew):
            score_val = DF['score_list'][index][plus1_I+1]
            score_arr.append(score_val)
            #print(score_val)
        else:
            continue
        #This gives the index of the reward when the end in terms of 
        #Do they go towards reward?
    
    score_df = score_df.append(score_arr)

In [11]:
def prev_prob_f(state):

    n_trials_earlier = np.where(DF['rew_list'][index])[0][int(rew_change-n_trials)]
    indices = np.where(np.asarray(DF['state_seq'][index][n_trials_earlier:DF['rew_change_end'][index][num]])==state)
    prev_prob = 1-np.nanmean(DF['score_list'][index][n_trials_earlier:DF['rew_change_end'][index][num]][indices])
    
    return prev_prob

In [16]:
##THE ONE with various parameters
#Parameters
allow_move_into_middle = False
full_trial = False #True -> score for each decision on a trial, false -> first decision
first_visit = False #True + full_trial True -> give score of first decision in each state on a trial
next_to_reward = False #True & full_trial True & first_visit False ->only look at decisions next to reward
not_next_to_reward = False #True & full_trial True & first_visit False & next_to_reward False -> only look at 
                            #decisions not next to reward
comp_across_prev_trial = False
comp_across_trials = False

n_trials = 20


score_df = pd.DataFrame()
prev_prob_df = pd.DataFrame()

for index in DF.index: #for all sessions
    score_arr = []
    prev_prob_arr = []
    for num, rew_change in enumerate(DF['trial_rew_change_end'][index][0:-1]): #-1 to not use "end"
    
        #set the number of rewards possible to the number of rew in each section -1 
        permissible_num_rewards = int(DF['trial_rew_change_end'][index][num+1]-DF['trial_rew_change_end'][index][num])-1

        change_I = np.where(DF['rew_list'][index])[0][int(rew_change)] #index of change in rew
        old_rew = DF['rew_state'][index][change_I]                     #old rew state
        
        prev_states = []                                               #states visited since rew_change

        for num_rews in np.arange(permissible_num_rewards)+1:
            plus_num_rew_I = np.where(DF['rew_list'][index])[0][int(rew_change)+(num_rews)]+1
            plus_num_rew2_I = np.where(DF['rew_list'][index])[0][int(rew_change)+(num_rews+1)]
            
            state = DF['state_seq'][index][plus_num_rew_I]
            new_rew = DF['rew_state'][index][plus_num_rew_I]
            
            prev_states = DF['state_seq'][index][change_I+1:plus_num_rew_I]
            
            if allow_move_into_middle == True:
                for num_steps, states in enumerate(DF['state_seq'][index][plus_num_rew_I:plus_num_rew2_I]):
                    if ((old_rew > states > new_rew) or (old_rew < states < new_rew)) & (states not in prev_states):
                        score_val = DF['score_list'][index][plus_num_rew_I+num_steps]
                        score_arr.append(score_val)
                        prev_prob = prev_prob_f(states)
                        prev_prob_arr.append(prev_prob)
                        print(
                            'ID:', DF['subject_id'][index], DF['date_time'][index],
                            '\n', index, plus_num_rew_I,change_I, old_rew, new_rew, score_val,num_steps,
                            '\n', DF['state_seq'][index][plus_num_rew_I-5:plus_num_rew_I+20],
                            '\n---------------------------\n')
                        break
                        
            elif ((old_rew > state > new_rew) or (old_rew < state < new_rew)) & (state not in prev_states):
                if full_trial == True:
                    trial_state_prevs = []
                    if first_visit == True:
                        score_val = []
                        for num_steps, states in enumerate(DF['state_seq'][index][plus_num_rew_I:plus_num_rew2_I]):                        
                            if states not in trial_state_prevs:
                                trial_state_prevs.append(states)
                                score_val.append(DF['score_list'][index][plus_num_rew_I+num_steps])
                                prev_prob = prev_prob_f(states)
                                prev_prob_arr.append(prev_prob)
                            else:
                                continue
                                
                    elif next_to_reward == True:
                        arr = np.asarray(DF['state_seq'][index][plus_num_rew_I:plus_num_rew2_I])
                        next_to = np.where((arr==new_rew+1)|(arr==new_rew-1))[0]
                        score_val = np.asarray(DF['score_list'][index][plus_num_rew_I:plus_num_rew2_I])[next_to]
                        
                    elif not_next_to_reward == True:
                        arr = np.asarray(DF['state_seq'][index][plus_num_rew_I:plus_num_rew2_I])
                        not_next_to = np.where((arr!=new_rew+1)&(arr!=new_rew-1))[0]
                        score_val = np.asarray(DF['score_list'][index][plus_num_rew_I:plus_num_rew2_I])[not_next_to]
                    
                    else: #all decisions on a trial
                        score_val = DF['score_list'][index][plus_num_rew_I:plus_num_rew2_I+1] #Check indexing
                
                else: #standard conditions
                    score_val = DF['score_list'][index][plus_num_rew_I]
                    
                    prev_prob = prev_prob_f(state)
                    prev_prob_arr.append(prev_prob)
                    
                score_arr.append(score_val)
                print(
                    'ID:', DF['subject_id'][index], DF['date_time'][index],
                    '\n', index, plus_num_rew_I,change_I, old_rew, new_rew, score_val,
                    '\n', DF['state_seq'][index][plus_num_rew_I-5:plus_num_rew_I+20],
                    '\n---------------------------\n')

    score_df = score_df.append(score_arr)
    prev_prob_df = prev_prob_df.append(prev_prob_arr)

ID: 456675_3 2021-02-12 12:16:33 
 7 709 703 7.0 3.0 0.0 
 [1, 0, 1, 2, 3, 5, 6, 7, 8, 7, 6, 7, 6, 7, 8, 7, 6, 7, 6, 7, 6, 7, 6, 5, 6] 
---------------------------

ID: 456675_3 2021-02-26 10:44:33 
 11 767 762 7.0 3.0 1.0 
 [7, 0, 1, 2, 3, 5, 4, 3, 8, 7, 6, 5, 6, 5, 4, 3, 5, 4, 3, 7, 6, 5, 4, 3, 1] 
---------------------------



In [14]:
prev_prob_df

Unnamed: 0,0
0,0.4375
0,0.153846


In [15]:
score_df

Unnamed: 0,0
0,0.0
0,1.0


In [None]:
forced = '/Users/joshuakeeling/Documents/Python/Project/line_loop/project_notebooks/line_loop/navi/test_forced_seq.npy'
state = '/Users/joshuakeeling/Documents/Python/Project/line_loop/project_notebooks/line_loop/navi/test_state_seq.npy'
rew = '/Users/joshuakeeling/Documents/Python/Project/line_loop/project_notebooks/line_loop/navi/test_rew_list.npy'
port_seq =  '/Users/joshuakeeling/Documents/Python/Project/line_loop/project_notebooks/line_loop/navi/test_state_seq.npy'





forced_seq = np.load(forced)
state_seq = np.load(state)
rew_list = np.load(rew)

state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,\
            rew_change,rew_change_end,trial_perf,trial_rew_change_end,trial_dec_count\
            = performance_arrays(state_seq,port_seq,rew_list,forced_seq)



In [None]:
# #Test case 1
state_seq = [0,1,2,8,7,6,5,3,4,5]
port_seq = [0,1,2,8,7,6,5,3,4,5]
rew_list = np.array([0,0,1,0,0,0,1,0,0,1])==1
forced_seq = np.array([1,0,0,1,0,0,0,0,0,0])==1



state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,\
            rew_change,rew_change_end,trial_perf,trial_rew_change_end,trial_dec_count\
            = performance_arrays(state_seq,port_seq,rew_list,forced_seq)


In [None]:
#Plot of performance with given reward locations

plt.imshow(score_matrix,vmin=0,vmax=1,cmap='RdBu_r')
plt.xlabel("Current State")
plt.ylabel("Rew Location \n (1,3,5,7)")
cbar = plt.colorbar()
cbar.set_label("Fraction correct")

In [None]:
#Performance across sessions 

min_state = np.min(DF['state_seq'][0])
max_state = np.max(DF['state_seq'][0])


Ncols = len(np.arange(min_state,max_state+1,1))
Nrows = np.shape(DF)[0]
empty = np.zeros((Nrows,Ncols))

for row in np.arange(np.shape(DF)[0]):
    
    STATE = DF['state_seq'][row]
    SCORE = DF['score_list'][row]
    
    for col in np.arange(np.min(STATE),np.max(STATE)+1,1):
        y = np.nanmean(SCORE[np.where(np.array(STATE)==col)[0]])
        empty[row,col] = y


In [None]:
#Plot of performance across sessions 

plt.imshow(empty,vmin=0,vmax=1,cmap='RdBu_r')
plt.xlabel("Current State")
plt.ylabel("Session")
cbar = plt.colorbar()
cbar.set_label("Fraction correct")

In [None]:
##SINGLE FILE ANALYSIS

#This one for 3:
root = '/Users/joshuakeeling/Documents/Python/Project/beh_data_newroom/line_loop_batch_3NAVI/'
#file = "'456675_10'-2021-02-23-102809.txt"
#file = "'456675_10'-2021-03-04-105905.txt"
file = "'460175_10'-2021-03-02-142443.txt"
#file = "'460175_10'-2021-03-02-132053.txt"

# #This one for 4:
# root = '/Users/joshuakeeling/Documents/Python/Project/beh_data_newroom/line_loop_batch_4_RUNNAVI/'
# #file = "'456675_3'-2021-02-12-121633.txt"
# file = "'456675_3'-2021-03-02-103339.txt"

fpath = os.path.join(root,file)
f = open(fpath)
lines = f.readlines()

#Get metadata:
experiment_name, task_name, subject_id, task_nr, graph,lineloop,date,\
    test,overview = mpk.load.get_metadata(lines)

state_seq,rew_list,port_seq,forced_seq = extract_navi_dat(lines)

overview

state_seq,rew_list,port_seq,forced_seq,rew_state,score_list,\
rew_change,rew_change_end,trial_perf,trial_rew_change_end, trial_dec_count \
= performance_arrays(state_seq,port_seq,rew_list,forced_seq)

In [None]:
new_DF = DF.sort_values('date_time')

#Plot of proportion correct decisions on cumulative trials for subsequent sessions

for i in new_DF['trial_perf']:
    
    plt.scatter(np.arange(len(i)),i,color='k',s=50,edgecolors='none')
    plt.xlabel("trial #")
    plt.ylabel("Proportion\n correct")
    plt.ylim(.1,1.1)
    seaborn.despine()
    print(np.nanmean(i))
    plt.show()

In [None]:
for index in DF.index:

    z = 0

    print("Mean perf in each rew location")
    for i in DF.rew_change_end[index]:
        y = np.nanmean(DF.score_list[index][z:i-1])
        z = i-1
        print(DF.rew_state[index][i-1], int(y*100))

    print("\nDoes perf imprv with succ trials with rew at same loc??")
    NM = 10 #arbitrary number of sections to take mean of
    low = 0
    arr = np.linspace(rew_change[0],rew_change_end[1],num=NM)

    for i in np.arange(NM):
        upp = arr[i]
        y = np.nanmean(DF.score_list[index][int(low):int(upp)])
        low = arr[i]
        print(int(y*100))

In [None]:
print("Plot perf in first set of rewards")
x = range(int(DF.trial_rew_change_end[index][0]-1))
y = DF.trial_perf[index][0:int(DF.trial_rew_change_end[index][0])-1]

plt.scatter(x, y,color='k',s=128,edgecolors='none')
plt.xlabel("# of successive rewards at some location")
plt.ylabel("Proportion\n correct")
plt.ylim(.1,1.1)
seaborn.despine()

In [None]:
print("Plot perf for successive trials at a reward location")
low = 1
y_dat = pd.DataFrame()

for index in DF.index:
    y_DAT = []
    for i in np.arange(len(DF.trial_rew_change_end[index])):
        upp = int(DF.trial_rew_change_end[index][i])
        x = np.arange(upp-low)
        y = DF.trial_perf[index][low:upp]
        low = upp + 1
        y_DAT.append(y)
        plt.scatter(x, y,color='k',s=10,edgecolors='none')
        plt.xlabel("# of successive rewards at some location")
        plt.ylabel("Proportion\n correct")
        plt.ylim(0.1,1.1)
        seaborn.despine()
        plt.plot(gaussian_filter1d(y,2.25,mode='nearest'),linewidth=1)
    y_dat = y_dat.append(y_DAT)

In [None]:
low = 1
num_correct = pd.DataFrame()
num_decisions = pd.DataFrame()

count_arr = []

for session in DF.index:
    num_corr = []
    num_dec = []
    counter = 0
    for rew_segment in np.arange(len(DF.trial_rew_change_end[session])):
        upp = int(DF.trial_rew_change_end[session][rew_segment])
        perf = DF.trial_perf[session][low:upp]
        count = DF.trial_dec_count[session][low:upp]
        correct = (np.asarray(perf))*(np.asarray(count))
        low = upp + 1
        num_corr.append(correct)
        num_dec.append(count)
        counter = counter + 1
    count_arr.append(counter)
    print(session,counter)
    score = pd.DataFrame(num_corr)/pd.DataFrame(num_dec)
    mn_score = np.nanmean(score,0)
    plt.scatter(np.arange(len(mn_score)),mn_score)
    plt.ylim(0.45,1.05)
    plt.xlim(-1,23.5)
    plt.show()
    num_correct = num_correct.append(num_corr)
    num_decisions = num_decisions.append(num_dec)

In [None]:
mean_score = np.asarray(num_correct.sum(axis=0))/np.asarray(num_decisions.sum(axis=0))
x = np.arange(len(mean_score))

plt.scatter(x, mean_score, color='k',s=10,edgecolors='none')

plt.xlabel("# of successive rewards at some location")
plt.ylabel("Proportion \n correct")
plt.ylim(0.45,1.05)
plt.xlim(-1,23.5)
seaborn.despine()
plt.plot(gaussian_filter1d(mean_score,2.25,mode='nearest'),linewidth=1)

In [None]:
y = np.nanmean(y_dat, 0)
x = np.arange(len(y))
plt.scatter(x, y, color='k',s=10,edgecolors='none')
plt.xlabel("# of successive rewards at some location")
plt.ylabel("Proportion correct \n (each trial counts the same)")
plt.ylim(0.45,1.05)
plt.xlim(-1,23.5)
seaborn.despine()
plt.plot(gaussian_filter1d(y,2.25,mode='nearest'),linewidth=1)

In [None]:
print("Plot perf for each section of reward location")
low = 0
for i in np.arange(len(DF.trial_rew_change_end[index])):
    upp = int(DF.trial_rew_change_end[index][i]-1)
    x = np.arange(upp-low)
    y = DF.trial_perf[index][low:upp]
    low = upp
    
    plt.scatter(x, y,color='k',s=128,edgecolors='none')
    plt.xlabel("# of successive rewards at some location")
    plt.ylabel("Proportion\n correct")
    plt.ylim(.3,1.1)
    seaborn.despine()
    plt.show()
