In [1]:
from teddy.data.dataset import MsaLabels
from teddy.lightning.datamodule import BDS_datamodule
import teddy.data.Alphabet as alphabet
from copy import deepcopy
from typing import Callable

import torch
from torch import eye, ones
from torch.optim import Adam, AdamW

from sbi.analysis import pairplot
from sbi.utils import BoxUniform


alphabet_instance = alphabet.Alphabet(list( "ATGCX-"))

train_ratio = 0.8
batch_size = 16
val_batch_size = 16

msa = MsaLabels(dir = "data/example/seq", alphabet=alphabet_instance, limit_size=200)
data = BDS_datamodule(data_dir = "data/example/seq", 
                      alphabet=alphabet_instance, 
                      train_ratio=train_ratio, 
                      val_batch_size=val_batch_size, 
                      batch_size=batch_size)

data.setup()

# Setting up dataloaders

In [2]:
from sbi.neural_nets.net_builders import build_nsf
import torch
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

# Density estimator with first sequence and prior just for the dimensionality
torch.manual_seed(0)

# Define training params
learning_rate = 5e-4
validation_fraction = 0.1  # 10% of the data will be used for validation
stop_after_epochs = 20  # Stop training after 20 epochs with no improvement
max_num_epochs = 2**31 - 1

train_loader = data.train_dataloader()
val_loader = data.val_dataloader()

train_iter = iter(train_loader)
batch = next(train_iter)

x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
theta_batch = torch.Tensor(np.array(batch[1]))

density_estimator = build_nsf(theta_batch, x_batch) # theta batch dimension, x batch dimension

  x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
  theta_batch = torch.Tensor(np.array(batch[1]))


# Training Loop 

In [None]:
optimizer = Adam(list(density_estimator.parameters()), lr=learning_rate)
epoch, val_loss = 0, float("Inf")

best_val_loss = float("Inf")

converged = False

while epoch <= max_num_epochs and not converged:
    # Train for a single epoch.
    density_estimator.train()
    train_loss_sum = 0
    for batch in train_loader:
        optimizer.zero_grad()

        x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
        theta_batch = torch.Tensor(np.array(batch[1]))

        train_losses = density_estimator.loss(theta_batch, x_batch)
        train_loss = torch.mean(train_losses)
        train_loss_sum += train_losses.sum().item()

        train_loss.backward()

        optimizer.step()

    epoch += 1

    train_loss_average = train_loss_sum / (
        len(train_loader) * train_loader.batch_size
    )

    # Calculate validation performance.
    density_estimator.eval()
    val_loss_sum = 0

    with torch.no_grad():
        for batch in val_loader:
            x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
            theta_batch = torch.Tensor(np.array(batch[1]))
            val_losses = density_estimator.loss(
                theta_batch,
                x_batch,
            )
            val_loss_sum += val_losses.sum().item()

    # Take mean over all validation samples.
    val_loss = val_loss_sum / (len(val_loader) * val_loader.batch_size)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_since_last_improvement = 0
        best_model_state_dict = deepcopy(density_estimator.state_dict())
    else:
        epochs_since_last_improvement += 1

    # If no validation improvement over many epochs, stop training.
    if epochs_since_last_improvement > stop_after_epochs - 1:
        density_estimator.load_state_dict(best_model_state_dict)
        converged = True
        print(f'Neural network successfully converged after {epoch} epochs')

  x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
  theta_batch = torch.Tensor(np.array(batch[1]))
  x_batch = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)
  theta_batch = torch.Tensor(np.array(batch[1]))


Neural network successfully converged after 25 epochs


: 

In [None]:
train_iter = iter(train_loader)
batch = next(train_iter)

x_o = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)

print(f"Shape of x_o: {x_o.shape}            # Must have a batch dimension")

samples = density_estimator.sample((1000,), condition=x_o).detach()
print(
    f"Shape of samples: {samples.shape}  # Samples are returned with a batch dimension."
)

samples = samples.squeeze(dim=1)
print(f"Shape of samples: {samples.shape}     # Removed batch dimension.")

  x_o = torch.flatten(torch.Tensor(np.array(batch[0][0])), start_dim=1)


Shape of x_o: torch.Size([16, 151702])            # Must have a batch dimension
