In [1]:
import torch
import pyro
import pyro.distributions as dist

class PloidyPrior:
    def __init__(self, 
                 components=[
                     {'mu': 2.0, 'sigma': 0.2, 'weight': 0.3},
                     {'mu': 3.0, 'sigma': 0.3, 'weight': 0.3},
                     {'mu': 4.0, 'sigma': 0.3, 'weight': 0.25},
                     {'mu': 5.0, 'sigma': 0.4, 'weight': 0.15}
                 ]):
        """
        Implements a mixture model prior for cancer cell line ploidy using Pyro.
        
        Parameters:
        components: list of dicts, each containing mu, sigma, and weight for a mixture component
        """
        self.components = components
        
        # Convert parameters to PyTorch tensors
        self.mus = torch.tensor([c['mu'] for c in components])
        self.sigmas = torch.tensor([c['sigma'] for c in components])
        self.weights = torch.tensor([c['weight'] for c in components])
        
        # Normalize weights
        self.weights = self.weights / self.weights.sum()
        
    def sample(self, sample_shape=(1,)):
        """
        Sample from the ploidy prior
        
        Parameters:
        sample_shape: tuple, shape of samples to generate
        
        Returns:
        samples: tensor of ploidy samples
        """
        with pyro.plate("samples", sample_shape[0] if len(sample_shape) > 0 else 1):
            # Sample mixture component
            mixture_idx = pyro.sample(
                "mixture_idx",
                dist.Categorical(self.weights)
            )
            
            # Sample ploidy from selected component
            ploidy = pyro.sample(
                "ploidy",
                dist.Normal(
                    self.mus[mixture_idx],
                    self.sigmas[mixture_idx]
                )
            )
            
        return ploidy

def plot_samples(prior, n_samples=1000):
    """
    Generate and plot samples from the prior
    """
    import matplotlib.pyplot as plt
    
    # Generate samples
    with torch.no_grad():
        samples = prior.sample((n_samples,))
    
    # Plot histogram
    plt.figure(figsize=(10, 6))
    plt.hist(samples.numpy(), bins=50, density=True, alpha=0.6)
    plt.xlabel('Ploidy')
    plt.ylabel('Density')
    plt.title('Samples from Ploidy Prior')
    plt.grid(True, alpha=0.3)
    
    # Add vertical lines for component means
    for mu in prior.mus:
        plt.axvline(mu.item(), color='red', linestyle='--', alpha=0.3)
    
    return plt

In [None]:
# Example usage
pyro.clear_param_store()

# Create prior
prior = PloidyPrior()

# Generate single sample
single_sample = prior.sample()
print(f"Single ploidy sample: {single_sample.item():.2f}")

# Generate multiple samples
batch_samples = prior.sample((5,))
print(f"Batch of samples: {batch_samples.numpy()}")

# Plot distribution of samples
plt = plot_samples(prior)

In [3]:
PloidyPrior().sample()

tensor([1.7700])