In [1]:
using CSV
using CategoricalArrays
using DataFrames
using HypothesisTests
using LinearAlgebra
using LogExpFunctions
using MixedModels
using SpecialFunctions
using Statistics

In [2]:
# load the EM library

full = false    # Maintain full covariance matrix (vs a diagional one) at the group level
emtol = 1e-2    # stopping condition (relative change) for EM

directory = "/Volumes/Tim-1/Ephys Data/"

push!(LOAD_PATH,directory)
using EM

In [3]:
# load the data

data = CSV.read("/Volumes/Tim/Photometry/10MfRatDataSet/hexLevelDf_cornerHexCorrectAlignment.csv", DataFrame);
# get rid of frames with no location
data = data[data.pairedHexState .>= 0,:];
# get rid of sessions with no ramp
data = data[data.rat .!= "IM-1292",:];
data = data[data.session .!= 90,:];
data = data[data.session .!= 92,:];
data = data[data.session .!= 94,:];

data.session = CategoricalArray(data.session);
#data.pairedHexState = CategoricalArray(data.pairedHexState);


In [4]:
function phi(x)
    return 0.5 + 0.5 * erf(x / sqrt(2))
end

phi (generic function with 1 method)

In [5]:
# these are reduced version of our favorite model
etliksepNoMF = (x,d) -> etliksep([x[1:2]; -1e99; -1e99; x[3:end]],d)
etliksepNoMB = (x,d) -> etliksep([1e99; x[1:end]],d)
etliksepNoTD1MB = (x,d) -> etliksep([1e99; -1e99; x[1:2]; -1e99; x[3:end]],d)

# this is our favorite model
function etliksep(params,data)
    lrT = phi(params[1])
    lrV = phi(params[2])
    lr0 = phi(params[3])
    gamma =  phi(params[4])
    lambda = phi(params[5])
    bint = params[6]
    bV = params[7]
    sig2 = exp(params[8])
    
    paramType = typeof(params[1])
    
    hexStates::Array{Int64,1} = data.pairedHexState .+ 1
    rwd::Array{Int64,1} = data.rwd
    da::Array{Float64,1} = data.DA
    atport::Array{Int64,1} = data.port
    
    V = zeros(paramType,126) .+ 0.2
    Tl = zeros(paramType,3,126)
    
    el = zeros(paramType,126)    
    
    lik = 0.
    
    for i = 1:length(hexStates)-1
        s = hexStates[i]
        snext = hexStates[i+1]


        if (atport[i] == -100)
            # Not at port
            # Gaussian log likelihood
            lik += -1/2 * log(2 * pi * sig2) - 1/(2*sig2) * (da[i] - (bint + bV * V[s]))^2

            V[s] += lr0 * (rwd[i] + gamma * V[snext] - V[s])
            el .*= lambda
            el[s] = 1

        else
            # At port
            #el .*= lambda
            #el[s] = 1

            # update paths in
            Tl[atport[i] + 1,:] = (1-lrT) * Tl[atport[i] + 1,:] + lrT * el

            # Truncated TD0 update
            V[s] += lr0 * (rwd[i] - V[s])

            # TD1 / MB update
            V += lrV .* Tl[atport[i] + 1,:] .* (rwd[i] .- V)
            
            # reset eligibility traces
            el = zeros(paramType,126)
        end
 
    end
    
    return -lik
end

# TD lambda
function tdlambdalik(params,data)
    lr = phi(params[1])
    gamma =  phi(params[2])
    lambda = phi(params[3])
    bint = params[4]
    bV = params[5]
    sig2 = exp(params[6])
   
    paramType = typeof(params[1])
   
    hexStates::Array{Int64,1} = data.pairedHexState .+ 1
    rwd::Array{Int64,1} = data.rwd
    da::Array{Float64,1} = data.DA
    atport::Array{Int64,1} = data.port
   
    V = zeros(paramType,126) .+ 0.2
   
    e = zeros(paramType,126)
   
    lik = 0.
   
    for i = 1:length(hexStates)-1
        s = hexStates[i]
        snext = hexStates[i+1]

        if (atport[i] == -100)
            # Not at port
            # Gaussian log likelihood
            lik += -1/2 * log(2 * pi * sig2) - 1/(2*sig2) * (da[i] - (bint + bV * V[s]))^2
            
            e .*= lambda * gamma
            e[s] = 1

            V += lr .* e .* (rwd[i] + gamma * V[snext] - V[s])

        else
            # At port
            #e .*= lambda * gamma
            #e[s] = 1
           
            # Truncated TD update
            V += lr .* e .* (rwd[i] - V[s])

            # reset eligibility traces
            e = zeros(paramType,126)
        end
 
    end
   
    return -lik
