In [None]:
%pylab

In [None]:
%matplotlib inline

In [None]:
import pystan, pickle, popmachine, scipy
import pandas as pd

from pystan_cache.pystan_cache import caching_stan

In [None]:
gp_multi = caching_stan.stan_model(file='stan-models/gp_multi.stan')
gp_multi_marginal = caching_stan.stan_model(file='stan-models/gp_multi_marginal.stan')
gp_multi_marginal_gamma = caching_stan.stan_model(file='stan-models/gp_multi_marginal_gamma.stan')

In [None]:
pystan.__version__

In [None]:
machine = popmachine.Machine('sqlite:///../popmachine_local/.popmachine.db')

In [None]:
ds = machine.search(plates = ['20150517 PQ 3'], Strain='ura3', **{'mM PQ':[0.0], 'M NaCl':[4.2, None]})

# plates = [u'20150517 PQ 3', u'20150715 PQ 8', u'20150702 PQ 6',
#        u'20150630 PQ 5', u'20150704 PQ 7', u'20150717 PQ 9']
# ds = machine.search(plates=plates, Strain='ura3', **{'mM PQ':[0.0, .083], 'M NaCl':[4.2, None]})

ds.log()
ds.filter()
ds.trim(5)
ds.poly_scale(2, groupby=['plate', 'mM PQ'])

ds.data = ds.data.iloc[::3,:]

In [None]:
plt.figure(figsize=(4,4))
ds.plot(columns=['plate'], colorby=['mM PQ'], colorLabels=False)

In [None]:
xraw ,y, design, labels = ds.build(['mM PQ', 'plate'],scale=True)

In [None]:
ymean, ystd = y.mean(), y.std()

y = (y-y.mean())/y.std()

x = (xraw-xraw.mean())/xraw.std()

In [None]:
dm = np.ones((y.shape[1], 1))

In [None]:
p = dm.shape[1]
n = x.shape[0]

train_data = {
    'N': n,
    'P':y.shape[1],
    'K':dm.shape[1],
    'L':1,
    'prior':[1],     
    'design': dm
}

train_data['y'] = y.T
train_data['x'] = x[:,0]

train_data['alpha_prior'] = [[1,1]]

# train_data['length_scale_prior'] = [[14, 1.0]]
# train_data['length_scale_prior'] = [[2, .5]]
train_data['length_scale_prior'] = [[.5, .5]]

train_data['marginal_alpha_prior'] = [.5, .1]
train_data['marginal_lengthscale_prior'] = [8, 2.0]

train_data['sigma_prior'] = [.1,1.5]

train_data

In [None]:
tsamples = gp_multi_marginal.sampling(data=train_data, chains=4, iter=2000, control = {'adapt_delta': 0.8})
# tsamples = gp_multi_marginal_gamma.sampling(data=train_data, chains=4, iter=2000, control = {'adapt_delta': 0.8})

In [None]:
tsamples

In [None]:
summary = tsamples.summary()
summary = pd.DataFrame(summary['summary'], columns=summary['summary_colnames'], index=summary['summary_rownames'])

In [None]:
summary.Rhat.describe()

In [None]:
summary.Rhat.

In [None]:
summary.Rhat.values

In [None]:
plt.hist(summary.Rhat.values[~summary.Rhat.isnull()])

In [None]:
summary.head()

In [None]:
plt.figure(figsize=(4,10))
tsamples.traceplot(['length_scale', 'alpha', 'sigma','lp__'])
plt.tight_layout()

In [None]:
plt.figure(figsize=(4,10))
tsamples.traceplot(['sigma', 'marginal_alpha', 'marginal_lengthscale'])
plt.tight_layout()

In [None]:
tsamp = tsamples.extract(permuted=True)

In [None]:
i = 0
plt.plot(xraw, tsamp['f'][:,i,:].mean(0),)
plt.fill_between(xraw[:,0], 
                 tsamp['f'][:,i,:].mean(0)-2*tsamp['f'][:,i,:].std(0),
                 tsamp['f'][:,i,:].mean(0)+2*tsamp['f'][:,i,:].std(0),alpha=.3)

In [None]:
i = 0

s = np.where(tsamp['length_scale'] == tsamp['length_scale'].min())
plt.plot(xraw, tsamp['f'][s, i, :].reshape(31))

s = np.where(tsamp['length_scale'] == tsamp['length_scale'].max())
plt.plot(xraw, tsamp['f'][s, i, :].reshape(31))

In [None]:
tsamp['f'][s, 0, :].reshape(47)

In [None]:
for p in ['alpha', 'sigma', 'marginal_alpha']:

    temp = tsamp[p]

    z = np.linspace(temp.min(), temp.max())
    kde = scipy.stats.gaussian_kde(temp)

    plt.plot(z, kde(z)/kde(z).max(), label=p)
    
# temp = tsamp['sigma']
# z = np.linspace(temp.min(), temp.max())
# kde = scipy.stats.gaussian_kde(temp)

# plt.plot(z, kde(z)/kde(z).max(), label='sigma')

