<a href="https://colab.research.google.com/github/reyhanehtt/B-and-Non-B-DNA/blob/main/HMM/Pyro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyro-ppl torch

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam




Collecting pyro-ppl
  Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)
Collecting pyro-api>=0.1.1 (from pyro-ppl)
  Downloading pyro_api-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Downloading pyro_ppl-1.9.1-py3-none-any.whl (755 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.9.1


In [None]:
def hmm_model(data, num_states):
    # Define the priors for the transition and emission parameters
    transition_probs = pyro.param("transition_probs_param", torch.rand(num_states, num_states), constraint=dist.constraints.simplex)
    emission_means = pyro.param("means_param", torch.zeros(num_states))  # Initialize means
    emission_stds = pyro.param("stds_param", torch.ones(num_states), constraint=dist.constraints.positive)  # Initialize stds

    # Define the initial state distribution
    initial_probs = pyro.param("initial_probs_param", torch.ones(num_states) / num_states, constraint=dist.constraints.simplex)

    # Sample the initial state
    state = pyro.sample("state_0", dist.Categorical(initial_probs))

    # Loop through the observed data
    for t, observation in enumerate(data):
        # Sample the current state given the previous state
        state = pyro.sample(f"state_{t+1}", dist.Categorical(transition_probs[state]))

        # Sample the observation given the current state
        pyro.sample(f"obs_{t+1}", dist.Normal(emission_means[state], emission_stds[state]), obs=observation)


transition_probs: This is a
2×2 matrix where each row contains the transition probabilities from one state to another. We use a Dirichlet distribution as a prior.

initial_probs: This is the prior over the initial hidden state. The system can start in either B-DNA or Non-B-DNA.

means and stds: These represent the mean and standard deviation of the Gaussian distribution associated with each state. The observations (DNA features) are generated based on these parameters.

hidden_state: This tracks the current hidden state (B-DNA or Non-B-DNA). The state at time t depends on the transition from the state at time t-1.

In [None]:
def hmm_guide(data, num_states):
    # Learnable parameters for transition probabilities
    pyro.param("transition_probs_param", torch.ones(num_states, num_states), constraint=dist.constraints.simplex)

    # Learnable parameters for initial state probabilities
    pyro.param("initial_probs_param", torch.ones(num_states), constraint=dist.constraints.simplex)

    # Learnable parameters for the Gaussian means and std deviations for each state
    pyro.param("means_param", torch.zeros(num_states))
    pyro.param("stds_param", torch.ones(num_states), constraint=dist.constraints.positive)


transition_probs_param: This is a learnable parameter that represents the transition matrix.

initial_probs_param: A learnable parameter for the initial state probabilities.

means_param and stds_param: These are learnable parameters for the Gaussian emission distributions (mean and standard deviation) for the B-DNA and Non-B-DNA states.

Now, we’ll define the optimizer and the SVI (Stochastic Variational Inference) object that will be used to optimize the parameters in the guide.

In [None]:
# Create the optimizer
optimizer = Adam({"lr": 0.01})

# Create the SVI object
svi = SVI(hmm_model, hmm_guide, optimizer, loss=Trace_ELBO())


Adam: We use the Adam optimizer to update the variational parameters.

SVI: This is the Pyro object that performs variational inference. It requires the model, the guide, the optimizer, and the loss function (Trace_ELBO).

In [None]:
def train_hmm(data, num_states, num_steps=1000):
    pyro.clear_param_store()

    for step in range(num_steps):
        # Perform a step of optimization
        loss = svi.step(data, num_states)

        # Print the loss every 50 steps
        if step % 50 == 0:
            print(f"Step {step}: Loss = {loss}")

# Define the number of hidden states (e.g., 2 states for B-DNA and Non-B-DNA)
num_states = 2

# Example data: a 1D array of observed features (this should represent your DNA feature data)
data = torch.tensor([2.1, 1.5, -0.2, 0.3, 1.8, 1.7, -1.1, 2.2, 0.1])

# Train the model on the data
train_hmm(data, num_states)




Step 0: Loss = 24.691918432712555




Step 50: Loss = 21.330583810806274
Step 100: Loss = 20.97357738018036
Step 150: Loss = 20.649157464504242
Step 200: Loss = 20.672191560268402
Step 250: Loss = 20.545710861682892
Step 300: Loss = 20.166686832904816
Step 350: Loss = 20.137859106063843
Step 400: Loss = 21.817894160747528
Step 450: Loss = 21.11650174856186
Step 500: Loss = 19.975821137428284
Step 550: Loss = 20.256884574890137
Step 600: Loss = 21.24632740020752
Step 650: Loss = 20.992319494485855
Step 700: Loss = 20.927818477153778
Step 750: Loss = 20.048939287662506
Step 800: Loss = 21.273469001054764
Step 850: Loss = 20.16432449221611
Step 900: Loss = 19.971367090940475
Step 950: Loss = 20.89401423931122


svi.step(data): This function performs one step of optimization, updating the variational parameters and returning the ELBO loss.

Training Loop: The loop runs for a number of iterations (num_steps), printing the loss every 50 steps to monitor the training process.

In [None]:
# Print the learned transition probabilities
print("Learned transition probabilities:", pyro.param("transition_probs_param").data)

# Print the learned initial probabilities
print("Learned initial state probabilities:", pyro.param("initial_probs_param").data)

# Print the learned Gaussian means
print("Learned means:", pyro.param("means_param").data)

# Print the learned Gaussian standard deviations
print("Learned standard deviations:", pyro.param("stds_param").data)


Learned transition probabilities: tensor([[0.4928, 0.5072],
        [0.6465, 0.3535]])
Learned initial state probabilities: tensor([0.7535, 0.2465])
Learned means: tensor([0.9213, 0.9440])
Learned standard deviations: tensor([1.0812, 1.1518])