end

tdlambdalik (generic function with 1 method)

In [6]:
data.sub = data.session

subs = levels(data.session)
NS = length(subs)

X = ones(NS);

In [None]:
# fit TD lambda 

startbetas = [0 0 0 0 0 0.]
startsigma = [1, 1, 1, 1, 1, 1]

(betastdl,sigmatdl,xtdl,ltdl,htdl) = em(data,subs,X,startbetas,startsigma,tdlambdalik; emtol=emtol, full=full,maxiter=1000);


iter: 7
betas: [-0.48 1.18 1.59 -0.12 0.84 -0.76]
sigma: [0.24, 0.2, 0.33, 0.02, 0.23, 0.18]
free energy: -242133.295137
change: [-0.001842, 0.000134, 0.00134, -0.000129, 2.3e-5, -4.0e-6, 0.001335, 0.004612, 0.006658, 0.000212, 0.000445, 5.0e-6]
max: 0.006658


In [None]:
# x should have subject-level parameters

In [9]:
betastdl

1×6 adjoint(::Vector{Float64}) with eltype Float64:
 -0.484016  1.17699  1.58907  -0.123824  0.842665  -0.763422

In [21]:
using CSV

In [24]:
sesh_lambdas = 0.5 .+ 0.5 .* erf.(xtdl[:,3] ./ sqrt(2))

70-element Vector{Float64}:
 0.9689179493479061
 0.9925631314063585
 0.9263554547553853
 0.9172896217749235
 0.9387464347147162
 0.9323326607102616
 0.9549280492980134
 0.953739250351694
 0.9592432845691254
 0.9871611636664033
 0.9488295359570362
 0.9809666372797796
 0.9779938132588593
 ⋮
 0.9721604782130102
 0.946777354579918
 0.9070495832781289
 0.9380125222191908
 0.9358491937262206
 0.9833487870860862
 0.9666078396688357
 0.918543296987325
 0.9604625643485172
 0.9031778237314549
 0.9590516696121829
 0.9750130664965286

In [25]:
rats = [data[data.session .== sub,:rat][1] for sub in subs];
sessionLambdaDf = DataFrame(rat = rats,lambda = sesh_lambdas,session=subs)

Unnamed: 0_level_0,rat,lambda,session
Unnamed: 0_level_1,String7,Float64,Int64
1,IM-1272,0.968918,0
2,IM-1272,0.992563,1
3,IM-1272,0.926355,2
4,IM-1272,0.91729,3
5,IM-1272,0.938746,4
6,IM-1272,0.932333,5
7,IM-1272,0.954928,6
8,IM-1272,0.953739,7
9,IM-1273,0.959243,9
10,IM-1273,0.987161,11


In [26]:
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/optLambdasBySession.csv",sessionLambdaDf)

"/Volumes/Tim/Photometry/10MfRatDataSet/optLambdasBySession.csv"

In [None]:
print("done")

In [7]:
# fit 2 component model 

startbetas = [0 0 0 0 0 0 0 0.]
startsigma = [1, 1, 1, 1, 1, 1, 1, 1]

(betasa,sigmaa,xa,la,ha) = em(data,subs,X,startbetas,startsigma,etliksep; emtol=emtol, full=full,maxiter=1000);


