In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from lib.config.config_dna import get_config
import time
import tqdm
import tabix
import pyBigWig
import pandas as pd
from matplotlib import pyplot as plt
from lib.models.ddsm import *
from selene_sdk.utils import NonStrandSpecific
from selene_sdk.targets import Target
from lib.sei.sei import Sei
from lib.datasets.datasets import TSSDatasetS

In [None]:
config = get_config()

sb = UnitStickBreakingTransform()
"""
Sei model is published in the following paper
Chen, K. M., Wong, A. K., Troyanskaya, O. G., & Zhou, J. (2022). A sequence-based global map of 
regulatory activity for deciphering human genetics. Nature genetics, 54(7), 940-949. 
[https://doi.org/10.1038/s41588-022-01102-2](https://doi.org/10.1038/s41588-022-01102-2)  
"""
seifeatures = pd.read_csv(config.seifeatures_file, sep='|', header=None)

sei = nn.DataParallel(NonStrandSpecific(Sei(4096, 21907)))
sei.load_state_dict(torch.load(config.seimodel_file, map_location='cpu')['state_dict'])
sei.cuda()

### LOAD WEIGHTS
v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints = torch.load(config.diffusion_weights_file)
v_one = v_one.cpu()
v_zero = v_zero.cpu()
v_one_loggrad = v_one_loggrad.cpu()
v_zero_loggrad = v_zero_loggrad.cpu()
timepoints = timepoints.cpu()
alpha = torch.ones(config.ncat - 1).float()
beta =  torch.arange(config.ncat - 1, 0, -1).float()

### TIME DEPENDENT WEIGHTS ###
torch.set_default_dtype(torch.float32)

train_set = TSSDatasetS(config, n_tsses=40000, rand_offset=10)
data_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

time_dependent_cums = torch.zeros(config.n_time_steps).to(config.device)
time_dependent_counts = torch.zeros(config.n_time_steps).to(config.device)
