In [1]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import selene_sdk
import numpy as np


def one_hot_encode(seqs):
    seqs_hot = list()
    for seq in seqs:
        seqs_hot.append(
            selene_sdk.sequences.Genome.sequence_to_encoding(seq).T
        )
    seqs_hot = np.stack(seqs_hot)
    return seqs_hot

class SampleDataset(Dataset):
    
    def __init__(self, data_path):
    
        sequence_key = "sequence"
        activity_key = "expression_log2"
        data_df = pd.read_csv(data_path)

        sequences = data_df[sequence_key]
        self.target = torch.tensor(data_df[activity_key].values, dtype=torch.float)

        self.seqs_hot = one_hot_encode(sequences)
        
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, idx):
        x = self.seqs_hot[idx]
        y = self.target[idx]

        return x, y