In [1]:
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow
from torch.utils.data import DataLoader, Dataset

from src.ml.data.splitting import create_data_splits
from src.ml.modeling import (
    model_factory,
    optimizer_factory,
)
from src.ml.utils.set_seed import set_seed

from datetime import datetime
from loguru import logger
from pathlib import Path
from shutil import copy, copytree, rmtree

import yaml

from src.ml.train_neural_network import train_neural_network
from src.ml.data import data_sets_factory
from src.ml.preprocessing import preprocessing_factory
from src.ml.utils.set_seed import set_seed

In [2]:
CONFIG_FILE = Path("src/ml/config.yaml")

In [3]:
torch.set_default_device(torch.device("cpu"))

In [4]:
# load config file

logger.info("Loading config file.")

with open(CONFIG_FILE, "r") as f:
    config = yaml.safe_load(f)

[32m2024-12-13 11:06:27.253[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading config file.[0m


In [5]:
# load data

logger.info("Loading data.")

data_sets = data_sets_factory(**config["data_set"])
logger.info("Loaded {} data sets.", len(data_sets))

# preprocess data

logger.info("Start preprocessing.")

for preprocessing_step in config["preprocessing"]:
    logger.info("Perform {} preprocessing.", preprocessing_step["name"])

    transform = preprocessing_factory(**preprocessing_step)
    data_sets = [transform(data_set) for data_set in data_sets]

dataset = data_sets[0]

[32m2024-12-13 11:06:28.083[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading data.[0m
[32m2024-12-13 11:06:38.671[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoaded 1 data sets.[0m
[32m2024-12-13 11:06:38.672[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mStart preprocessing.[0m
[32m2024-12-13 11:06:38.672[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mPerform add_taxa_names preprocessing.[0m
[32m2024-12-13 11:06:39.803[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mPerform add_clade_information preprocessing.[0m
[32m2024-12-13 11:06:42.713[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mPerform remove_tree preprocessing.[0m


In [6]:
train_dataset, val_dataset, test_dataset = create_data_splits(
    dataset, **config["training"]["splitting_config"]
)

In [7]:
train_loader = DataLoader(train_dataset, **config["training"]["dataloader_config"])
test_loader = DataLoader(test_dataset,batch_size=1)
val_loader = DataLoader(val_dataset)

optimizer = optimizer_factory(**config["training"]["optimizer_config"])
model = model_factory(
    optimizer=optimizer,
    dim=len(train_dataset[0]["branch_lengths"]),
    **config["training"]["model_config"],
)

In [13]:
from src.ml.modeling.conditional_tree_flow import ConditionalTreeFlow

model = ConditionalTreeFlow.load_from_checkpoint("ml_data/models/debug_yule_10/epoch=49-val_loss=2628.66.ckpt")
model = model.eval()

In [14]:
sample = next(iter(test_loader))
sample

{'trees_file': ['data/mcmc_runs/yule-10_140.trees'],
 'tree_index': tensor([34217]),
 'taxa_names': [('0',),
  ('8',),
  ('6',),
  ('5',),
  ('7',),
  ('3',),
  ('4',),
  ('9',),
  ('1',),
  ('2',)],
 'clades': [tensor([3]),
  tensor([12]),
  tensor([19]),
  tensor([31]),
  tensor([96]),
  tensor([224]),
  tensor([255]),
  tensor([768]),
  tensor([1023])],
 'branch_lengths': tensor([[0.0105, 0.0028, 0.0049, 0.0012, 0.0004, 0.0042, 0.0075, 0.0053, 0.0362]]),
 'clades_one_hot': tensor([[1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1.]])}

In [15]:
latent = model.forward(sample)
latent

{'z': tensor([[ 0.5680,  0.2705, -0.0076, -0.8350, -0.8668,  0.4092,  0.1815,  0.2427,
          -0.1909]], grad_fn=<MulBackward0>),
 'context': tensor([[1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1.]]),
 'log_dj': tensor([51.6559], grad_fn=<AddBackward0>)}

In [16]:
output = model.inverse(latent)
output, sample

({'branch_lengths': tensor([[0.0105, 0.0028, 0.0049, 0.0012, 0.0004, 0.0042, 0.0075, 0.0053, 0.0362]],
         grad_fn=<ExpBackward0>),
  'clades_one_hot': tensor([[1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
           0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 1.]]),
  'log_dj': tensor([51.6559], grad_fn=<AddBackward0>)},
 {'trees_file': ['data/mcmc_runs/yule-10_140.trees'],
  'tree_index': tensor([34217]),
  'taxa_names': [('0',),
   ('8',),
   ('6',),
   ('5',),
   ('7',),
   ('3',),
   ('4',),
   ('9',),
   ('1',),
   ('2',)],
  'clades': [tensor([3]),
   tensor([12]),
   tensor([19]),
   tensor([31]),
   tensor([96]),
   tensor([224]),
   tensor([255]),
   tensor([768]),
   tensor([1023])],
  'branch_lengths': tensor([[0.0105, 0.0028, 0.0049, 0.0012, 0.0004, 0.0042, 0.0075, 0.0053, 0.0362]]),
  'cla

In [28]:
prior = model.prior.sample((9,))
latent["z"] = prior.clone()
output = model.inverse(latent)
output

{'branch_lengths': tensor([0.0107, 0.0007, 0.0020, 0.0016, 0.0017, 0.0011, 0.0045, 0.0052, 0.0260],
        grad_fn=<ExpBackward0>),
 'clades_one_hot': tensor([[1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1.]]),
 'log_dj': tensor([51.6559], grad_fn=<AddBackward0>)}