In [201]:
from abc import ABC, abstractmethod
from typing import Any, List

class Distribution(ABC):
    def __init__(self, low: float, high: float):
        if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
            raise TypeError("low and high must be numeric (int or float)")
        if low >= high:
            raise ValueError("low must be less than high")

        self.low = float(low)
        self.high = float(high)

    @abstractmethod
    def discretize(self, n: float) -> Any:
        return n


class FloatDistribution(Distribution):
    def __init__(self, low: float, high: float):
        super().__init__(low, high)

    def discretize(self, n: float) -> float:
        return n


class IntDistribution(Distribution):
    def __init__(self, low: int, high: int):
        super().__init__(float(low), float(high))  # Call super with floats

    def discretize(self, n: float) -> int:
        return round(n)


class ChoiceDistribution(Distribution):
    def __init__(self, choices: List[Any]):
        self.choices = choices
        super().__init__(0, len(choices) - 1)

    def discretize(self, n: float) -> Any:
        return self.choices[round(n)]

# human readable prior
from dataclasses import dataclass, fields
from torch.distributions import Uniform
import torch

@dataclass
class ModelParams:
    individuals_local: int
    individuals_meta: int
    species_meta: int
    speciation_local: float
    speciation_meta: float
    extinction_meta: float
    env_sigma: float
    trait_sigma: float
    comp_sigma: float
    dispersal_prob: float
    mutation_rate: float
    equilib_escape: float
    num_basepairs: int
    init_type: str
    niter: int
    niterTimestep: int

@dataclass
class ModelPrior:
    individuals_local: int | IntDistribution
    individuals_meta: int | IntDistribution
    species_meta: int | IntDistribution
    speciation_local: float | FloatDistribution
    speciation_meta: float | FloatDistribution
    extinction_meta: float | FloatDistribution
    env_sigma: float | FloatDistribution
    trait_sigma: float | FloatDistribution
    comp_sigma: float | FloatDistribution
    dispersal_prob: float | FloatDistribution
    mutation_rate: float | FloatDistribution
    equilib_escape: float | FloatDistribution
    num_basepairs: int | IntDistribution
    init_type: str | ChoiceDistribution
    niter: int
    niterTimestep: int
    
    def get_joint_uniform(self) -> Uniform:
        low = []
        high = []
        for field in fields(self):
            val = getattr(self, field.name)
            if isinstance(val, Distribution):
                low.append(val.low)
                high.append(val.high)
        
        return Uniform(low=torch.tensor(low), high=torch.tensor(high))
    
    def get_params_from_sample(self, sample: torch.Tensor) -> ModelParams:
        sampled_params = {}
        sample_index = 0
        for field_obj in fields(self):
            value = getattr(self, field_obj.name)
            if isinstance(value, Distribution):
                sampled_value = sample[sample_index].item()
                sampled_params[field_obj.name] = value.discretize(sampled_value)
                sample_index += 1
            else:
                sampled_params[field_obj.name] = value
        return ModelParams(**sampled_params)

prior = ModelPrior(
    individuals_local=IntDistribution(50, 300),
    individuals_meta=IntDistribution(400, 1000),
    species_meta=50,
    speciation_local=0.05,
    speciation_meta=0.05,
    extinction_meta=0.05,
    env_sigma=0.5,
    trait_sigma=1,
    comp_sigma=0.5,
    dispersal_prob=0.1,
    mutation_rate=0.01,
    equilib_escape=1,
    num_basepairs=250,
    init_type='oceanic_island',
    niter=2000,
    niterTimestep=10
)

In [185]:
import warnings

# Import necessary modules
try:
    import rpy2.robjects as robjects
    from rpy2.robjects.packages import importr
    from rpy2.robjects import pandas2ri
except ImportError as e:
    raise ImportError(
        "The 'rpy2' library is required but not installed. "
        "Install it with 'pip install rpy2'."
    ) from e
    
# Activate pandas conversion for rpy2
pandas2ri.activate()


# Install R packages if they are not already installed
try:
    remotes = importr('remotes')
    print("Installing the 'roleR' R package from GitHub...")
    remotes.install_github("role-model/roleR", dependencies=True)
except Exception as e:
    warnings.warn(
        f"Error installing R packages: {e}\n"
        "Make sure you have R and the 'remotes' package installed correctly.",
        RuntimeWarning
    )

# Import the R package
try:
    roleR = importr('roleR')
except Exception as e:
    raise ImportError(
        f"Error importing the 'roleR' R package: {e}\n"
        "Ensure the package is installed and available in your R environment."
    ) from e

R[write to console]: Using GitHub PAT from the git credential store.



Installing the 'roleR' R package from GitHub...


R[write to console]: Skipping install of 'roleR' from a github remote, the SHA1 (cc6546a1) has not changed since last install.
  Use `force = TRUE` to force installation



In [202]:
from dataclasses import asdict
import numpy as np

class Simulator:
    def __init__(self, prior: ModelPrior, columns: list[str]):
        self.prior = prior
        self.columns = columns
    
    def simulate(self, theta: torch.Tensor) -> torch.Tensor:
        arr = []
        for t in theta:
            params = self.prior.get_params_from_sample(t)
            params = asdict(params)
            p = roleR.roleParams(**params)

            model = roleR.runRole(roleR.roleModel(p))
            stats = roleR.getSumStats(model)
            
            stats_df = pandas2ri.rpy2py(stats)
            stats_df = stats_df[[col for col in stats_df.columns if col in self.columns]]
            stats_df = stats_df.dropna()
            
            # print(np.array(stats_df).shape)
            arr.append(torch.Tensor(np.array(stats_df)))
        return arr
            

