In [None]:
%pylab

In [None]:
%matplotlib inline

In [None]:
import pystan, pickle, popmachine

from pystan_cache.pystan_cache import caching_stan

In [None]:
machine = popmachine.Machine('sqlite:///../popmachine_local/.popmachine.db')

In [None]:
ds = machine.search(Strain='ura3', **{'mM PQ':[0.0, .083]})
ds.log()
ds.filter()
ds.trim(5)
ds.poly_scale(2, groupby=['plate', 'mM PQ'])

ds.data = ds.data.iloc[::3,:]

In [None]:
plt.figure(figsize=(12,2))
ds.plot(columns=['plate'], colorby=['mM PQ'])

In [None]:
x,y, design, labels = ds.build(['mM PQ', 'plate'],scale=True)

In [None]:
y = (y-y.mean())/y.std()
x = (x-x.mean())/x.std()

In [None]:
design

In [None]:
y.shape, x.shape

In [None]:
dm = np.zeros((144, 2 + 2*ds.meta.plate.unique().shape[0]))
dm[:,0] = 1
dm[:,1] = 1 - 2*design['mM PQ']

for i in range(design.plate.unique().shape[0]):
    dm[:,2+i*2:4+i*2] = dm[:,:2] * (design.plate==i).values[:,None]

In [None]:
plt.imshow(dm, aspect='auto')

In [None]:
gp_multi = caching_stan.stan_model(file='models/gp_multi.stan')

In [None]:
# random effect

p = dm.shape[1]
n = x.shape[0]

# design = np.zeros((p, 1+p))
# design[:,0] = 1
# design[:,1:] = np.eye(p)

priors = [1, 2] + [3, 4] * ds.meta.plate.unique().shape[0]

sim_data = {
    'N': n,
    'P':y.shape[1],
    'K':dm.shape[1],
    'L':4,
    'prior':priors,     
    'length_scale': [1, .5, .3],
    'alpha': [1,.4, .3],
    'sigma': .2,
    'design': dm #[[1,1,0,0],[1,0,1,0],[1,0,0,1]]
}

In [None]:
import scipy

In [None]:
z = np.linspace(0, 10)

plt.plot(z, scipy.stats.gamma.pdf(z, 1.5, scale=.4))
plt.plot(z, scipy.stats.gamma.pdf(z, 1.5, scale=2))

# plt.semilogx()

In [None]:
train_data = sim_data.copy()
train_data['N'] = x.shape[0]
train_data['y'] = y.T
train_data['x'] = x[:,0]

train_data['alpha_prior'] = [[1,1], [1,1], [.1,1], [.1,1]]
train_data['length_scale_prior'] = [[1.5,2]] * 4

train_data

In [None]:
tsamples = gp_multi.sampling(data=train_data, chains=2, iter=2000, control = {'adapt_delta': 0.8})

In [None]:
tsamples

In [None]:
plt.figure(figsize=(10,4))
tsamples.traceplot(['length_scale', 'alpha', 'sigma','lp__'])
plt.tight_layout()

In [None]:
tsamp = tsamples.extract(permuted=True)

In [None]:
ncol = 2
nrow = int(1.*(dm.shape[1]-1)/ncol) + 1

plt.figure(figsize=(4*ncol, 4*nrow))

for i in range(dm.shape[1]):
    
    plt.subplot(dm.shape[1]/ncol + 1, ncol, i + 1)
    
    plt.plot(x, tsamp['f'][:,i,:].mean(0),)
    #plt.plot(x, tsamp['f'][:,i,:].T,c='k', alpha=.2)
    plt.fill_between(x[:,0], 
                     tsamp['f'][:,i,:].mean(0)-2*tsamp['f'][:,i,:].std(0),
                     tsamp['f'][:,i,:].mean(0)+2*tsamp['f'][:,i,:].std(0),alpha=.1)
    plt.plot([x.min(), x.max()], [0, 0], lw=3, c='k')
    
    if i > 1:
        plt.ylim(-.48, .48)
    
    #plt.plot(x[train_ind],f[i,train_ind].T,'--')

In [None]:
temp = tsamp['f'][:,3::2].sum(1)

plt.plot(x, temp.mean(0),)
#plt.plot(x, tsamp['f'][:,i,:].T,c='k', alpha=.2)
plt.fill_between(x[:,0], 
                 temp.mean(0)-2*temp.std(0),
                 temp.mean(0)+2*temp.std(0),alpha=.1)

