In [1]:
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow
from torch.utils.data import DataLoader, Dataset

from src.ml.data.splitting import create_data_splits
from src.ml.modeling import (
    model_factory,
    optimizer_factory,
)
from src.ml.utils.set_seed import set_seed

from datetime import datetime
from loguru import logger
from pathlib import Path
from shutil import copy, copytree, rmtree

import yaml

from src.ml.train_neural_network import train_neural_network
from src.ml.data import data_sets_factory
from src.ml.preprocessing import preprocessing_factory
from src.ml.utils.set_seed import set_seed

In [2]:
CONFIG_FILE = Path("src/ml/config.yaml")

In [3]:
torch.set_default_device(torch.device("cpu"))

In [4]:
# load config file

logger.info("Loading config file.")

with open(CONFIG_FILE, "r") as f:
    config = yaml.safe_load(f)

[32m2024-12-13 00:46:52.241[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading config file.[0m


In [5]:
# load data

logger.info("Loading data.")

data_sets = data_sets_factory(**config["data_set"])
logger.info("Loaded {} data sets.", len(data_sets))

# preprocess data

logger.info("Start preprocessing.")

for preprocessing_step in config["preprocessing"]:
    logger.info("Perform {} preprocessing.", preprocessing_step["name"])

    transform = preprocessing_factory(**preprocessing_step)
    data_sets = [transform(data_set) for data_set in data_sets]

dataset = data_sets[0]

[32m2024-12-13 00:46:52.245[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading data.[0m
[32m2024-12-13 00:46:52.246[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mLoaded 1 data sets.[0m
[32m2024-12-13 00:46:52.247[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mStart preprocessing.[0m


In [6]:
train_dataset, val_dataset, test_dataset = create_data_splits(
    dataset, **config["training"]["splitting_config"]
)

In [7]:
train_loader = DataLoader(train_dataset, **config["training"]["dataloader_config"])
test_loader = DataLoader(test_dataset,batch_size=10)
val_loader = DataLoader(val_dataset)

optimizer = optimizer_factory(**config["training"]["optimizer_config"])
model = model_factory(
    optimizer=optimizer,
    dim=len(train_dataset[0]["branch_lengths"]),
    context_dim=len(train_dataset[0]["clades_one_hot"]),
    **config["training"]["model_config"],
)

In [8]:
from src.ml.modeling.conditional_tree_flow import ConditionalTreeFlow

model = ConditionalTreeFlow.load_from_checkpoint("ml_data/models/debug_simple_conditional_flow/epoch=19-val_loss=2300.96.ckpt")

In [9]:
sample = next(iter(test_loader))
latent = model.forward(sample)
output = model.inverse(latent)
sample, output

({'branch_lengths': tensor([[0.0519, 0.0486, 0.0490, 0.0489, 0.0498, 0.0497, 0.0502, 0.0517, 0.0495,
           0.0491],
          [0.0517, 0.0503, 0.0498, 0.0511, 0.0515, 0.0513, 0.0513, 0.0510, 0.0496,
           0.0488],
          [0.0492, 0.0503, 0.0492, 0.0509, 0.0517, 0.0494, 0.0508, 0.0507, 0.0504,
           0.0502],
          [0.0499, 0.0503, 0.0495, 0.0500, 0.0488, 0.0502, 0.0478, 0.0486, 0.0511,
           0.0499],
          [0.0492, 0.0489, 0.0484, 0.0508, 0.0504, 0.0518, 0.0516, 0.0516, 0.0517,
           0.0489],
          [0.0508, 0.0507, 0.0511, 0.0509, 0.0494, 0.0507, 0.0505, 0.0508, 0.0517,
           0.0492],
          [0.0493, 0.0510, 0.0499, 0.0492, 0.0498, 0.0492, 0.0497, 0.0493, 0.0499,
           0.0506],
          [0.0485, 0.0502, 0.0488, 0.0515, 0.0489, 0.0491, 0.0478, 0.0508, 0.0495,
           0.0485],
          [0.0510, 0.0502, 0.0507, 0.0511, 0.0515, 0.0502, 0.0494, 0.0499, 0.0500,
           0.0484],
          [0.0515, 0.0481, 0.0488, 0.0511, 0.0500, 0.05

In [10]:
sample = next(iter(test_loader))
sample

{'branch_lengths': tensor([[0.0509, 0.0490, 0.0498, 0.0499, 0.0494, 0.0505, 0.0491, 0.0493, 0.0522,
          0.0506],
         [0.0493, 0.0504, 0.0490, 0.0503, 0.0511, 0.0516, 0.0518, 0.0486, 0.0521,
          0.0492],
         [0.0484, 0.0479, 0.0492, 0.0507, 0.0497, 0.0505, 0.0489, 0.0499, 0.0500,
          0.0496],
         [0.0515, 0.0506, 0.0509, 0.0495, 0.0500, 0.0498, 0.0515, 0.0487, 0.0525,
          0.0530],
         [0.0511, 0.0496, 0.0491, 0.0508, 0.0510, 0.0492, 0.0493, 0.0511, 0.0494,
          0.0488],
         [0.0489, 0.0492, 0.0522, 0.0500, 0.0503, 0.0501, 0.0506, 0.0501, 0.0502,
          0.0494],
         [0.0492, 0.0498, 0.0490, 0.0488, 0.0495, 0.0508, 0.0489, 0.0509, 0.0516,
          0.0518],
         [0.0485, 0.0477, 0.0506, 0.0502, 0.0498, 0.0506, 0.0518, 0.0487, 0.0510,
          0.0514],
         [0.0516, 0.0505, 0.0506, 0.0487, 0.0501, 0.0496, 0.0493, 0.0506, 0.0484,
          0.0504],
         [0.0489, 0.0524, 0.0512, 0.0492, 0.0497, 0.0490, 0.0498, 0.0494,

In [11]:
latent = model.forward(sample)
latent

{'z': tensor([[-1.3902e-01, -2.6560e-03, -7.9882e-03,  4.9895e-02, -1.1855e-03,
          -4.6577e-02,  1.8443e-01,  5.9022e-03,  5.2162e-02, -1.0939e-01],
         [-1.6355e-01,  1.4747e-03, -1.1370e-02,  5.0322e-02,  1.1467e-02,
          -6.1454e-02,  2.2058e-01,  2.7310e-02,  5.2133e-02, -1.2408e-01],
         [-1.3487e-01, -3.3168e-02, -5.0880e-03,  5.0673e-02,  3.1648e-03,
          -4.2069e-02,  2.3836e-01,  2.1490e-02,  5.0024e-02, -1.3239e-01],
         [-1.2149e-01, -8.6101e-03, -4.7958e-03,  4.9473e-02, -4.1558e-03,
          -5.1250e-02,  1.5665e-01,  1.7641e-02,  5.2458e-02, -1.0876e-01],
         [-1.4606e-01, -1.7655e-02, -6.2293e-03,  5.0765e-02,  2.0278e-03,
          -8.5442e-02,  1.7926e-01,  2.9374e-02,  4.9440e-02, -1.2760e-01],
         [-1.5798e-01, -3.2042e-03, -8.6136e-03,  5.0026e-02, -2.9738e-03,
          -4.6176e-02,  2.0000e-01,  1.0446e-02,  5.0161e-02, -1.2808e-01],
         [-1.7710e-01, -1.5226e-02, -9.3870e-03,  4.8841e-02,  4.7702e-03,
          -5.6

In [12]:
torch.randn(1, 10)

tensor([[ 1.2818, -0.4083,  1.3057,  1.3644,  0.0135, -0.8820,  0.3757, -0.0398,
          1.8833,  0.0425]])

In [13]:
latent["z"] = torch.randn(1, 10)
output = model.inverse(latent)
output

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 10 for tensor number 1 in the list.