# Start up commands/load relevant functions

In [None]:
parallel = true # Run on multiple CPUs. If you are having trouble, set parallel = false: easier to debug
full = false    # Maintain full covariance matrix (vs a diagional one) at the group level
emtol = 1e-3    # stopping condition (relative change) for EM

using Distributed
if (parallel)
	# only run this once
	addprocs()
end

# this loads the packages needed -- the @everywhere makes sure they 
# available on all CPUs 

@everywhere using DataFrames
@everywhere using SharedArrays
@everywhere using ForwardDiff
@everywhere using Optim
@everywhere using LinearAlgebra       # for tr, diagonal
@everywhere using StatsFuns           # logsumexp
@everywhere using SpecialFunctions    # for erf
@everywhere using Statistics          # for mean
@everywhere using Distributions
@everywhere using GLM
@everywhere using CSV #for reading/writing csv files

# change this to where you keep the Daw's latest em code
@everywhere directory = "/Users/neil/General/Gems/learning_models/em"

#load in functions including em
@everywhere include("$directory/em.jl");
@everywhere include("$directory/common.jl");
@everywhere include("$directory/likfuns.jl");


# Data read and process

### Read in data

In [None]:
#read in data
df = readtable("/Users/neil/General/Gems/data/gem_dat_fMRI.csv")

#get rid of missed responses
df = df[df[:missed_trial].!=1,:]

#add "sub column" 
# this is just a replica of the existing column sub_no but I think em looks for "sub" specifically
df[:sub] = df[:participantID];

#change coding so that 1 = market 1 in dependent condition,
#2 and 3 refer to the two markets in the independent condition
df[:market_presented] = df[:market_presented] + 1
df[df[:blockType].==1,:market_presented] = 1

#code picking white as 2, picking black as 1
df[:state_chosen] = df[:pick_black]
df[df[:state_chosen].==0, :state_chosen] = 2

#convert this so can use in model
df[:state_chosen] = convert(Vector{Integer}, df[:state_chosen])

head(df)

In [None]:
#exlude subs 21 and 28..

df = df[df[:participantID].!=21,:];
df = df[df[:participantID].!=28,:];


In [None]:
#use recoded condition variable in the model
df[:condition_recode] = df[:blockType]
df[df[:condition_recode].==2,:condition_recode] = -1

#now 

# RL Model

In [None]:
@everywhere function rl_model(params, data)
    
    #model parameteres
	beta_mb = params[1] 
    w_slope = params[2]
    lr =  0.5 .+ 0.5.*erf(params[3]/sqrt(2))
   
    c1 = data[:state_chosen] # choice: 1 = black door, 2 = white door
    r = data[:outcome] # outcome: coded as +1 = gain, -1 = loss, 0 = neutral 
    s = data[:outcomeState] # stage 2 state: coded as 1 = gain/loss state reached, 2 = neutral state reached
    t = data[:trials] # trial number
    sub = data[:sub] # subject number
    condition = data[:condition_recode] # condition: 1 = dependent, -1=independent
    gem = data[:gem_presented] #gem presented
    market = data[:market_presented] #market presented
    reward_loss_trial = data[:rew_loss]
    force_t = data[:forcedTrial]
    block_n = data[:block_n]
    blackFirst = data[:blackFirst]
    
    SR_m = zeros(typeof(beta_mb), 2) .+ 0.5 #initalise to 0.5. stores estimates of transition probabilities for black/white door going to reward/loss state 
    SR_gem = zeros(typeof(beta_mb), 4) .+ 0.5 #initalise to 0.5. stores estimates of transition probabilities for black/white door going to reward/loss state 
   
	Qmb = zeros(typeof(beta_mb), 2) #decision variable
    
    # initialize likelihood
    lik = 0 
    
	for i = 1:length(c1)
 
        w_raw = w_slope
        w = 0.5 .+ 0.5.*erf(w_raw/sqrt(2))
        
        if gem[i]<3
            index = 1            
        else
            index = 2
        end
        
        Vtot = w*SR_m[index] .+ (1-w)*SR_gem[gem[i]]
        
        Qmb = [Vtot.*reward_loss_trial[i], (1-Vtot).*reward_loss_trial[i]]
            
        if (force_t[i] == 0)
            
            #Q-values that determine the decision
            Qd = beta_mb.*Qmb
            lik += Qd[c1[i]] .- log(sum(exp.(Qd)))
            
        else
        end
        
        SR_m_prev = SR_m[index]
        SR_gem_prev = SR_gem[gem[i]]
        
        # updates go in here - these are updates of probability estimates (not contingent on outcome)
        if (s[i]==1 & c1[i]==1)
            
            SPE_raw = 1-Vtot
            SR_m[index] = SR_m[index] .+ w*lr*(1-Vtot)
            SR_gem[gem[i]] = SR_gem[gem[i]] .+ (1-w)*lr*(1-Vtot)
            
        elseif (s[i]==2 & c1[i]==2)
            
            SPE_raw = 1-Vtot
            SR_m[index] = SR_m[index] .+ w*lr*(1-Vtot)
            SR_gem[gem[i]] = SR_gem[gem[i]] .+ (1-w)*lr*(1-Vtot)
            
        else

            SPE_raw = 0-Vtot
            SR_m[index] = SR_m[index] .+ w*lr*(0-Vtot)
            SR_gem[gem[i]] = SR_gem[gem[i]] .+ (1-w)*lr*(0-Vtot)
            
        end
        
        if (SR_m[index]>1)
            SR_m[index] = 1
        elseif (SR_m[index]<0)
            SR_m[index] = 0
        end
        
        if (SR_gem[gem[i]]>1)
            SR_gem[gem[i]] = 1
        elseif (SR_gem[gem[i]]<0)
            SR_gem[gem[i]] = 0
        end
        
	end
    
    # here if running em you can only return the likelihood
    return -lik
    