In [None]:
plt.plot(tsamp['alpha'],alpha=.4)
plt.semilogy()
plt.legend()

In [None]:
tsamp['alpha'].shape

In [None]:
for i in range(4):

    temp = tsamp['alpha'][:,i]

    z = np.linspace(temp.min(), temp.max())
    kde = scipy.stats.gaussian_kde(temp)

    plt.plot(z, kde(z)/kde(z).max(), label='alpha %d'%i)
    
temp = tsamp['sigma']
z = np.linspace(temp.min(), temp.max())
kde = scipy.stats.gaussian_kde(temp)

plt.plot(z, kde(z)/kde(z).max(), label='sigma')

plt.semilogx()    
plt.legend()

In [None]:
for i in range(4):

    temp = tsamp['length_scale'][:,i]

    z = np.linspace(temp.min(), temp.max())
    kde = scipy.stats.gaussian_kde(temp)


    plt.plot(z, kde(z), label='length_scale %d'%i)

plt.semilogx()    
plt.legend()

# null model 

In [None]:
null_train_data = train_data.copy()
null_train_data['design'] = train_data['design'][:,:2]
null_train_data['prior'] = null_train_data['prior'][:2]
null_train_data['length_scale_prior'] = null_train_data['length_scale_prior'][:2]
null_train_data['alpha_prior'] = null_train_data['alpha_prior'][:2]
null_train_data['K'] = null_train_data['L'] = 2

In [None]:
nullSamples = gp_multi.sampling(data=null_train_data, chains=2, iter=2000, control = {'adapt_delta': 0.8})

In [None]:
nullSamples

In [None]:
plt.figure(figsize=(10,4))
nullSamples.traceplot(['length_scale', 'alpha', 'sigma','lp__'])
plt.tight_layout()

In [None]:
nsamp = nullSamples.extract(permuted=True)

In [None]:
ncol = 2
nrow = int(1.*(dm.shape[1]-1)/ncol) + 1

plt.figure(figsize=(4*ncol, 4*nrow))

for i in range(2):
    
    plt.subplot(dm.shape[1]/ncol + 1, ncol, i + 1)
    
    plt.plot(x, nsamp['f'][:,i,:].mean(0),)
    #plt.plot(x, tsamp['f'][:,i,:].T,c='k', alpha=.2)
    plt.fill_between(x[:,0], 
                     nsamp['f'][:,i,:].mean(0)-2*nsamp['f'][:,i,:].std(0),
                     nsamp['f'][:,i,:].mean(0)+2*nsamp['f'][:,i,:].std(0),alpha=.1)
    plt.plot([x.min(), x.max()], [0, 0], lw=3, c='k')
    
    if i > 1:
        plt.ylim(-.48, .48)
    
    #plt.plot(x[train_ind],f[i,train_ind].T,'--')

In [None]:
for i in range(2):

    temp = nsamp['alpha'][:,i]

    z = np.linspace(temp.min()*.7, temp.max()*1.3)
    kde = scipy.stats.gaussian_kde(temp)


    plt.plot(z, kde(z), label='alpha %d'%i)

plt.semilogx()    
plt.legend()

In [None]:
for i in range(2):

    temp = nsamp['length_scale'][:,i]

    z = np.linspace(temp.min(), temp.max())
    kde = scipy.stats.gaussian_kde(temp)


    plt.plot(z, kde(z), label='length_scale %d'%i)

plt.semilogx()    
plt.legend()

In [None]:
temp = 2*tsamp['f'][:,1]

# plt.plot(x, (2*temp).mean(0),)
# plt.fill_between(x[:,0], 
#                  (2*temp).mean(0)-2*(2*temp).std(0),
#                  (2*temp).mean(0)+2*(2*temp).std(0),alpha=.1)

plt.plot(x, (temp).mean(0),)
plt.fill_between(x[:,0], 
                 (temp).mean(0)-2*(temp).std(0),
                 (temp).mean(0)+2*(temp).std(0),alpha=.1)

temp = 2*nsamp['f'][:,1]

plt.plot(x, temp.mean(0),)
plt.fill_between(x[:,0], 
                 temp.mean(0)-2*temp.std(0),
                 temp.mean(0)+2*temp.std(0),alpha=.1)

plt.plot([x.min(), x.max()], [0, 0], lw=3, c='k')