In [32]:
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow
from torch.utils.data import DataLoader, Dataset

from src.ml.data.splitting import create_data_splits
from src.ml.modeling import (
    model_factory,
    optimizer_factory,
)
from src.ml.utils.set_seed import set_seed

from datetime import datetime
from loguru import logger
from pathlib import Path
from shutil import copy, copytree, rmtree

import yaml

from src.ml.train_neural_network import train_neural_network
from src.ml.data import data_sets_factory
from src.ml.preprocessing import preprocessing_factory
from src.ml.utils.set_seed import set_seed

In [33]:
CONFIG_FILE = Path("src/ml/config.yaml")

In [34]:
torch.set_default_device(torch.device("cpu"))

In [35]:
# load config file

logger.info("Loading config file.")

with open(CONFIG_FILE, "r") as f:
    config = yaml.safe_load(f)

[32m2024-12-13 01:07:38.556[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading config file.[0m


In [36]:
# load data

logger.info("Loading data.")

data_sets = data_sets_factory(**config["data_set"])
logger.info("Loaded {} data sets.", len(data_sets))

# preprocess data

logger.info("Start preprocessing.")

for preprocessing_step in config["preprocessing"]:
    logger.info("Perform {} preprocessing.", preprocessing_step["name"])

    transform = preprocessing_factory(**preprocessing_step)
    data_sets = [transform(data_set) for data_set in data_sets]

dataset = data_sets[0]

[32m2024-12-13 01:07:38.564[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading data.[0m
[32m2024-12-13 01:07:38.564[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoaded 1 data sets.[0m
[32m2024-12-13 01:07:38.565[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mStart preprocessing.[0m


In [37]:
train_dataset, val_dataset, test_dataset = create_data_splits(
    dataset, **config["training"]["splitting_config"]
)

In [38]:
train_loader = DataLoader(train_dataset, **config["training"]["dataloader_config"])
test_loader = DataLoader(test_dataset,batch_size=1)
val_loader = DataLoader(val_dataset)

optimizer = optimizer_factory(**config["training"]["optimizer_config"])
model = model_factory(
    optimizer=optimizer,
    dim=len(train_dataset[0]["branch_lengths"]),
    **config["training"]["model_config"],
)

In [39]:
from src.ml.modeling.conditional_tree_flow import ConditionalTreeFlow

model = ConditionalTreeFlow.load_from_checkpoint("ml_data/models/debug_simple_conditional_flow/epoch=19-val_loss=507080.16.ckpt")

In [40]:
sample = next(iter(test_loader))
latent = model.forward(sample)
output = model.inverse(latent)
sample, output

({'branch_lengths': tensor([[ 9.8746e-04,  2.8777e-05, -1.5181e-04,  1.8398e-03,  1.2263e-03,
           -1.7070e-03,  3.0562e-04,  1.3443e-05, -2.9067e-04,  6.3688e-04]]),
  'clades_one_hot': tensor([[0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1.,
           0., 1.]])},
 {'branch_lengths': tensor([[ 9.8746e-04,  2.8777e-05, -1.5181e-04,  1.8398e-03,  1.2263e-03,
           -1.7070e-03,  3.0562e-04,  1.3443e-05, -2.9067e-04,  6.3689e-04]],
         grad_fn=<SubBackward0>),
  'log_dj': tensor([0.5187], grad_fn=<AddBackward0>)})

In [41]:
sample = next(iter(test_loader))
sample

{'branch_lengths': tensor([[-0.0008, -0.0003, -0.0007, -0.0001,  0.0016, -0.0006,  0.0007,  0.0010,
           0.0011,  0.0009]]),
 'clades_one_hot': tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1.,
          0., 0.]])}

In [42]:
latent = model.forward(sample)
latent

{'z': tensor([[-0.0008, -0.0003, -0.0007,  0.2210,  0.0016, -0.0006,  0.0163,  0.0010,
           0.0011, -0.1399]], grad_fn=<MulBackward0>),
 'log_dj': tensor([0.5175], grad_fn=<AddBackward0>)}

In [46]:
latent["z"] = torch.distributions.normal.Normal(loc=0.0, scale=1.0).sample((10,))
output = model.inverse(latent)
output

{'branch_lengths': tensor([-1.2467,  0.8291,  0.2576, -1.9265, -0.9135,  1.1018, -0.0609, -0.3466,
          2.3803,  0.4958], grad_fn=<SubBackward0>),
 'log_dj': tensor([0.5175], grad_fn=<AddBackward0>)}