theta_samples = prior.get_joint_uniform().sample((200,))
simulator = Simulator(prior=prior, columns=["richness", "hill_abund_1", "hill_abund_2", "hill_abund_3", "hill_abund_4", "hill_trait_1", "hill_trait_2", "hill_trait_3", "hill_trait_4"])

x_samples = simulator.simulate(theta_samples)

# x_samples_shape = x_samples.shape
# x_samples = x_samples.reshape(x_samples_shape[0] * x_samples_shape[1], x_samples_shape[2])
# theta_samples = np.repeat(theta_samples, x_samples_shape[1])
# print(x_samples.shape, theta_samples.shape)

In [203]:
x_samples_transformed = torch.tensor([])
theta_samples_transformed = torch.tensor([])

for x, theta in zip(x_samples, theta_samples):
    x_samples_transformed = torch.cat((x_samples_transformed, x))
    theta_samples_transformed = torch.cat((theta_samples_transformed, torch.tile(theta, (x.shape[0], 1))))
    
print(x_samples_transformed.shape, theta_samples_transformed.shape)

torch.Size([39933, 9]) torch.Size([39933, 2])


In [188]:
from sbi.inference import SNPE

In [204]:
snpe = SNPE(prior=prior.get_joint_uniform())

density_estimator = snpe.append_simulations(theta_samples_transformed, x_samples_transformed).train()

posterior = snpe.build_posterior(density_estimator)

 Neural network successfully converged after 216 epochs.

In [242]:
theta_obs = prior.get_joint_uniform().sample((1,))  # (this is the true but unknown parameter)
print(theta_obs)
x_obs = simulator.simulate(theta_obs)[-1][-1]
print("Observed simulation output:", x_obs)

# Use the learned posterior to sample inferred parameters given the observed output
posterior_samples = posterior.sample((1000,), x=x_obs)
print("Posterior samples shape:", posterior_samples.shape)

# # Compute a point estimate (e.g. the posterior mean)
posterior_mean = posterior_samples.mean(dim=0)
print("Posterior mean estimate:", posterior_mean)

tensor([[100.1724, 490.1338]])
Observed simulation output: tensor([16.5849, 12.4378, 10.4587,  9.3509, 16.3699, 12.3007, 10.5423,  9.6749,
        25.0000])


Drawing 1000 posterior samples: 1091it [00:00, 184478.36it/s]           

Posterior samples shape: torch.Size([1000, 2])
Posterior mean estimate: tensor([106.0260, 689.8876])





In [None]:

# prior = {
#     "individuals_local": int,
#     "individuals_meta": int,
#     "species_meta": int,
#     "speciation_local": float,
#     "speciation_meta": float,
#     "extinction_meta": float,
#     "env_sigma": float,
#     "trait_sigma": float,
#     "comp_sigma": float,
#     "dispersal_prob": float,
#     "mutation_rate": float,
#     "equilib_escape": int,
#     "num_basepairs": int,
#     "init_type": str,
#     "niter": int,
#     "niterTimestep": int,
# }

# torch
# human readable discrete

# generate dataset

# class ChoiceDistribution(Distribution):
#     def __init__(self, choices: list):
#         super().__init__(0, len(choices))  # Initialize base class
#         self.choices = choices



In [None]:
import torch
from torch.distributions import Uniform



# Lower bounds for the 15 numeric parameters (order unchanged):
prior_low = torch.tensor([
    50.0,        # individuals_local (J): local community individuals
    1000.0,      # individuals_meta (JM): metacommunity individuals
    20.0,        # species_meta (SM): metacommunity species richness
    0.0,         # speciation_local (ν): local speciation rate in [0,1]
    0.001,       # speciation_meta (Λ): metacommunity speciation rate
    0.0,         # extinction_meta (Ε): extinction rate in [0,1]
    0.01,        # env_sigma: metacommunity trait evolution variance (σ²)
    0.01,        # trait_sigma: local trait evolution variance (σ²)
    0.1,         # comp_sigma: strength of ecological filtering (SE)
    0.0,         # dispersal_prob (m): immigration rate from metacommunity in [0,1]
    1e-9,        # mutation_rate (µ): mutation rate
    1.0,         # equilib_escape (α): abundance/Ne scaling factor
    100.0,       # num_basepairs (L): sequence length (bp)
])

# Upper bounds for the 15 numeric parameters:
prior_high = torch.tensor([
    1000.0,      # individuals_local (J)
    10000.0,     # individuals_meta (JM)
    80.0,        # species_meta (SM)
    1.0,         # speciation_local (ν)
    1.0,         # speciation_meta (Λ)
    1.0,         # extinction_meta (Ε)
    5.0,         # env_sigma (σ² metacommunity)
    5.0,         # trait_sigma (σ² local)
    10.0,        # comp_sigma (SE)
    1.0,         # dispersal_prob (m)
    1e-6,        # mutation_rate (µ)
    100.0,       # equilib_escape (α)
    10000.0,     # num_basepairs (L)
])

# Create the joint uniform prior over these 15 parameters.
joint_prior = Uniform(low=prior_low, high=prior_high)

# Example: Draw 5 samples from this joint prior.
samples = joint_prior.sample()
print(samples)


tensor([6.1428e+02, 3.9715e+03, 3.5400e+01, 1.5399e-01, 7.8941e-01, 6.8638e-01,
        3.3611e+00, 1.2061e-01, 9.6169e+00, 6.3239e-01, 4.6293e-08, 9.7214e+01,
        5.8275e+03])
