In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
import pandas as pd
from Bio import SeqIO
import numpy as np


In [2]:
genome = "resources/genomes/hg38.ml.fa"
genome_dict = SeqIO.to_dict(SeqIO.parse(genome, "fasta"))

In [35]:
train_data = "test_train_data/h5py/all_train.h5"

In [33]:
train_data['sequence'].shape, train_data['target'].shape

((35138, 196608), (35138, 896, 10))

In [10]:
bed_file ="test_train_data/sequences.bed"

In [53]:
train_dataset = EnformerDataset(train_data, bed_file, seqlen=196608, genome_dict=genome_dict,
                             shift_aug=True, rc_aug=True, fold="train")    
train_loader = DataLoader(train_dataset, batch_size=4, drop_last=True, shuffle=False, num_workers=0)

In [54]:
batch = next(iter(train_loader))
batch['sequence'].shape, batch['target'].shape

(torch.Size([4, 196608, 4]), torch.Size([4, 896, 10]))

In [39]:

class EnformerDataset(Dataset):
    """
    GEP means Gene Expression Prediction
    """
    def __init__(self, file_path, bed_path, seqlen, genome_dict, shift_aug, rc_aug, fold):
        self.file_path = file_path
        self.bed_path = bed_path
        self.fold = fold  # train / valid / test

        
        self.h5_file = h5py.File(self.file_path, 'r')
        all_targets = self.h5_file['target'][:]  # load all first

       
        bed_file = pd.read_csv(bed_path, sep='\t', names=["chrom", "start", "end", "fold"])
        self.bed_file = bed_file[bed_file["fold"] == self.fold].reset_index(drop=True)

       
        self.targets = all_targets[self.bed_file.index]

        assert len(self.bed_file) == len(self.targets), \
            f"Mismatch: {len(self.bed_file)} rows in BED vs {len(self.targets)} targets"

        self.seqlen = seqlen
        self.genome_dict = genome_dict
        self.chrom_length = {chrom: len(genome_dict[chrom]) for chrom in genome_dict}

        self.shift_aug = shift_aug
        self.rc_aug = rc_aug

    def resize_interval(self, chrom, start, end):
        mid_point = (start + end) // 2
        extend_start = mid_point - self.seqlen // 2
        extend_end = mid_point + self.seqlen // 2
        trimmed_start = max(0, extend_start)
        left_pad = trimmed_start - extend_start
        trimmed_end = min(self.chrom_length[chrom], extend_end)
        right_pad = extend_end - trimmed_end
        return trimmed_start, trimmed_end, left_pad, right_pad

    def get_sequence(self, chrom, start, end):
        trimmed_start, trimmed_end, left_pad, right_pad = self.resize_interval(chrom, start, end)
        sequence = str(self.genome_dict[chrom].seq[trimmed_start:trimmed_end]).upper()
        left_pad_seq = 'N' * left_pad
        right_pad_seq = 'N' * right_pad
        sequence = left_pad_seq + sequence + right_pad_seq
        return sequence

    def sequence_to_onehot(self, sequence):
        mapping = {'A': [1, 0, 0, 0],
                   'C': [0, 1, 0, 0],
                   'G': [0, 0, 1, 0],
                   'T': [0, 0, 0, 1],
                   'N': [0, 0, 0, 0]}
        onehot = np.array([mapping[base] for base in sequence], dtype=np.float32)
        return onehot

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        target = self.targets[idx]
        row = self.bed_file.iloc[idx]
        chrom, start, end = row['chrom'], row['start'], row['end']

        if self.shift_aug:
            shift = np.random.randint(-3, 4)
            start += shift
            end += shift

        sequence = self.get_sequence(chrom, start, end)
        onehot = self.sequence_to_onehot(sequence)

        if self.rc_aug and np.random.rand() < 0.5:
            onehot = onehot[::-1, ::-1].copy()  # <-- copy removes negative stride
            target = target[::-1].copy() if isinstance(target, np.ndarray) else target

        return {
            'sequence': onehot,
            'target': target,
        }

    def close(self):
        self.h5_file.close()
