# Tests

## Imports and stuff

In [1]:
import sys
from copy import copy

def append_path(s):
    if s in sys.path:
        return
    sys.path.append(s)

append_path("..")
#%load_ext autoreload
#%autoreload 2

In [2]:
import numpy as np
from numpy.random import default_rng
rng = default_rng()
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import networkx as nx
import torch
from torch import Tensor


In [3]:
import ml_lib
from ml_lib import feature_specification

In [4]:
ls ../ml_lib/datasets

base_classes.py  feature_specification.py  registration.py
datapoint.py     __init__.py               splitting.py
[0m[01;34mdatasets[0m/        [01;34m__pycache__[0m/              transforms.py


## Datasets

In [5]:
from ml_lib.datasets import register as dataset_register, transform_register, load_transform, load_dataset, Transform
from ml_lib.datasets.transforms import MultipleFunctionTransform

In [6]:
dataset0 = dataset_register["Torus4D"](1000)

In [7]:
transform = transform_register["CacheTransform"]() 
dataset1 = transform(dataset0)

In [8]:
dataset1._inner

<ml_lib.datasets.datasets.simple_shapes.Torus4D at 0x7bc815b12690>

In [9]:
dataset1[0]

tensor([[ 0.4502, -0.8844, -0.7997, -0.5998]])

In [10]:
dataset2 = transform_register["RenameTransform"]({"x": "_"})(dataset1)
dataset2[0]

<ml_lib.datasets.datapoint.DictDatapoint at 0x7bc816f8ef50>

In [11]:
dataset3 = transform_register["RenameTransform"]({"x": "x", "y":"x"})(dataset2)
dataset3[0]

<ml_lib.datasets.datapoint.DictDatapoint at 0x7bc815234310>

In [12]:
normalize = lambda x : x / x.norm()
dataset4 = transform_register["MultipleFunctionTransform"]({"x": (normalize, "x"), "gt": (lambda x:x, "y")})(dataset3)
dataset4[0]

<ml_lib.datasets.datapoint.DictDatapoint at 0x7bc8162e5510>

In [13]:
@transform_register
class TestTransform(MultipleFunctionTransform):

    def __init__(self):
        super().__init__({"x": (normalize, "x"), "gt": (lambda x:x, "y")})

great!

## Environments

In [14]:
from ml_lib.environment import Environment, Scope, scopevar_of_str, str_of_scopevar, HierarchicEnvironment, ScopedEnvironment

In [15]:
scopevar_of_str("a/b/c")

(('a', 'b'), 'c')

In [16]:
str_of_scopevar(('a', 'b'), 'c')

'a/b/c'

In [17]:
env = Environment()

In [18]:
env

Environment(defaultdict(<class 'dict'>, {'_ipython_canary_method_should_not_exist_': {}, '_ipython_display_': {}, '_repr_mimebundle_': {}}))

In [19]:
env.record("hello", 1)

In [20]:
env.data

defaultdict(dict,
            {'_ipython_canary_method_should_not_exist_': {},
             '_ipython_display_': {},
             '_repr_mimebundle_': {},
             '_repr_html_': {},
             '_repr_markdown_': {},
             '_repr_svg_': {},
             '_repr_png_': {},
             '_repr_pdf_': {},
             '_repr_jpeg_': {},
             '_repr_latex_': {},
             '_repr_json_': {},
             '_repr_javascript_': {},
             'hello': {(): 1}})

In [21]:
env.record("world", 4, ("some", "scope"))
env.data

defaultdict(dict,
            {'_ipython_canary_method_should_not_exist_': {},
             '_ipython_display_': {},
             '_repr_mimebundle_': {},
             '_repr_html_': {},
             '_repr_markdown_': {},
             '_repr_svg_': {},
             '_repr_png_': {},
             '_repr_pdf_': {},
             '_repr_jpeg_': {},
             '_repr_latex_': {},
             '_repr_json_': {},
             '_repr_javascript_': {},
             'hello': {(): 1},
             'world': {('some', 'scope'): 4}})

In [22]:
env.get("hello")

1

In [23]:
env.get("world")

4

In [24]:
#import pdb; pdb.set_trace()
env.get("world", scope=("some",))

4

In [25]:
env.get("world", scope=("soe",))

In [26]:
hier_env = HierarchicEnvironment(parent=env)

In [27]:
hier_env.record("world", 2, ("some",))
hier_env.get("world", scope=("some",))


2

In [28]:
hier_env.get("world", scope=("some","scope"))


4

In [29]:
def f(world, hello=6, magic=3):
    print(world, hello, magic)
    return 1

In [30]:
hier_env.run_function(f)

2 1 3


1

## Model

In [31]:
from ml_lib.datasets.feature_specification import FeatureSpecification, MSEFeature

feature_spec = FeatureSpecification([MSEFeature("location", 4)])
assert feature_spec == feature_spec.from_config(feature_spec.to_config())
feature_spec

FeatureSpecification(MSEFeature(name='location', dim=4, loss_coef=1.0))

In [33]:
from ml_lib.models import Model, Supervised, Hyperparameter, register as model_register
from ml_lib.models.layers import MLP

@model_register
class SimpleMLPModel(Supervised):
    
    dimensions: Hyperparameter[list[int]]
    feature_specification: Hyperparameter[FeatureSpecification]
    

    inner: MLP

    def __setup__(self):
        self.inner = MLP(self.feature_specification.dim, *self.dimensions, batchnorm=False)

    def forward(self, x):
        return self.inner(x)

    def loss_fun(self, x, gt):
        return self.feature_specification.compute_loss(x, gt, reduce=True)
        #return (x - gt).square().sum()
        

model = SimpleMLPModel(dimensions=[5, 10, 4, 4],
                       feature_specification=feature_spec, 
                       name="test_model")
print(model)
print(torch.nn.Module.__repr__(model))
model.model_name

TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union

In [None]:
import pdb; pdb.pm()

> [0;32m/home/tris/Devoirs/research/implementations/tb_ml/ml_lib/misc/typing.py[0m(38)[0;36madvanced_type_check[0;34m()[0m
[0;32m     34 [0;31m        [0;32mreturn[0m [0madvanced_type_check[0m[0;34m([0m[0mvalue[0m[0;34m,[0m [0mnew_t[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m    [0;32mif[0m [0morigin[0m [0;34m==[0m [0mLiteral[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0margs[0m [0;34m=[0m [0mtyping[0m[0;34m.[0m[0mget_args[0m[0;34m([0m[0mt[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m        [0;32mreturn[0m [0mvalue[0m [0;32min[0m [0margs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 38 [0;31m    [0;32mreturn[0m [0misinstance[0m[0;34m([0m[0mvalue[0m[0;34m,[0m [0morigin[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


In [None]:
model.get_hyperparameters(serializable=True)

## Saving and loading to database

In [None]:
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from ml_lib.pipeline.experiment_tracking import create_tables, Model as Database_Model


In [None]:
model.get_model_type()

In [None]:
model.to_database_object()

In [None]:
print_requests = False
%rm /tmp/test.db
db_engine = create_engine("sqlite:////tmp/test.db", echo=print_requests)
create_tables(db_engine)
with Session(db_engine) as db_session:
    model.save_to_database(db_session, replace=True)
    #db_session.commit()
    print(model.get_database_object(db_session))
    model_object  = db_session.get(Database_Model, model.id)
    db_session.commit()
model.id

In [None]:
with Session(db_engine) as db_session:
    model_object  = db_session.get(Database_Model, model.id)
    print(model_object)
    print(type(db_session))
    loaded_model= model_object.load_model(load_latest_checkpoint=False,) #cannot load checkpoint for now since there isn't one
loaded_model

## Training

In [None]:
skip_training = False

In [None]:
from ml_lib.models import Model
from ml_lib.pipeline import Trainer, Training_parameters
from ml_lib.pipeline.training_hooks import TqdmHook, LoggerHook, CurveHook, DatabaseHook
from ml_lib.pipeline.experiment_tracking import Experiment as DBExperiment
from torch.utils.data import DataLoader

In [None]:
loader = DataLoader(dataset4, batch_size=None, shuffle=True)
next(iter(loader))

In [None]:
import logging; logging.basicConfig(level=logging.INFO, force=True)

curve = CurveHook()
training_parameters = Training_parameters(n_epochs=4)

with Session(db_engine) as db_session:
    experiment = DBExperiment(name="test_experiment", )
    db_session.add(experiment)
    trainer = Trainer(model, loader, training_parameters=training_parameters, device="cpu",
        step_hooks = [ 
            TqdmHook(),
            LoggerHook(interval=10),
            curve
        ],
        database=db_session, 
        db_experiment=experiment, )
    experiment_id = experiment.id
    trainer_id = trainer.id
    if not skip_training: trainer.train()
curve.draw()

Now check that we're able to resume training

In [None]:
new_training_parameters = copy(training_parameters)
new_training_parameters.n_epochs = 5
with Session(db_engine) as db_session:
    
    trainer = Trainer(model, loader, 
                      training_parameters=training_parameters, device="cpu",
        step_hooks = [ 
            TqdmHook(),
            LoggerHook(interval=10),
            curve
        ],
        database=db_session, 
        db_experiment=experiment_id, 
        resume_from=trainer_id)
    if not skip_training: trainer.train()

In [None]:
curve.draw()

## Automated experiment

In [None]:
cat workdir/test_config.yaml

In [None]:
from ml_lib.pipeline.experiment import Experiment
with Session(db_engine) as db_session:
    #import pdb;pdb.set_trace()
    exp = Experiment.from_yaml("workdir/test_config.yaml", database_session=db_session)
    exp.train_all()

In [None]:
with Session(db_engine) as db_session:
    #import pdb;pdb.set_trace()
    exp = Experiment.from_yaml("workdir/test_config.yaml", database_session=db_session)
    exp.train_all()

## Misc

In [None]:
from ml_lib.misc.matchers import EmptySet

In [None]:
isinstance(set(), EmptySet)

In [None]:
isinstance(set([1]), EmptySet)

In [None]:
match set():
    case EmptySet():
        print("Tristan, you are a terrible person")
    case _:
        print("ugh")

match {1, 2, 3}:
    case EmptySet():
        print("ugh")
    case _:
        print("But it works")

Someday I'll make one of those that checks for regular expressions… and that will be absolutely terrible

In [None]:
from typing import Literal
isinstance(Literal[3, 4], Literal)



In [None]:
Literal[1, 2] == Literal[2, 1]

In [None]:
Literal[1, 2].__dict__

In [None]:
import typing
Hyperparameter[list[int]].__origin__.__origin__

In [None]:
typing.get_args(Hyperparameter[list[int]])