Skip to content

Commit

Permalink
fix(models): add observation flag to inference model
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Apr 22, 2024
1 parent ba5b674 commit 719db25
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions src/pyrovelocity/models/_deterministic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
def deterministic_transcription_splicing_probabilistic_model(
times: TimeTensor,
data: MultiModalTranscriptomeTensor,
observation_flag: bool = True,
):
num_genes, num_cells, num_timepoints, num_modalities = data.shape

Expand Down Expand Up @@ -109,18 +110,19 @@ def model_solver(gene_index, cell_index):
# likelihood
sigma = numpyro.sample(
"sigma",
dist.HalfNormal(scale=1.0),
dist.HalfNormal(scale=0.1),
sample_shape=(num_modalities,),
)
sigma_expanded = sigma.reshape(1, 1, 1, num_modalities)

numpyro.sample(
"observations",
dist.Normal(
predictions,
sigma_expanded,
dist.TruncatedNormal(
low=0.001,
loc=predictions,
scale=sigma_expanded,
),
obs=data,
obs=data if observation_flag else None,
)


Expand Down Expand Up @@ -356,7 +358,12 @@ def generate_prior_inference_data(
batch_ndims=1,
parallel=False,
)
prior_predictions = predictive(rng_key_, times=times, data=data)
prior_predictions = predictive(
rng_key_,
times=times,
data=data,
observation_flag=False,
)

modality_labels = ["pre-mRNA", "mRNA"]
assert len(modality_labels) == num_modalities
Expand Down Expand Up @@ -416,7 +423,12 @@ def generate_posterior_inference_data(
batch_ndims=1,
parallel=False,
)
prior_predictions = prior_predictive(rng_key_, times=times, data=data)
prior_predictions = prior_predictive(
rng_key_,
times=times,
data=data,
observation_flag=False,
)

kernel = NUTS(model)
mcmc = MCMC(
Expand All @@ -430,10 +442,15 @@ def generate_posterior_inference_data(
mcmc.run(rng_key_, times=times, data=data)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()
posterior_samples = mcmc.get_samples(group_by_chain=False)

rng_key, rng_key_ = jax.random.split(rng_key)
posterior_predictive = Predictive(model, posterior_samples)
posterior_predictions = posterior_predictive(
rng_key_, times=times, data=data
rng_key_,
times=times,
data=data,
observation_flag=False,
)

modality_labels = ["pre-mRNA", "mRNA"]
Expand Down

0 comments on commit 719db25

Please sign in to comment.