iter: 20
betas: [0.55 -0.07 -1.43 -2.2 0.73 -0.14 0.78 -0.76]
sigma: [1.04, 0.39, 0.41, 1.98, 0.17, 0.03, 0.22, 0.18]
free energy: -242528.338895
change: [2.3e-5, -2.1e-5, -9.0e-6, -0.002115, 2.2e-5, -3.0e-6, 6.0e-6, -1.0e-6, 0.00015, 2.2e-5, 0.000129, 0.009436, 6.9e-5, 1.6e-5, 1.5e-5, 0.0]
max: 0.009436


In [9]:
# this computes the p value for betaV

(standarderrors,pvalues,covmtx) = emerrors(data,subs,xa,X,ha,betasa,sigmaa,etliksep)

([0.1573833122846947, 0.08692815021329896, 0.08245609215213251, 0.3225372812924188, 0.05163631287309376, 0.019705464749661356, 0.05782991883747753, 0.050294892939845674], [0.0005654221649205428, 0.42969704443301115, 8.41908865809748e-54, 2.53184617192817e-11, 7.12070052146316e-39, 1.1749752159827929e-11, 2.8454540963956194e-36, 1.7285854580370576e-43], [0.02476950698570174 -0.0003684214016176538 … 2.3404765229258165e-5 -2.214531277678762e-6; -0.0003684214016176537 0.007556503299505867 … -0.00031921625595149794 6.940080679648974e-8; … ; 2.3404765229258192e-5 -0.00031921625595149794 … 0.0033442995127492387 -4.772021120873036e-7; -2.2145312776787643e-6 6.940080679648985e-8 … -4.772021120873036e-7 0.002529576255830538])

In [51]:
betasa[7]

0.783911158983642

In [50]:
standarderrors[7]

0.05782991883747753

In [10]:
# here it is

pvalues[7]

2.8454540963956194e-36

In [18]:
using JLD2

In [25]:
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/betasa.jld2","betasa",betasa)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/sigmaa.jld2","sigmaa",sigmaa)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/xa.jld2","xa",xa)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/la.jld2","la",la)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/ha.jld2","ha",ha)

In [29]:
# fit 3 reduced versions of model
startbetas = [0 0 0 0 0 0.]
startsigma = [1, 1, 1, 1, 1, 1]

(betasb,sigmab,xb,lb,hb) = em(data,subs,X,startbetas,startsigma,etliksepNoMF; emtol=emtol, full=full,maxiter=1000);


iter: 7
betas: [0.48 0.2 0.52 -0.15 0.61 -0.75]
sigma: [1.25, 0.52, 0.29, 0.03, 0.21, 0.18]
free energy: -243580.163114
change: [0.007924, 0.000311, 0.000177, -3.5e-5, 1.9e-5, -2.0e-6, 0.009999, 0.002037, 0.001359, 0.000234, 0.000372, 4.0e-6]
max: 0.009999


In [30]:
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/betasb.jld2","betasb",betasb)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/sigmab.jld2","sigmab",sigmab)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/xb.jld2","xb",xb)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/lb.jld2","lb",lb)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/hb.jld2","hb",hb)

In [10]:
startbetas = [0 0 0 0 0 0 0.]
startsigma = [1, 1, 1, 1, 1, 1, 1]

(betasc,sigmac,xc,lc,hc) = em(data,subs,X,startbetas,startsigma,etliksepNoMB; emtol=emtol, full=full,maxiter=1000);


iter: 12
betas: [-0.48 -1.25 -1.29 0.74 -0.14 0.99 -0.76]
sigma: [0.73, 0.23, 2.48, 0.11, 0.03, 0.53, 0.19]
free energy: -230933.687864
change: [-3.2e-5, -0.000101, -0.002211, 5.1e-5, -2.2e-5, 1.5e-5, -1.0e-6, 1.0e-6, 0.001318, 0.008537, 0.00082, 5.1e-5, 6.6e-5, 1.0e-6]
max: 0.008537


In [31]:
startbetas = [0 0 0 0 0.]
startsigma = [1, 1, 1, 1, 1]

(betasd,sigmad,xd,ld,hd) = em(data,subs,X,startbetas,startsigma,etliksepNoTD1MB; emtol=emtol, full=full,maxiter=1000);


