In [6]:
from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow
from torch.utils.data import DataLoader, Dataset

from src.ml.data.splitting import create_data_splits
from src.ml.modeling import (
    model_factory,
    optimizer_factory,
)
from src.ml.utils.set_seed import set_seed

from datetime import datetime
from loguru import logger
from pathlib import Path
from shutil import copy, copytree, rmtree

import yaml

from src.ml.train_neural_network import train_neural_network
from src.ml.data import data_sets_factory
from src.ml.preprocessing import preprocessing_factory
from src.ml.utils.set_seed import set_seed

In [7]:
CONFIG_FILE = Path("src/ml/config.yaml")

In [8]:
torch.set_default_device(torch.device("cpu"))

In [9]:
# load config file

logger.info("Loading config file.")

with open(CONFIG_FILE, "r") as f:
    config = yaml.safe_load(f)

[32m2024-12-15 12:44:28.436[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading config file.[0m


In [10]:
# load data

logger.info("Loading data.")

_, data_sets = next(iter(data_sets_factory(**config["data_set"])))

# preprocess data

logger.info("Start preprocessing.")

for preprocessing_step in config["preprocessing"]:
    logger.info("Perform {} preprocessing.", preprocessing_step["name"])

    transform = preprocessing_factory(**preprocessing_step)
    data_sets = transform(data_sets)

dataset = data_sets

[32m2024-12-15 12:44:28.727[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mLoading data.[0m
[32m2024-12-15 12:44:40.507[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mStart preprocessing.[0m
[32m2024-12-15 12:44:40.507[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mPerform add_taxa_names preprocessing.[0m
[32m2024-12-15 12:44:41.414[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mPerform add_clade_information preprocessing.[0m
[32m2024-12-15 12:44:41.937[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mPerform remove_tree preprocessing.[0m


In [11]:
train_dataset, val_dataset, test_dataset = create_data_splits(
    dataset, **config["training"]["splitting_config"]
)

In [12]:
train_loader = DataLoader(train_dataset, **config["training"]["dataloader_config"])
test_loader = DataLoader(test_dataset,batch_size=2)
val_loader = DataLoader(val_dataset)

optimizer = optimizer_factory(**config["training"]["optimizer_config"])
model = model_factory(
    optimizer=optimizer,
    input_example=train_dataset[0],
    **config["training"]["model_config"],
)

In [13]:
from src.ml.modeling.conditional_tree_flow import ConditionalTreeFlow

model = ConditionalTreeFlow.load_from_checkpoint("ml_data/models/debug_dummy_2024_12_15_12_40_53/yule-10_140/epoch=19-val_loss=-41.61.ckpt")
model = model.eval()

In [14]:
i = iter(test_loader)

In [15]:
sample = next(i)

In [16]:
model.forward(sample)["z"] 

tensor([[ -1.1223,  -6.1402,  -7.9776,  -9.0680,  -1.7755, -18.2997,  -3.7653,
          -2.1172,  -1.7462],
        [ -1.4162,  -8.1876,  -6.5458,  -4.7786,   0.1151, -13.2932,  -4.6749,
          -3.8043,  -2.5084]], grad_fn=<AddBackward0>)

In [32]:
prior = model.prior.sample((2, 9,))
latent = model.forward(sample)
latent["z"] = prior.clone()
model.inverse(latent)["branch_lengths"]

tensor([[5.8314e+01, 1.0600e+00, 2.4746e-01, 3.0824e+00, 7.3227e-01, 5.5507e-01,
         1.9308e-01, 1.3201e+00, 1.9309e+00],
        [1.1974e+02, 1.2081e+00, 1.8354e-01, 5.3316e+00, 6.5801e+00, 3.4609e+00,
         4.4080e+00, 3.7368e+00, 1.1210e-01]], grad_fn=<ExpBackward0>)

In [14]:
transformed = latent

In [15]:
transformed["z"]

tensor([[  1.0683,   5.3490, 168.6009,   3.1450,   0.3396,   2.7504,   1.9404,
           0.2198,   4.1802]], grad_fn=<ExpBackward0>)

In [16]:
model.flows[-33]

IndexError: index -33 is out of range

In [None]:
transformed["z"] = result["z"]
result = model.flows[-33].inverse(**transformed)

result["z"]

tensor([[2.2004e-01, 8.4510e+02, 3.1984e+01, 0.0000e+00, 5.3590e-03, 3.6059e-01,
         8.8577e+04, 1.3899e+22, 3.2386e-04]], grad_fn=<ExpBackward0>)

In [None]:
output["z"], output["branch_lengths"]

(tensor([[4.5399e+01, 9.2058e-05,        inf, 1.1351e+17, 9.7374e+00, 2.7609e+01,
          3.2553e-15, 5.0594e+25, 1.2663e+03]], grad_fn=<ExpBackward0>),
 tensor([[4.5399e+01, 9.2058e-05,        inf, 1.1351e+17, 9.7374e+00, 2.7609e+01,
          3.2553e-15, 5.0594e+25, 1.2663e+03]], grad_fn=<ExpBackward0>))

In [None]:
from torch import tensor


clade_bitstrings = [
    tensor([3]),
    tensor([12]),
    tensor([19]),
    tensor([31]),
    tensor([96]),
    tensor([224]),
    tensor([255]),
    tensor([768]),
    tensor([1023]),
]

In [None]:
clade_bitstrings = torch.tensor(clade_bitstrings)

In [None]:
import torch

mask = 2**torch.arange(10)

In [None]:
clade_bitstrings.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

tensor([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)