plt.semilogx()    
plt.legend()

# plt.savefig('figures/ura3_0.083mMPQ-alpha-stan.pdf', bbox_inches='tight')

In [None]:
temp.shape

# null model 

In [None]:
nullSamples = gp_multi.sampling(data=train_data, chains=4, iter=2000, control = {'adapt_delta': 0.8})

In [None]:
nullSamples

In [None]:
plt.figure(figsize=(10,4))
nullSamples.traceplot(['length_scale', 'alpha', 'sigma','lp__'])
plt.tight_layout()

In [None]:
nsamp = nullSamples.extract(permuted=True)

In [None]:
ncol = 2
nrow = int(1.*(dm.shape[1]-1)/ncol) + 1

plt.figure(figsize=(4*ncol, 4*nrow))

for i in range(1):
    
    plt.subplot(dm.shape[1]/ncol + 1, ncol, i + 1)
    
    plt.plot(x, nsamp['f'][:,i,:].mean(0),)
    #plt.plot(x, tsamp['f'][:,i,:].T,c='k', alpha=.2)
    plt.fill_between(x[:,0], 
                     nsamp['f'][:,i,:].mean(0)-2*nsamp['f'][:,i,:].std(0),
                     nsamp['f'][:,i,:].mean(0)+2*nsamp['f'][:,i,:].std(0),alpha=.1)
    plt.plot([x.min(), x.max()], [0, 0], lw=3, c='k')
    
    if i > 1:
        plt.ylim(-.48, .48)
    
    #plt.plot(x[train_ind],f[i,train_ind].T,'--')

In [None]:
for i in range(2):

    temp = nsamp['alpha'][:,i]

    z = np.linspace(temp.min()*.7, temp.max()*1.3)
    kde = scipy.stats.gaussian_kde(temp)


    plt.plot(z, kde(z), label='alpha %d'%i)

plt.semilogx()    
plt.legend()

In [None]:
for i in range(2):

    temp = nsamp['length_scale'][:,i]

    z = np.linspace(temp.min(), temp.max())
    kde = scipy.stats.gaussian_kde(temp)


    plt.plot(z, kde(z), label='length_scale %d'%i)

plt.semilogx()    
plt.legend()

In [None]:
temp = ystd*2*tsamp['f'][:,1]

# plt.plot(x, (2*temp).mean(0),)
# plt.fill_between(x[:,0], 
#                  (2*temp).mean(0)-2*(2*temp).std(0),
#                  (2*temp).mean(0)+2*(2*temp).std(0),alpha=.1)

plt.plot(xraw, (temp).mean(0), label='$M_2$')
plt.fill_between(xraw[:,0], 
                 (temp).mean(0)-2*(temp).std(0),
                 (temp).mean(0)+2*(temp).std(0),alpha=.3)

temp = ystd*2*nsamp['f'][:,1]

plt.plot(xraw, temp.mean(0), label='$M_0$')
plt.fill_between(xraw[:,0], 
                 temp.mean(0)-2*temp.std(0),
                 temp.mean(0)+2*temp.std(0),alpha=.3)

# plt.plot(xraw, ds.data.loc[:,ds.meta['mM PQ'] != '0.0'].mean(1) - ds.data.loc[:,ds.meta['mM PQ'] == '0.0'].mean(1));

plt.plot([xraw.min(), xraw.max()], [0, 0], lw=3, c='k')

plt.legend(fontsize=12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylabel('log(OD)', fontsize=16)
plt.xlabel('time (h)', fontsize=16)

plt.plot()

plt.savefig('figures/ura3_0.083mM-PQ_f1_m02-stan.pdf', bbox_inches='tight')

In [None]:
plt.plot(ds.data.loc[:,ds.meta['mM PQ'] != '0.0'].mean(1) - ds.data.loc[:,ds.meta['mM PQ'] == '0.0'].mean(1));

In [None]:
temp = ystd*tsamp['f'][:,0]

# plt.plot(x, (2*temp).mean(0),)
# plt.fill_between(x[:,0], 
#                  (2*temp).mean(0)-2*(2*temp).std(0),
#                  (2*temp).mean(0)+2*(2*temp).std(0),alpha=.1)

plt.plot(xraw, (temp).mean(0), label='$M_2$')
plt.fill_between(xraw[:,0], 
                 (temp).mean(0)-2*(temp).std(0),
                 (temp).mean(0)+2*(temp).std(0),alpha=.3)

temp = ystd*nsamp['f'][:,0]

plt.plot(xraw, temp.mean(0), label='$M_0$')
plt.fill_between(xraw[:,0], 
                 temp.mean(0)-2*temp.std(0),
                 temp.mean(0)+2*temp.std(0),alpha=.3)

plt.plot([xraw.min(), xraw.max()], [0, 0], lw=3, c='k')

plt.legend(fontsize=12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylabel('log(OD)', fontsize=16)
plt.xlabel('time (h)', fontsize=16)

# plt.savefig('figures/ura3_0.083mM-PQ_f0_m02-stan.pdf', bbox_inches='tight')