iter: 5
betas: [-0.03 0.76 -0.04 0.39 -0.72]
sigma: [0.7, 1.57, 0.02, 0.5, 0.17]
free energy: -246635.773324
change: [-0.004795, 0.000218, -0.00019, 7.5e-5, -1.0e-6, 0.000712, 0.000818, 0.000322, 0.001086, 2.0e-6]
max: 0.004795


In [32]:
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/betasd.jld2","betasd",betasd)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/sigmad.jld2","sigmad",sigmad)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/xd.jld2","xd",xd)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/ld.jld2","ld",ld)
save("/Volumes/Tim/Photometry/10MfRatDataSet/julia_model_outputs/hd.jld2","hd",hd)

In [34]:
# 2 component model (2nd one) is better on aggregate than all the alternatives

println(iaic(xa,la,ha,betasa,sigmaa))
println(iaic(xb,lb,hb,betasb,sigmab))
println(iaic(xd,ld,hd,betasd,sigmad))


241696.09787314467
242816.92249796362
245908.16968047104


In [13]:
# compute cross validated model scores for each session

lootdl = loocv(data,subs,xtdl,X,betastdl,sigmatdl,tdlambdalik;emtol=emtol, full=full)

Subject: 1..2..3..4..5..6..7..8..9..10..11..12..13..14..15..16..17..18..19..20..21..22..23..24..25..26..27..28..29..30..31..32..33..34..35..36..37..38..39..40..41..42..43..44..45..46..47..48..49..50..51..52..53..54..55..56..57..58..59..60..61..62..63..64..65..66..67..68..69..70..

70-element Vector{Float64}:
 5912.735828657604
 6444.623889321575
 5884.231789515827
 5834.856269221079
 3229.7074418038396
 3266.707190377055
 4034.5741338489347
 3189.1107616129275
 2152.407384595823
 2700.300074376121
 2565.0459276756433
 2256.073872237307
 2323.551667564118
    ⋮
 2121.1631752720045
 2519.8562528515426
 1645.4125784268817
 1352.540124745271
 1838.23784211435
 1379.1452096735738
 2461.2569315258543
 6452.962959058842
 4178.368171355523
 4947.535625025737
 4572.848998867438
 3346.9245543394236

In [None]:
looa = loocv(data,subs,xa,X,betasa,sigmaa,etliksep;emtol=emtol, full=full)

Subject: 1..2..3..4..5..6..7..8..9..10..11..12..13..14..15..16..17..18..19..20..21..22..23..24..25..26..27..28..29..30..31..32..33..34..35..36..37..38..39..40..41..42..43..44..45..46..47..48..49..50..51..52..53..54..55..56..57..58..59..60..61..62..63..64..65..66..67..68..69..70..

70-element Vector{Float64}:
 6067.641795617361
 6667.1204550752755
 6053.842603266715
 6058.384956910505
 3451.3567141827248
 3407.21551223077
 4285.406812609134
 3334.1362715072632
 2180.3525739537326
 2780.8606314097597
 2685.2113812375287
 2378.732493078734
 2403.87741180545
    ⋮
 2329.9640153801993
 2716.7012945360666
 1790.1501220153032
 1677.7757725866502
 2610.347100999705
 1425.2119985701877
 2816.7812554234574
 6484.203583787722
 4246.864964766661
 5015.109903069645
 4628.325473537063
 3418.6152163191728

In [41]:
rats = [data[data.session .== sub,:rat][1] for sub in subs];
sessiondf = DataFrame(rat = rats, looa=looa)
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/loocvDualModelBySession.csv",sessiondf)

"/Volumes/Tim/Photometry/10MfRatDataSet/loocvDualModelBySession.csv"

In [None]:
loob = loocv(data,subs,xb,X,betasb,sigmab,etliksepNoMF;emtol=emtol, full=full)

Subject: 1..2..3..4..5..6..7..8..9..10..11..12..13..14..15..16..17..18..19..20..21..22..23..24..25..26..27..28..29..30..31..32..33..34..35..36..37..38..39..40..41..42..43..44..45..46..47..48..49..50..51..52..53..54..55..56..57..58..59..60..61..62..63..64..65..66..67..68..69..70..

