In [None]:
%load_ext autoreload
%autoreload 2

In [24]:
import wandb
import pytorch_lightning as pl
from pathlib import Path
from src.utils.io import HDFReader
from src.data.preprocess import preprocess
from src.data.modules.paralog import ParalogousGeneDataModule
from src.models.baseline import ConvolutionalModel
from pytorch_lightning.loggers import WandbLogger

In [25]:
# preprocess(
#     "data/genome/gff_file.gff",
#     "data/genome/fasta_file.fsa",
#     "data/embeddings",
#     "data/waern_2013",
#     "data/samples.json",
#     "data/processed",
#     500,
# )

In [None]:
summary_path = Path("data/processed/summary.csv")
h5_path = Path("data/processed/genewise.h5")

summary_path.exists() and h5_path.exists()

In [27]:
h5_reader = HDFReader("data/processed/genewise.h5")

In [28]:
dm = ParalogousGeneDataModule(h5_reader, summary_path, 1)
train_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()

In [None]:
lr = 1e-3
batch_size = 32

for fold in range(5):
    dm = ParalogousGeneDataModule(h5_reader, summary_path, fold, batch_size=batch_size)
    train_loader = dm.train_dataloader()
    test_loader = dm.test_dataloader()

    model = ConvolutionalModel(pooling_type="max", learning_rate=lr)
    wandb_logger = WandbLogger(project="RNA_prediction", name=f"fold_{fold}, max convolution, lr={lr}, batch_size={batch_size}")
    trainer = pl.Trainer(max_epochs=20, logger=wandb_logger)
    trainer.fit(model, train_loader, test_loader)

    wandb.finish()