In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from lib.config.config_dna import get_config
import time
import tqdm
import tabix
import pyBigWig
import pandas as pd
from matplotlib import pyplot as plt
from lib.models.ddsm import *
from selene_sdk.utils import NonStrandSpecific
from selene_sdk.targets import Target
from lib.sei.sei import Sei
from lib.datasets.datasets import TSSDatasetS, prepare_dna_valid_dataset
import os
import lib.utils.bookkeeping as bookkeeping
from lib.models.networks import DNAScoreNet
from lib.training.training import Trainer
from lib.sampling.sampling import Euler_Maruyama_sampler

In [None]:
train_resume = False
if not train_resume:
    config = get_config()
    bookkeeping.save_config(config, config.save_location)

else:
    path = 'path/to/saved/models'
    date = 'date'
    config_name = 'config_name.yaml'
    config_path = os.path.join(path, date, config_name)

    configfg = bookkeeping.load_config(config_path)

sei_features = pd.read_csv(config.seifeatures_file, sep='|', header=None)
sei = nn.DataParallel(NonStrandSpecific(Sei(4096, 21907)))
sei.load_state_dict(torch.load(config.sei.seimodel_file, map_location='cpu')['state_dict'])
#sei.cuda()

torch.set_default_dtype(torch.float32)
# hiermit importance sampling und mit rand_offset=100 in train()
train_set = TSSDatasetS(config, n_tsses=40000, rand_offset=10)
data_loader = DataLoader(train_set, batch_size=config.data.batch_size, shuffle=True, num_workers=config.data.num_workers)

trainer = Trainer(config)
valid_datasets, valid_seqs = prepare_dna_valid_dataset(config, sei, sei_features)


In [None]:
model = DNAScoreNet()
optimizer = Adam(model.parameters(), lr=config.optimizer.lr)
n_iter = 0
state = {"model": model, "optimizer": optimizer, "n_iter": 0}

if train_resume:
    checkpoint_path = 'path/to/saved/models'
    model_name = 'model_name.pt'
    checkpoint_path = os.path.join(path, date, model_name)
    state = bookkeeping.load_state(state, checkpoint_path)
    config.training.n_iters = 36000
    config.sampler.sample_freq = 36000
    config.saving.checkpoint_freq = 1000
    
sampler = Euler_Maruyama_sampler


trainer.train(state, sampler, sei, sei_features, data_loader, valid_datasets, valid_seqs)