end

# Parameter optimisation

### setup variables for em


In [None]:

#store list of actual subject numbers (in subj)
subs = unique(df[:participantID])

#put in a new column called "sub" which is identical to subj - em looks for this
df[:sub] = df[:participantID];

NS = length(subs)
X = ones(NS)
betas = [0. 0. 0.]
sigma = [.5, .5, .5];


### Run em to get best fit parameters for each subject


In [None]:
# run em
# x contains the parameters for each subject (note not the same as variable X)
# l and h are per-subject likelihood and hessians
(betas, sigma, x, l, h) = em(df, subs, X, betas, sigma, rl_model; emtol=emtol, parallel=parallel, full=full);


In [None]:
aggll_iaic = iaic(x, l, h, betas, sigma)

### Generate Model Statistics 
(LOOCV)

In [None]:

#compute unbiased per subject marginal likelihoods via cross validation.
liks = loocv(df, subs, x, X, betas, sigma, rl_model; emtol=emtol, parallel=parallel, full=full)

print(sum(liks))


### Write loocv scores to csv file and save

(if you have run loocv above)

In [None]:

#put loocv scores into dataframe
loocv_scores = DataFrame(sub = subs,
liks = vec(liks));

CSV.write("loocv_scores.csv", DataFrame(loocv_scores))


### Calculate and write p values, std error and covariance

In [None]:

# standard errors on the subject-level means, based on an asymptotic Gaussian approx 
# (these may be inflated esp for small n)
(standarderrors, pvalues, covmtx) = emerrors(df, subs, x, X, h, betas, sigma, rl_model);


In [None]:
standarderrors

In [None]:

model_stats = DataFrame(stderror = vec(standarderrors),
pvalues = vec(pvalues),
covmtx_1 = vec(covmtx[:,1]),
covmtx_2 = vec(covmtx[:,2]),
covmtx_3 = vec(covmtx[:,3]));

# save model stats to csv file
CSV.write("model_stats.csv", DataFrame(model_stats));


### Write per subject model parameters to csv files and save

In [None]:

# put parameters into variable d
d=x;

# now put parameters into dataframe
params = DataFrame(sub = subs,
slope = vec(d[:,1]), 
w_raw = vec(d[:,2]),
lr_raw = vec(d[:, 3]));

CSV.write("subject_params.csv", DataFrame(params))
