In [1]:
import os
from time import time

import matplotlib.pyplot as plt
import pandas as pd
from jax import random
import jax
from PIL import Image
import numpyro.distributions as dist
from splotch import get_input_data, register, run_nuts, run_svi
from splotch.models import get_default_priors
from splotch.visualization import (
    plot_annotations_in_common_coordinate_system,
    plot_annotations_on_slides,
    plot_coefficients,
    plot_rates_in_common_coordinate_system,
    plot_rates_on_slides,
    plot_tissue_sections_on_slides,
    plot_variable_on_slides,
)

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

Image.MAX_IMAGE_PIXELS = 1000000000

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [2]:
splotch_input_data = get_input_data(
    "metadata.tsv", 3, min_detection_rate=0.1, num_of_neighbors=8
)

Discarding a tissue section with 1 spots
Discarding a tissue section with 1 spots
Discarding a tissue section with 2 spots
Discarding a tissue section with 1 spots
Discarding a tissue section with 2 spots
Discarding a tissue section with 2 spots
Discarding a tissue section with 1 spots
Discarding a tissue section with 1 spots
Discarding a tissue section with 8 spots
Discarding 1 spots due to low sequencing depth.
Discarding a tissue section with 1 spots


## NUTS

### map

Sampling is done sequentially over chains and genes.

In [3]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

splotch_result_nuts = run_nuts(
    key_,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="map",
    num_warmup=500,
    num_samples=500,
    num_chains=4,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

58.82572793960571


### vmap

Sampling is vectorized over chains and genes.

Should not be used on CPU. This strategy can be beneficial on GPUs.

In [4]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

splotch_result_nuts = run_nuts(
    key_,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="vmap",
    num_warmup=500,
    num_samples=500,
    num_chains=4,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

184.9018521308899


### pmap

Sampling is parallelized over chains and genes.

Note that the `XLA_FLAGS` environment variable has to be set accordingly. For instance, if you want to run four chains for two genes simultaneously, then please set `XLA_FLAGS=8`.

In [5]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

splotch_result_nuts = run_nuts(
    key_,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="pmap",
    num_warmup=500,
    num_samples=500,
    num_chains=4,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

28.73485493659973


## SVI

### map

Sampling is done sequentially over chains and genes.

In [6]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

key = random.PRNGKey(0)
key, key_ = random.split(key)

splotch_result_svi_batch_1 = run_svi(
    key,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="map",
    num_steps=10_000,
    num_samples=500,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

44.09098696708679


### vmap

Sampling is vectorized over chains and genes.

This strategy can be really beneficial on GPUs.

In [7]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

key = random.PRNGKey(0)
key, key_ = random.split(key)

splotch_result_svi_batch_1 = run_svi(
    key,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="vmap",
    num_steps=10_000,
    num_samples=500,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

44.795427083969116


### pmap

Sampling is parallelized over chains and genes.

Note that the `XLA_FLAGS` environment variable has to be set accordingly. For instance, if you want to run four chains for two genes simultaneously, then please set `XLA_FLAGS=8`.

In [8]:
jax.clear_caches() # for timing purposes ensure that the function is compiled again

key = random.PRNGKey(0)
key, key_ = random.split(key)

t = time()

key = random.PRNGKey(0)
key, key_ = random.split(key)

splotch_result_svi_batch_1 = run_svi(
    key,
    ["Slc5a7", "Gfap"],
    splotch_input_data,
    map_method="pmap",
    num_steps=10_000,
    num_samples=500,
)

splotch_result_nuts.posterior_samples["length"].block_until_ready()

print(time()-t)

29.260001182556152
