In [1]:
from analysis import dataset as ds
from analysis.base import BaseModel, ModelConfig, TrainConfig

print("Num real phenotypes:", len(ds.phenotype_names))
print("First 10 phenotypes:", ds.phenotype_names[:10])
print("Num synthetic phenos:", len(getattr(ds, "phenotype_names_synthetic", [])))
print("TrainConfig defaults:\n", TrainConfig())
print("ModelConfig defaults:\n", ModelConfig())

Num real phenotypes: 18
First 10 phenotypes: ['23C', '25C', '27C', '30C', '33C', '35C', '37C', 'cu', 'suloc', 'ynb']
Num synthetic phenos: 1
TrainConfig defaults:
 TrainConfig(data_dir=PosixPath('data'), save_dir=PosixPath('models'), name_prefix='', phenotypes=['23C', '25C', '27C', '30C', '33C', '35C', '37C', 'cu', 'suloc', 'ynb', 'eth', 'gu', 'li', 'mann', 'mol', 'raff', 'sds', '4NQO'], optimizer='adam', patience=200, batch_size=64, learning_rate=0.001, lr_schedule=False, weight_decay=0.0, max_epochs=200, num_workers=1, gradient_clip_val=0.0, use_cache=True, use_modal=False, modal_detach=True, seed=None, synthetic_data=False)
ModelConfig defaults:
 ModelConfig(model_type='rijal_et_al', seq_length=1164, embedding_dim=13, num_layers=3, init_scale=0.03, skip_connections=False, scaled_attention=False, layer_norm=False, dropout_rate=0.0, nhead=4, dim_feedforward=1048)


In [None]:
import inspect


def all_subclasses(cls):
    out = set()
    work = [cls]
    while work:
        c = work.pop()
        for sc in c.__subclasses__():
            if sc not in out:
                out.add(sc)
                work.append(sc)
    return sorted(out, key=lambda c: (c.__module__, c.__name__))


subs = all_subclasses(BaseModel)
for cls in subs:
    try:
        sig = str(inspect.signature(cls))
    except ValueError:
        sig = "(...)"
    print(f"{cls.__module__}.{cls.__name__}{sig}")

analysis.modified_rijal_et_al.ModifiedRijalEtAl(model_config: analysis.base.ModelConfig, train_config: analysis.base.TrainConfig)
analysis.rijal_et_al.RijalEtAl(model_config: analysis.base.ModelConfig, train_config: analysis.base.TrainConfig)
analysis.transformer.Transformer(model_config: analysis.base.ModelConfig, train_config: analysis.base.TrainConfig)


In [None]:
import torch
from analysis import dataset as ds
from analysis.base import ModelConfig, TrainConfig
from analysis.modified_rijal_et_al import ModifiedRijalEtAl
from analysis.rijal_et_al import RijalEtAl
from analysis.transformer import Transformer

mc = ModelConfig(seq_length=1164, embedding_dim=16, num_layers=2, nhead=4, dim_feedforward=256)

tc_multi = TrainConfig(phenotypes=ds.phenotype_names[:4])

tc_single = TrainConfig(phenotypes=[ds.phenotype_names[0]])

B = 8
x = torch.randint(0, 3, (B, mc.seq_length)).float()

models = [
    ("ModifiedRijalEtAl", ModifiedRijalEtAl(model_config=mc, train_config=tc_multi)),
    ("Transformer", Transformer(model_config=mc, train_config=tc_multi)),
    ("RijalEtAl (single)", RijalEtAl(model_config=mc, train_config=tc_single)),
]

with torch.no_grad():
    for name, model in models:
        y = model(x)
        print(f"{name:<22} in {tuple(x.shape)} -> out {tuple(y.shape)}")

ModifiedRijalEtAl      in (8, 1164) -> out (8, 4)
Transformer            in (8, 1164) -> out (8, 4)
RijalEtAl (single)     in (8, 1164) -> out (8, 1)