70-element Vector{Float64}:
 6095.167082639339
 6704.6372410717695
 6054.474159620529
 6074.492400131443
 3445.0969810423953
 3419.806903513694
 4319.226415681134
 3338.8912615870076
 2192.266436624914
 2804.0989818792227
 2701.165629110244
 2392.669285860079
 2409.9517521701064
    ⋮
 2365.711748638728
 2761.933680471531
 1796.0696940349621
 1686.5647924764585
 2642.7384361832273
 1459.3264447446152
 2873.81051277185
 6485.39867666618
 4265.245208412661
 5044.818645733594
 4646.355891645376
 3453.8419155907764

In [44]:
rats = [data[data.session .== sub,:rat][1] for sub in subs];
sessiondf = DataFrame(rat = rats, loob=loob)
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/loocvGlobalOnlyBySession.csv",sessiondf)

"/Volumes/Tim/Photometry/10MfRatDataSet/loocvGlobalOnlyBySession.csv"

In [45]:
lood = loocv(data,subs,xd,X,betasd,sigmad,etliksepNoTD1MB;emtol=emtol, full=full)

Subject: 1..2..3..4..5..6..7..8..9..10..11..12..13..14..15..16..17..18..19..20..21..22..23..24..25..26..27..28..29..30..31..32..33..34..35..36..37..38..39..40..41..42..43..44..45..46..47..48..49..50..51..52..53..54..55..56..57..58..59..60..61..62..63..64..65..66..67..68..69..70..

70-element Vector{Float64}:
 6243.019021973626
 6705.412686110986
 6061.721054303925
 6237.836674320178
 3486.649499786647
 3447.191490503531
 4380.371840757959
 3336.5874412031935
 2179.3893341440303
 2816.416042347567
 2731.3779519663194
 2438.0082422900923
 2440.5901991403975
    ⋮
 2397.4792437369624
 2873.9780544763225
 1849.5831509206528
 1724.125802158099
 2695.080446594977
 1522.2099289051573
 2908.0748011166766
 6542.925235878019
 4280.382225466262
 5051.088342614199
 4676.162702022047
 3476.9785586127814

In [46]:
rats = [data[data.session .== sub,:rat][1] for sub in subs];
sessiondf = DataFrame(rat = rats, lood=lood)
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/loocvTDOnlyBySession.csv",sessiondf)

"/Volumes/Tim/Photometry/10MfRatDataSet/loocvTDOnlyBySession.csv"

In [None]:
looa

In [47]:
rats = [data[data.session .== sub,:rat][1] for sub in subs];
sessiondf = DataFrame(rat = rats, looa=looa,loob=loob,lood=lood)
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/loocvsBySession.csv",sessiondf)
ratdf = combine(groupby(sessiondf,:rat),:looa => sum,:loob => sum,:lood => sum)
CSV.write("/Volumes/Tim/Photometry/10MfRatDataSet/loocvsByRat.csv",ratdf)

"/Volumes/Tim/Photometry/10MfRatDataSet/loocvsByRat.csv"

In [49]:
# these are t tests on model comparisons over rats 
print(OneSampleTTest(ratdf.looa_sum-ratdf.loob_sum))
print(OneSampleTTest(ratdf.looa_sum-ratdf.lood_sum))

One sample t-test
-----------------
Population details:
    parameter of interest:   Mean
    value under h_0:         0
    point estimate:          -124.16
    95% confidence interval: (-169.4, -78.96)

Test summary:
    outcome with 95% confidence: reject h_0
    two-sided p-value:           0.0002

Details:
    number of observations:   9
    t-statistic:              -6.334910969619532
    degrees of freedom:       8
    empirical standard error: 19.599398732139736
One sample t-test
-----------------
Population details:
    parameter of interest:   Mean
    value under h_0:         0
    point estimate:          -467.583
    95% confidence interval: (-679.6, -255.5)

Test summary:
    outcome with 95% confidence: reject h_0
    two-sided p-value:           0.0009

Details:
    number of observations:   9
    t-statistic:              -5.085133486376718
    degrees of freedom:       8
    empirical standard error: 91.9509034842352
