In [1]:
import gzip
import os
import shutil
import urllib
from pathlib import Path
from typing import List
from tqdm import tqdm
from ast import literal_eval

import re
import datasets
from datasets import Dataset, DatasetDict, Features, Value
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import DatasetInfo
from pyfaidx import Fasta
from abc import ABC, abstractmethod
from datasets import load_dataset, load_from_disk
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

In [3]:
# """
# ----------------------------------------------------------------------------------------
# Reference Genome URLS:
# ----------------------------------------------------------------------------------------
# """
# H38_REFERENCE_GENOME_URL = (
#     "https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/" "hg38.fa.gz"
# )
# H19_REFERENCE_GENOME_URL = (
#     "https://hgdownload.soe.ucsc.edu/goldenPath/hg19/bigZips/" "hg19.fa.gz"
# )

In [25]:
"""
----------------------------------------------------------------------------------------
Task Specific Handlers:
----------------------------------------------------------------------------------------
"""

class GenomicLRATaskHandler(ABC):
    """
    Abstract method for the Genomic LRA task handlers.
    """

    @abstractmethod
    def __init__(self, **kwargs):
        pass

    @abstractmethod
    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the task
        """
        pass

    def split_generators(
            self, dl_manager, cache_dir_root
    ) -> List[datasets.SplitGenerator]:
        """
        Downloads required files using dl_manager and separates them by split.
        """
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                gen_kwargs={"handler": self, "split": "train"},
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST, gen_kwargs={"handler": self, "split": "test"}
            ),
        ]

    @abstractmethod
    def generate_examples(self, split):
        """
        A generator that yields examples for the specified split.
        """
        pass

    @staticmethod
    def hook(t):
        last_b = [0]

        def inner(b=1, bsize=1, tsize=None):
            """
            b  : int, optional
                Number of blocks just transferred [default: 1].
            bsize  : int, optional
                Size of each block (in tqdm units) [default: 1].
            tsize  : int, optional
                Total size (in tqdm units). If [default: None] remains unchanged.
            """
            if tsize is not None:
                t.total = tsize
            t.update((b - last_b[0]) * bsize)
            last_b[0] = b

        return inner

    def download_and_extract_gz(self, file_url, cache_dir_root):
        """
        Downloads and extracts a gz file into the given cache directory. Returns the
        full file path of the extracted gz file.
        Args:
            file_url: url of the gz file to be downloaded and extracted.
            cache_dir_root: Directory to extract file into.
        """
        file_fname = Path(file_url).stem
        file_complete_path = os.path.join(cache_dir_root, "downloads", file_fname)

        if not os.path.exists(file_complete_path):
            if not os.path.exists(file_complete_path + ".gz"):
                with tqdm(
                        unit="B",
                        unit_scale=True,
                        unit_divisor=1024,
                        miniters=1,
                        desc=file_url.split("/")[-1],
                ) as t:
                    urllib.request.urlretrieve(
                        file_url, file_complete_path + ".gz", reporthook=self.hook(t)
                    )
            with gzip.open(file_complete_path + ".gz", "rb") as file_in:
                with open(file_complete_path, "wb") as file_out:
                    shutil.copyfileobj(file_in, file_out)
        return file_complete_path


In [26]:
class CagePredictionHandler(GenomicLRATaskHandler):
    """
    Handler for the CAGE prediction task.
    """

    NUM_TRAIN = 33891
    NUM_TEST = 1922
    NUM_VALID = 2195
    DEFAULT_LENGTH = 114688  # 896 x 128bp
    TARGET_SHAPE = (
        896,
        50,
    )  # 50 is a subset of CAGE tracks from the original enformer dataset
    NPZ_SPLIT = 1000  # number of files per npz file.
    NUM_BP_PER_BIN = 128  # number of base pairs per bin in labels

    def __init__(self, sequence_length=DEFAULT_LENGTH, **kwargs):
        """
        Creates a new handler for the CAGE task.
        Args:
            sequence_length: allows for increasing sequence context. Sequence length
            must be an even multiple of 128 to align with binned labels. Note:
            increasing sequence length may decrease the number of usable samples.
        """
        self.reference_genome = None
        self.coordinate_csv_file = None
        self.target_files_by_split = {}


        assert (sequence_length // 128) % 2 == 0, (
            f"Requested sequence length must be an even multuple of 128 to align "
            f"with the binned labels."
        )

        self.sequence_length = sequence_length

        if self.sequence_length < self.DEFAULT_LENGTH:
            self.TARGET_SHAPE = (self.sequence_length // 128, 50)

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the CAGE dataset. Each example
        includes a genomic sequence and a 2D array of labels
        """
        features = datasets.Features(
            {
                # DNA sequence
                "sequence": datasets.Value("string"),
                # array of sequence length x num_labels
                "labels": datasets.Array2D(shape=self.TARGET_SHAPE, dtype="float32"),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # start
                "labels_start": datasets.Value(dtype="int32"),
                # stop
                "labels_stop": datasets.Value(dtype="int32")
            }
        )
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,
        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        The CAGE dataset requires reference genome, coordinate
        csv file,and npy files to be saved.
        """

        # Manually download the reference genome since there are difficulties when
        # streaming
        reference_genome_file = self.download_and_extract_gz(
            H38_REFERENCE_GENOME_URL, cache_dir_root
        )
        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

        self.coordinate_csv_file = dl_manager.download_and_extract(
            "cage_prediction/sequences_coordinates.csv"
        )

        train_file_dict = {}
        for train_key, train_file in self.generate_npz_filenames(
                "train", self.NUM_TRAIN, folder="cage_prediction/targets_subset"
        ):
            train_file_dict[train_key] = dl_manager.download(train_file)

        test_file_dict = {}
        for test_key, test_file in self.generate_npz_filenames(
                "test", self.NUM_TEST, folder="cage_prediction/targets_subset"
        ):
            test_file_dict[test_key] = dl_manager.download(test_file)

        valid_file_dict = {}
        for valid_key, valid_file in self.generate_npz_filenames(
                "valid", self.NUM_VALID, folder="cage_prediction/targets_subset"
        ):
            valid_file_dict[valid_key] = dl_manager.download(valid_file)

        # convert file list to a dict keyed by target number
        self.target_files_by_split["train"] = train_file_dict
        self.target_files_by_split["test"] = test_file_dict
        self.target_files_by_split["validation"] = valid_file_dict

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                gen_kwargs={"handler": self, "split": "train"},
            ),
            datasets.SplitGenerator(
                name=datasets.Split.VALIDATION,
                gen_kwargs={"handler": self, "split": "validation"},
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,
                gen_kwargs={"handler": self, "split": "test"}
            ),
        ]

    def generate_examples(self, split):
        """
        A generator which produces examples for the given split, each with a sequence
        and the corresponding labels. The sequences are padded to the correct
        sequence length and standardized before returning.
        """

        target_files = self.target_files_by_split[split]

        key = 0
        coordinates_dataframe = pd.read_csv(self.coordinate_csv_file)
        filtered = coordinates_dataframe[coordinates_dataframe["split"] == split]
        for sequential_idx, row in filtered.iterrows():
            start, stop = int(row["start"]) - 1, int(
                row["stop"]) - 1  # -1 since coords are 1-based

            chromosome = row['chrom']

            padded_sequence,new_start,new_stop = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                sequence_length=self.sequence_length,
                end=stop,
                return_new_start_stop=True
            )

            if self.sequence_length >= self.DEFAULT_LENGTH:
                new_start = start
                new_stop = stop

            # floor npy_idx to the nearest 1000
            npz_file = np.load(
                target_files[int((row["npy_idx"] // self.NPZ_SPLIT) * self.NPZ_SPLIT)]
            )

            if (
                    split == "validation"
            ):  # npy files are keyed by ["train", "test", "valid"]
                split = "valid"
            targets = npz_file[f"target-{split}-{row['npy_idx']}.npy"][
                0]  # select 0 since there is extra dimension

            # subset the targets if sequence length is smaller than 114688 (
            # DEFAULT_LENGTH)
            if self.sequence_length < self.DEFAULT_LENGTH:
                idx_diff = (self.DEFAULT_LENGTH - self.sequence_length) // 2 // 128
                targets = targets[idx_diff:-idx_diff]

            if padded_sequence:
                yield key, {
                    "labels": targets,
                    "sequence": standardize_sequence(padded_sequence),
                    "chromosome": re.sub("chr", "", chromosome),
                    "labels_start": new_start,
                    "labels_stop": new_stop
                }
                key += 1

    @staticmethod
    def generate_npz_filenames(split, total, folder, npz_size=NPZ_SPLIT):
        """
        Generates a list of filenames for the npz files stored in the dataset.
        Yields a tuple of floored multiple of 1000, filename
        Args:
            split: split to generate filenames for. Must be in ['train', 'test', 'valid']
                due to the naming of the files.
            total: total number of npy targets for given split
            folder: folder where data is stored.
            npz_size: number of npy files per npz. Defaults to 1000 because
                this is the number currently used in the dataset.
        """

        for i in range(total // npz_size):
            yield i * npz_size, f"{folder}/targets-{split}-{i * npz_size}-{i * npz_size + (npz_size - 1)}.npz"
        if total % npz_size != 0:
            yield (
                npz_size * (total // npz_size),
                f"{folder}/targets-{split}-"
                f"{npz_size * (total // npz_size)}-"
                f"{npz_size * (total // npz_size) + (total % npz_size - 1)}.npz",
            )



In [27]:
class BulkRnaExpressionHandler(GenomicLRATaskHandler):
    """
    Handler for the Bulk RNA Expression task.
    """

    DEFAULT_LENGTH = 100000

    def __init__(self, sequence_length=DEFAULT_LENGTH, **kwargs):
        """
        Creates a new handler for the Bulk RNA Expression Prediction Task.
        Args:
            sequence_length: Length of the sequence around the TSS_CAGE start site

        """
        self.reference_genome = None
        self.coordinate_csv_file = None
        self.labels_csv_file = None
        self.sequence_length = sequence_length

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the Bulk RNA Expression dataset. Each example
        includes a genomic sequence and a list of label values.
        """
        features = datasets.Features(
            {
                # DNA sequence
                "sequence": datasets.Value("string"),
                # list of expression values in each tissue
                "labels": datasets.Sequence(datasets.Value("float32")),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # position
                "position": datasets.Value(dtype="int32"),
            }
        )
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,

        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        The Bulk RNA Expression dataset requires the reference hg19 genome, coordinate
        csv file,and label csv file to be saved.
        """

        reference_genome_file = self.download_and_extract_gz(
            H19_REFERENCE_GENOME_URL, cache_dir_root
        )
        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

        self.coordinate_csv_file = dl_manager.download_and_extract(
            "bulk_rna_expression/gene_coordinates.csv"
        )

        self.labels_csv_file = dl_manager.download_and_extract(
            "bulk_rna_expression/rna_expression_values.csv"
        )

        return super().split_generators(dl_manager, cache_dir_root)

    def generate_examples(self, split):
        """
        A generator which produces examples for the given split, each with a sequence
        and the corresponding labels. The sequences are padded to the correct sequence
        length and standardized before returning.
        """
        coordinates_df = pd.read_csv(self.coordinate_csv_file)
        labels_df = pd.read_csv(self.labels_csv_file)

        coordinates_split_df = coordinates_df[coordinates_df["split"] == split]

        key = 0
        for idx, coordinates_row in coordinates_split_df.iterrows():
            start = coordinates_row[
                        "CAGE_representative_TSS"] - 1  # -1 since coords are 1-based

            chromosome = coordinates_row["chrom"]
            labels_row = labels_df.loc[idx].values
            padded_sequence = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                sequence_length=self.sequence_length,
                negative_strand=coordinates_row["strand"] == "-",
            )
            if padded_sequence:
                yield key, {
                    "labels": labels_row,
                    "sequence": standardize_sequence(padded_sequence),
                    "chromosome": re.sub("chr", "", chromosome),
                    "position": coordinates_row["CAGE_representative_TSS"]
                }
                key += 1



In [28]:
class VariantEffectCausalEqtl(GenomicLRATaskHandler):
    """
    Handler for the Variant Effect Causal eQTL task.
    """

    DEFAULT_LENGTH = 100000

    def __init__(self, sequence_length=DEFAULT_LENGTH, **kwargs):
        """
        Creates a new handler for the Variant Effect Causal eQTL Task.
        Args:
            sequence_length: Length of the sequence to pad around the SNP position

        """
        self.reference_genome = None
        self.sequence_length = sequence_length

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the Variant Effect Causal eQTL dataset. Each example
        includes a  genomic sequence with the reference allele as well as the genomic
        sequence with the alternative allele, and a binary label.
        """
        features = datasets.Features(
            {
                # DNA sequence
                "ref_forward_sequence": datasets.Value("string"),
                "alt_forward_sequence": datasets.Value("string"),
                # binary label
                "label": datasets.Value(dtype="int8"),
                # tissue type
                "tissue": datasets.Value(dtype="string"),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # variant position
                "position": datasets.Value(dtype="int32"),
                # distance to nearest tss
                "distance_to_nearest_tss": datasets.Value(dtype="int32")
            }
        )

        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,
        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        The variant effect prediction dataset requires the reference hg38 genome and
        coordinates_labels_csv_file to be saved.
        """

        # Manually download the reference genome since there are difficulties
        # when streaming
        reference_genome_file = self.download_and_extract_gz(
            H38_REFERENCE_GENOME_URL, cache_dir_root
        )

        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)
        self.coordinates_labels_csv_file = dl_manager.download_and_extract(
            f"variant_effect_causal_eqtl/All_Tissues.csv"
        )

        return super().split_generators(dl_manager, cache_dir_root)

    def generate_examples(self, split):
        """
        A generator which produces examples each with ref/alt allele
        and corresponding binary label. The sequences are extended to
        the desired sequence length and standardized before returning.
        """

        coordinates_df = pd.read_csv(self.coordinates_labels_csv_file)

        coordinates_split_df = coordinates_df[coordinates_df["split"] == split]

        key = 0
        for idx, row in coordinates_split_df.iterrows():
            start = row["POS"] - 1  # sub 1 to create idx since coords are 1-based
            alt_allele = row["ALT"]
            label = row["label"]
            tissue = row['tissue']
            chromosome = row["CHROM"]
            distance = int(row["distance_to_nearest_TSS"])

            # get reference forward sequence
            ref_forward = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                sequence_length=self.sequence_length,
                negative_strand=False,
            )

            # only if a valid sequence returned
            if ref_forward:
                # Mutate sequence with the alt allele at the SNP position,
                # which is always centered in the string returned from pad_sequence
                alt_forward = list(ref_forward)
                alt_forward[self.sequence_length // 2] = alt_allele
                alt_forward = "".join(alt_forward)

                yield key, {
                    "label": label,
                    "tissue": tissue,
                    "chromosome": re.sub("chr", "", chromosome),
                    "ref_forward_sequence": standardize_sequence(ref_forward),
                    "alt_forward_sequence": standardize_sequence(alt_forward),
                    "distance_to_nearest_tss": distance,
                    "position": row["POS"]
                }
                key += 1


In [29]:
class VariantEffectPathogenicHandler(GenomicLRATaskHandler):
    """
    Handler for the Variant Effect Pathogenic Prediction tasks.
    """

    DEFAULT_LENGTH = 100000

    def __init__(self, sequence_length=DEFAULT_LENGTH, task_name=None, subset=False,
                 **kwargs):
        """
        Creates a new handler for the Variant Effect Pathogenic Tasks.
        Args:
            sequence_length: Length of the sequence to pad around the SNP position
            subset: Whether to return a pre-determined subset of the data.

        """
        self.sequence_length = sequence_length

        if task_name == 'variant_effect_pathogenic_clinvar':
            self.data_file_name = "variant_effect_pathogenic/vep_pathogenic_coding.csv"
        elif task_name == 'variant_effect_pathogenic_omim':
            self.data_file_name = "variant_effect_pathogenic/" \
                                  "vep_pathogenic_non_coding_subset.csv" \
                if subset else "variant_effect_pathogenic/vep_pathogenic_non_coding.csv"

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the Variant Effect Pathogenic datasets. Each example
        includes a  genomic sequence with the reference allele as well as the genomic
        sequence with the alternative allele, and a binary label.
        """
        features = datasets.Features(
            {
                # DNA sequence
                "ref_forward_sequence": datasets.Value("string"),
                "alt_forward_sequence": datasets.Value("string"),
                # binary label
                "label": datasets.Value(dtype="int8"),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # position
                "position": datasets.Value(dtype="int32")
            }
        )

        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,
        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        The variant effect prediction datasets require the reference hg38 genome and
        coordinates_labels_csv_file to be saved.
        """

        reference_genome_file = self.download_and_extract_gz(
            H38_REFERENCE_GENOME_URL, cache_dir_root
        )

        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)
        self.coordinates_labels_csv_file = dl_manager.download_and_extract(
            self.data_file_name)

        if 'non_coding' in self.data_file_name:
            return [
                datasets.SplitGenerator(
                    name=datasets.Split.TEST,
                    gen_kwargs={"handler": self, "split": "test"}
                ), ]
        else:
            return super().split_generators(dl_manager, cache_dir_root)

    def generate_examples(self, split):
        """
        A generator which produces examples each with ref/alt allele
        and corresponding binary label. The sequences are extended to
        the desired sequence length and standardized before returning.
        """

        coordinates_df = pd.read_csv(self.coordinates_labels_csv_file)
        coordinates_split_df = coordinates_df[coordinates_df["split"] == split]

        key = 0
        for idx, row in coordinates_split_df.iterrows():
            start = row["POS"] - 1  # sub 1 to create idx since coords are 1-based
            alt_allele = row["ALT"]
            label = row["INT_LABEL"]
            chromosome = row["CHROM"]

            # get reference forward sequence
            ref_forward = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                sequence_length=self.sequence_length,
                negative_strand=False,
            )

            # only if a valid sequence returned
            if ref_forward:
                # Mutate sequence with the alt allele at the SNP position,
                # which is always centered in the string returned from pad_sequence
                alt_forward = list(ref_forward)
                alt_forward[self.sequence_length // 2] = alt_allele
                alt_forward = "".join(alt_forward)

                yield key, {
                    "label": label,
                    "chromosome": re.sub("chr", "", chromosome),
                    "ref_forward_sequence": standardize_sequence(ref_forward),
                    "alt_forward_sequence": standardize_sequence(alt_forward),
                    "position": row['POS']
                }
                key += 1




In [30]:
class ChromatinFeaturesHandler(GenomicLRATaskHandler):
    """
    Handler for the histone marks and DNA accessibility tasks also referred to
    collectively as Chromatin features.
    """

    DEFAULT_LENGTH = 100000

    def __init__(self, task_name=None, sequence_length=DEFAULT_LENGTH, subset=False,
                 **kwargs):
        """
        Creates a new handler for the Deep Sea Histone and DNase tasks.
        Args:
            sequence_length: Length of the sequence around and including the
            annotated 200bp bin
            subset: Whether to return a pre-determined subset of the entire dataset.

        """
        self.sequence_length = sequence_length

        if sequence_length < 200:
            raise ValueError(
                'Sequence length for this task must be greater or equal to 200 bp')

        if 'histone' in task_name:
            self.label_name = 'HISTONES'
        elif 'dna' in task_name:
            self.label_name = 'DNASE'

        self.data_file_name = "chromatin_features/histones_and_dnase_subset.csv" if \
            subset else "chromatin_features/histones_and_dnase.csv"

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the histone marks and dna accessibility datasets.
        Each example includes a genomic sequence and a list of label values.
        """
        features = datasets.Features(
            {
                # DNA sequence
                "sequence": datasets.Value("string"),
                # list of binary chromatin marks
                "labels": datasets.Sequence(datasets.Value("int8")),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # starting position in genome which corresponds to label
                "label_start": datasets.Value(dtype="int32"),
                # end position in genome which corresponds to label
                "label_stop": datasets.Value(dtype="int32"),
            }
        )
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,

        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        The histone marks and dna accessibility datasets require the reference hg19
        genome and coordinate csv file to be saved.
        """
        reference_genome_file = self.download_and_extract_gz(
            H19_REFERENCE_GENOME_URL, cache_dir_root
        )
        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

        self.coordinate_csv_file = dl_manager.download_and_extract(self.data_file_name)

        return super().split_generators(dl_manager, cache_dir_root)

    def generate_examples(self, split):
        """
        A generator which produces examples for the given split, each with a sequence
        and the corresponding labels. The sequences are padded to the correct sequence
        length and standardized before returning.
        """
        coordinates_df = pd.read_csv(self.coordinate_csv_file)
        coordinates_split_df = coordinates_df[coordinates_df["split"] == split]

        key = 0
        for idx, coordinates_row in coordinates_split_df.iterrows():
            start = coordinates_row['POS'] - 1  # -1 since saved coords are 1-based
            chromosome = coordinates_row["CHROM"]

            # literal eval used since lists are saved as strings in csv
            labels_row = literal_eval(coordinates_row[self.label_name])

            padded_sequence = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                sequence_length=self.sequence_length,
            )
            if padded_sequence:
                yield key, {
                    "labels": labels_row,
                    "sequence": standardize_sequence(padded_sequence),
                    "chromosome": re.sub("chr", "", chromosome),
                    "label_start": coordinates_row['POS']-100,
                    "label_stop": coordinates_row['POS'] + 99,
                }
                key += 1


class RegulatoryElementHandler(GenomicLRATaskHandler):
    """
    Handler for the Regulatory Element Prediction tasks.
    """
    DEFAULT_LENGTH = 100000

    def __init__(self, task_name=None, sequence_length=DEFAULT_LENGTH, subset=False,
                 **kwargs):
        """
        Creates a new handler for the Regulatory Element Prediction tasks.
        Args:
            sequence_length: Length of the sequence around the element/non-element
            subset: Whether to return a pre-determined subset of the entire dataset.

        """

        if sequence_length < 200:
            raise ValueError(
                'Sequence length for this task must be greater or equal to 200 bp')

        self.sequence_length = sequence_length

        if 'promoter' in task_name:
            self.data_file_name = 'regulatory_elements/promoter_dataset'

        elif 'enhancer' in task_name:
            self.data_file_name = 'regulatory_elements/enhancer_dataset'

        if subset:
            self.data_file_name += '_subset.csv'
        else:
            self.data_file_name += '.csv'

    def get_info(self, description: str) -> DatasetInfo:
        """
        Returns the DatasetInfo for the Regulatory Element Prediction Tasks.
        Each example includes a genomic sequence and a label.
        """
        features = datasets.Features(
            {
                # DNA sequence
                "sequence": datasets.Value("string"),
                # label corresponding to whether the sequence has
                # the regulatory element of interest or not
                "labels": datasets.Value("int8"),
                # chromosome number
                "chromosome": datasets.Value(dtype="string"),
                # start
                "label_start": datasets.Value(dtype="int32"),
                # stop
                "label_stop": datasets.Value(dtype="int32"),
            }
        )
        return datasets.DatasetInfo(
            # This is the description that will appear on the datasets page.
            description=description,
            # This defines the different columns of the dataset and their types
            features=features,

        )

    def split_generators(self, dl_manager, cache_dir_root):
        """
        Separates files by split and stores filenames in instance variables.
        """
        reference_genome_file = self.download_and_extract_gz(
            H38_REFERENCE_GENOME_URL, cache_dir_root
        )
        self.reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

        self.coordinate_csv_file = dl_manager.download_and_extract(
            self.data_file_name
        )

        return super().split_generators(dl_manager, cache_dir_root)

    def generate_examples(self, split):
        """
        A generator which produces examples for the given split, each with a sequence
        and the corresponding label. The sequences are padded to the correct sequence
        length and standardized before returning.
        """
        coordinates_df = pd.read_csv(self.coordinate_csv_file)

        coordinates_split_df = coordinates_df[coordinates_df["split"] == split]

        key = 0
        for _, coordinates_row in coordinates_split_df.iterrows():
            start = coordinates_row["START"] - 1  # -1 since vcf coords are 1-based
            end = coordinates_row["STOP"] - 1  # -1 since vcf coords are 1-based
            chromosome = coordinates_row["CHROM"]

            label = coordinates_row['label']

            padded_sequence = pad_sequence(
                chromosome=self.reference_genome[chromosome],
                start=start,
                end=end,
                sequence_length=self.sequence_length,
            )

            if padded_sequence:
                yield key, {
                    "labels": label,
                    "sequence": standardize_sequence(padded_sequence),
                    "chromosome": re.sub("chr", "", chromosome),
                    "label_start": coordinates_row["START"],
                    "label_stop": coordinates_row["STOP"]
                }
                key += 1


In [31]:
"""
----------------------------------------------------------------------------------------
Dataset loader:
----------------------------------------------------------------------------------------
"""

_DESCRIPTION = """
Dataset for benchmark of genomic deep learning models. 
"""

_TASK_HANDLERS = {
    "cage_prediction": CagePredictionHandler,
    "bulk_rna_expression": BulkRnaExpressionHandler,
    "variant_effect_causal_eqtl": VariantEffectCausalEqtl,
    "variant_effect_pathogenic_clinvar": VariantEffectPathogenicHandler,
    "variant_effect_pathogenic_omim": VariantEffectPathogenicHandler,
    "chromatin_features_histone_marks": ChromatinFeaturesHandler,
    "chromatin_features_dna_accessibility": ChromatinFeaturesHandler,
    "regulatory_element_promoter": RegulatoryElementHandler,
    "regulatory_element_enhancer": RegulatoryElementHandler,
}


# define dataset configs
class GenomicsLRAConfig(datasets.BuilderConfig):
    """
    BuilderConfig.
    """

    def __init__(self, *args, task_name: str, **kwargs):  # type: ignore
        """BuilderConfig for the location tasks dataset.
        Args:
            **kwargs: keyword arguments forwarded to super.
        """
        super().__init__()
        self.handler = _TASK_HANDLERS[task_name](task_name=task_name, **kwargs)


# DatasetBuilder
class GenomicsLRATasks(datasets.GeneratorBasedBuilder):
    """
    Tasks to annotate human genome.
    """

    VERSION = datasets.Version("1.1.0")
    BUILDER_CONFIG_CLASS = GenomicsLRAConfig

    def _info(self) -> DatasetInfo:
        return self.config.handler.get_info(description=_DESCRIPTION)

    def _split_generators(
            self, dl_manager: datasets.DownloadManager
    ) -> List[datasets.SplitGenerator]:
        """
        Downloads data files and organizes it into train/test/val splits
        """
        return self.config.handler.split_generators(dl_manager, self._cache_dir_root)

    def _generate_examples(self, handler, split):
        """
        Read data files and create examples(yield)
        Args:
            handler: The handler for the current task
            split: A string in ['train', 'test', 'valid']
        """
        yield from handler.generate_examples(split)



In [2]:
"""
----------------------------------------------------------------------------------------
Global Utils:
----------------------------------------------------------------------------------------
"""


def standardize_sequence(sequence: str):
    """
    Standardizes the sequence by replacing all unknown characters with N and
    converting to all uppercase.
    Args:
        sequence: genomic sequence to standardize
    """
    pattern = "[^ATCG]"
    # all characters to upper case
    sequence = sequence.upper()
    # replace all characters that are not A,T,C,G with N
    sequence = re.sub(pattern, "N", sequence)
    return sequence


def pad_sequence(chromosome, start, sequence_length, end=None, negative_strand=False,
                 return_new_start_stop=False):
    """
    Extends a given sequence to length sequence_length. If
    padding to the given length is outside the gene, returns
    None.
    Args:
        chromosome: Chromosome from pyfaidx extracted Fasta.
        start: Start index of original sequence.
        sequence_length: Desired sequence length. If sequence length is odd, the
            remainder is added to the end of the sequence.
        end: End index of original sequence. If no end is specified, it creates a
            centered sequence around the start index.
        negative_strand: If negative_strand, returns the reverse compliment of the
        sequence
    """
    if end:
        pad = (sequence_length - (end - start)) // 2
        start = start - pad
        end = end + pad + (sequence_length % 2)
    else:
        pad = sequence_length // 2
        end = start + pad + (sequence_length % 2)
        start = start - pad

    if start < 0 or end >= len(chromosome):
        return
    if negative_strand:
        if return_new_start_stop:
            return chromosome[start:end].reverse.complement.seq ,start, end

        return chromosome[start:end].reverse.complement.seq

    if return_new_start_stop:
        return chromosome[start:end].seq , start, end

    return chromosome[start:end].seq

In [33]:
# def download_and_extract_gz(file_url, cache_dir_root):
#         """
#         Downloads and extracts a gz file into the given cache directory. Returns the
#         full file path of the extracted gz file.
#         Args:
#             file_url: url of the gz file to be downloaded and extracted.
#             cache_dir_root: Directory to extract file into.
#         """
#         file_fname = Path(file_url).stem
#         file_complete_path = os.path.join(cache_dir_root, "downloads", file_fname)

#         if not os.path.exists(file_complete_path):
#             if not os.path.exists(file_complete_path + ".gz"):
#                 with tqdm(
#                         unit="B",
#                         unit_scale=True,
#                         unit_divisor=1024,
#                         miniters=1,
#                         desc=file_url.split("/")[-1],
#                 ) as t:
#                     urllib.request.urlretrieve(
#                         file_url, file_complete_path + ".gz", reporthook=self.hook(t) 
#                     )
#             with gzip.open(file_complete_path + ".gz", "rb") as file_in:
#                 with open(file_complete_path, "wb") as file_out:
#                     shutil.copyfileobj(file_in, file_out)
#         return file_complete_path




# def download_and_extract_gz(file_url, cache_dir_root):
#     file_complete_path = os.path.join(cache_dir_root, file_url.split("/")[-1])
#     if not os.path.exists(file_complete_path):
#         # Download the file if not exists
#         with tqdm(
#             unit="B",
#             unit_scale=True,
#             unit_divisor=1024,
#             miniters=1,
#             desc=file_url.split("/")[-1],
#         ) as t:
#             urllib.request.urlretrieve(
#                 file_url, file_complete_path + ".gz", reporthook=lambda b, bs, tbs: t.update(bs)
#             )
#         with gzip.open(file_complete_path + ".gz", "rb") as file_in:
#             with open(file_complete_path, "wb") as file_out:
#                 shutil.copyfileobj(file_in, file_out)

#     return file_complete_path


In [34]:
# def PreProcess(data_filename, label):
#     # df=pd.read_csv(data_filename, delimiter='\t',  names=['chromosome', 'position', 'REF','ALT','info'])
#     df=pd.read_csv(data_filename, header=0, sep='\t', index_col=None)   # With head column row

#     df=df[(df['REF'].str.len() == 1) & (df['ALT'].str.len() == 1)]
#     # df['chromosome'] = df['chromosome'].str.replace('chr', '')
#     df['label']=label
#     df['ref_forward_sequence']=''
#     df['alt_forward_sequence']=''

#     df.drop(columns=['Consequence'], inplace=True)
#     df = df.reset_index(drop=True)
#     # df = df[~df['CHROM'].isin(['9', '10'])] # chromosome

#     df = df.rename(columns={'CHROM': 'chromosome', 'POS':'position'})

#     return df
    
# df1 = PreProcess('gnomad.v4.1.exon.txt',0)
# df2 = PreProcess('gnomad.v4.1.intergenic.txt',1)
# df3 = PreProcess('gnomad.v4.1.proximity.txt',2)
# # df = df.drop(['Consequence'], axis=1)
# # df = df.reset_index(drop=True)
# # df = df[~df['chromosome'].isin(['9', '10'])]


# df = pd.concat([df1, df2,df3], ignore_index=True)
# # df.to_csv('gnomad.v4.1.caduceus.csv', index=False)
# df

In [5]:
# datafile='clinvar_20240805.noncoding'
# csv_filename = datafile+'.txt'
# df=pd.read_csv(datafile+'.txt', delimiter='\t')
# df

Unnamed: 0,ID,CHROM,POS,REF,ALT,CLNDN,CLNREVSTAT,CLNSIGCONF,Pathogenicity,Symbol,Transcript,Consequence,AAChange,gnomAD_AF_popmax,UKBB_AF,Ensembl,OMIM
0,753921,1,1020383,C,T,not_provided,"criteria_provided,_single_submitter",.,B,AGRN,.,intron,.,-1.000000,0.000005,ENSG00000188157,"Myasthenic syndrome, congenital, 8, with pre- ..."
1,1616405,1,1020390,C,T,Congenital_myasthenic_syndrome_8,"criteria_provided,_single_submitter",.,B,AGRN,.,intron,.,-1.000000,-1.000000,ENSG00000188157,"Myasthenic syndrome, congenital, 8, with pre- ..."
2,2970137,1,1020391,C,G,Congenital_myasthenic_syndrome_8,"criteria_provided,_single_submitter",.,B,AGRN,.,intron,.,-1.000000,0.000024,ENSG00000188157,"Myasthenic syndrome, congenital, 8, with pre- ..."
3,1642546,1,1020392,C,T,Congenital_myasthenic_syndrome_8,"criteria_provided,_single_submitter",.,B,AGRN,.,intron,.,0.000713,0.000024,ENSG00000188157,"Myasthenic syndrome, congenital, 8, with pre- ..."
4,1663222,1,1020392,C,G,Congenital_myasthenic_syndrome_8,"criteria_provided,_single_submitter",.,B,AGRN,.,intron,.,0.000994,0.000015,ENSG00000188157,"Myasthenic syndrome, congenital, 8, with pre- ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95755,1297951,X,154991043,A,G,not_provided,"criteria_provided,_single_submitter",.,B,F8,.,intron,.,-1.000000,-1.000000,ENSG00000185010,"Thrombophilia 13, X-linked, due to factor VIII..."
95756,994413,X,154992919,C,T,Hereditary_factor_VIII_deficiency_disease,"criteria_provided,_single_submitter",.,B,F8,.,intron,.,0.015161,0.000309,ENSG00000185010,"Thrombophilia 13, X-linked, due to factor VIII..."
95757,368125,X,154993157,G,A,not_provided|Hereditary_factor_VIII_deficiency...,"criteria_provided,_multiple_submitters,_no_con...",.,B,F8,.,intron,.,0.027232,0.000142,ENSG00000185010,"Thrombophilia 13, X-linked, due to factor VIII..."
95758,804143,X,154997114,C,T,not_provided|Hereditary_factor_VIII_deficiency...,"criteria_provided,_multiple_submitters,_no_con...",.,B,F8,.,intron,.,0.020523,0.000588,ENSG00000185010,"Thrombophilia 13, X-linked, due to factor VIII..."


In [3]:
# csv_filename = 'Homo_sapiens.GRCh38.109.txt'
datafile='clinvar_20240805.noncoding'

csv_filename = datafile+'.txt'
df=pd.read_csv(datafile+'.txt', delimiter='\t')

max_length= 128 # 186

columns_to_keep=['CHROM','POS','ID','REF','ALT','Pathogenicity']
df = df[columns_to_keep]

for i in range(1,23):
    df.loc[df['CHROM']==i,'CHROM']=str(i)

df=df[~df['CHROM'].isna()]
df = df[~df['CHROM'].str.contains('KI')]
df = df[~df['CHROM'].str.contains('GL')]

df['START']= df['POS']- max_length //2 -1
df['END']  = df['START'] + max_length
df['size'] = max_length

Pathogenicity_dict={'B':0,'P':1}
df['label'] = df['Pathogenicity'].map(Pathogenicity_dict)
# df['CHROM'].value_counts()

# Assuming df is your DataFrame and 'cluster' is the column with string labels
# df['CHROM'] ='chr'+df['CHROM']
df['POS'] = ((df['START'] + df['END']) / 2).round().astype(int)

df['REF']='A'
df['ALT']='A'
df['ref_forward_sequence']=''
df['alt_forward_sequence']=''

df = df.rename(columns={'CHROM': 'chromosome', 'POS':'position'})
# df.columns = df.columns.str.upper()
# df.drop(columns=['START','END','type','cluster','size'], inplace=True)
# df.drop(['ID','Pathogenicity','START','END','SIZE'], axis=1, inplace=True)
df.drop(['ID','Pathogenicity','START','END'], axis=1, inplace=True)

df

Unnamed: 0,chromosome,position,REF,ALT,size,label,ref_forward_sequence,alt_forward_sequence
0,1,1020382,A,A,128,0,,
1,1,1020389,A,A,128,0,,
2,1,1020390,A,A,128,0,,
3,1,1020391,A,A,128,0,,
4,1,1020391,A,A,128,0,,
...,...,...,...,...,...,...,...,...
95755,X,154991042,A,A,128,0,,
95756,X,154992918,A,A,128,0,,
95757,X,154993156,A,A,128,0,,
95758,X,154997113,A,A,128,0,,


In [4]:
# csv_filename = 'Homo_sapiens.GRCh38.109.txt'
# datafile='clinvar_20240805.noncoding'
# csv_filename = datafile+'.txt'
# df=pd.read_csv(csv_filename, delimiter='\t')
# # df['cluster'].value_counts()
# cluster_dict = {
#     'first_exon': 0,
#     'first_intron': 1,
#     'first_three_prime_UTR': 2,
#     'first_five_prime_UTR': 3,
#     'ncRNA_gene': 4,
#     'pseudogene': 5,
#     'smallRNA': 6
# }

# # Assuming df is your DataFrame and 'cluster' is the column with string labels
# # df['chrom'] ='chr'+df['chrom']
# df['POS'] = ((df['start'] + df['end']) / 2).round().astype(int)
# df['REF']='A'
# df['ALT']='A'
# df['label'] = df['cluster'].map(cluster_dict)
# df['ref_forward_sequence']=''
# df['alt_forward_sequence']=''

# df = df.rename(columns={'chrom': 'chromosome', 'POS':'position'})
# # df.columns = df.columns.str.upper()
# df.drop(columns=['start','end','type','cluster','size'], inplace=True)

# # filtered_df = df[df['chromosome'].str.contains('KI')]
# df = df[~df['chromosome'].str.contains('KI')]
# df = df[~df['chromosome'].str.contains('GL')]

# df

In [46]:
# filtered_df = df[df['chromosome'].str.contains('GL')]
# filtered_df 
# df_filtered = df[~df['chromosome'].str.contains('KI')]
# df_filtered
# print(filtered_df)

In [33]:
# df.loc[df['label']==0]
# df.loc[df['label']==1]
# len(df)

In [34]:
# def generate_examples(df, sequence_length=131072):
#         """
#         A generator which produces examples each with ref/alt allele
#         and corresponding binary label. The sequences are extended to
#         the desired sequence length and standardized before returning.
#         """

#         cache_dir_root="/home/sunhuaikuan/ondemand/blue_caduceus/"
#         reference_genome_file = download_and_extract_gz(H38_REFERENCE_GENOME_URL, cache_dir_root)
#         reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

#         key = 0
#         for idx, row in df.iterrows():
#             start = row["position"] - 1  # sub 1 to create idx since coords are 1-based
#             alt_allele = row["ALT"]
#             label = row["label"]
#             chromosome = row["chromosome"]

#             # get reference forward sequence
#             ref_forward = pad_sequence(
#                 chromosome=reference_genome[chromosome],
#                 start=start,
#                 sequence_length=sequence_length,
#                 negative_strand=False,
#             )

#             # only if a valid sequence returned
#             if ref_forward:
#                 # Mutate sequence with the alt allele at the SNP position,
#                 # which is always centered in the string returned from pad_sequence
#                 alt_forward = list(ref_forward)
#                 alt_forward[sequence_length // 2] = alt_allele
#                 alt_forward = "".join(alt_forward)

#                 yield key, {
#                     "label": label,
#                     # "chromosome": re.sub("chr", "", chromosome),
#                     "chromosome": chromosome,
#                     "ref_forward_sequence": standardize_sequence(ref_forward),
#                     "alt_forward_sequence": standardize_sequence(alt_forward),
#                     "position": row['position']
#                 }
#                 key += 1


# df2=generate_examples(df)

# max_rows = 2
# count = 0
# for key, example in df2:
#     if count >= max_rows:
#         break
        
#     # print(f"Key: {key}")
#     # print(f"Label: {example['label']}")
#     # print(f"Chromosome: {example['chromosome']}")
#     # print(f"Reference Forward Sequence: {example['ref_forward_sequence'][:100]}")
#     # print(f"Alternate Forward Sequence: {len(example['alt_forward_sequence'])}")
#     # print(f"Position: {example['position']}")
#     # print()
    
#     list1 =example['ref_forward_sequence']
#     list2 =example['alt_forward_sequence'] # [10, 8, 10, 10, 10, 10, 8, 10, 7]
    
#     print('list lenth='+str(len(list1)))
    
#     # Find the differing entries
#     differing_entries = [(i, list1[i], list2[i]) for i in range(len(list1)) if list1[i] != list2[i]]
    
#     # Print the results
#     if differing_entries:
#         print("Entries where the values differ:")
#         for entry in differing_entries:
#             index, value1, value2 = entry
#             print(f"Index {index}: List1 has {value1}, List2 has {value2}")
#     else:
#         print("The lists are identical.")
    
    
#     count += 1
# df2

In [35]:
# The below method will blow up memory since ref_forward_sequences,alt_forward_sequences are 131,072 each.

# ref_forward_sequences = []
# alt_forward_sequences = []

# generator=generate_examples(df)

# # Iterate over the generator and populate the dictionaries
# for key, example in generator:
#     ref_forward_sequences.append(example['ref_forward_sequence'])
#     alt_forward_sequences.append(example['alt_forward_sequence'])


# df['ref_forward_sequence'] = ref_forward_sequences
# df['alt_forward_sequence'] = alt_forward_sequences

# dataset = Dataset.from_pandas(df)
# # Save the dataset to disk
# dataset.save_to_disk('prof_dataset')

In [36]:
# def generate_and_yield_examples(df, sequence_length=131072):
#         """
#         A generator which produces examples each with ref/alt allele
#         and corresponding binary label. The sequences are extended to
#         the desired sequence length and standardized before returning.
#         """

#         cache_dir_root="/home/sunhuaikuan/ondemand/blue_caduceus/"
#         reference_genome_file = download_and_extract_gz(H38_REFERENCE_GENOME_URL, cache_dir_root)
#         reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

      
#         for idx, row in df.iterrows():
#             start = row["position"] - 1  # sub 1 to create idx since coords are 1-based
#             alt_allele = row["ALT"]
#             label = row["label"]
#             chromosome = row["chromosome"]

#             # get reference forward sequence
#             ref_forward = pad_sequence(
#                 chromosome=reference_genome[chromosome],
#                 start=start,
#                 sequence_length=sequence_length,
#                 negative_strand=False,
#             )

#             # only if a valid sequence returned
#             if ref_forward:
#                 # Mutate sequence with the alt allele at the SNP position,
#                 # which is always centered in the string returned from pad_sequence
#                 alt_forward = list(ref_forward)
#                 alt_forward[sequence_length // 2] = alt_allele
#                 alt_forward = "".join(alt_forward)

#                 yield  {
#                     "label": label,
#                     # "chromosome": re.sub("chr", "", chromosome),
#                     "chromosome": chromosome,
#                     "ref_forward_sequence": standardize_sequence(ref_forward),
#                     "alt_forward_sequence": standardize_sequence(alt_forward),
#                     "position": row['position']
#                 }
                
                
        
# features = Features({
#     'label': Value('int64'),
#     'chromosome': Value('string'),
#     'ref_forward_sequence': Value('string'),
#     'alt_forward_sequence': Value('string'),
#     'position': Value('int64')
# })

# dataset = Dataset.from_generator(lambda: generate_and_yield_examples(df), features=features)

# dataset.save_to_disk('prof_dataset')

In [10]:
%%time

def generate_and_yield_examples(df, sequence_length=131072):
        """
        A generator which produces examples each with ref/alt allele
        and corresponding binary label. The sequences are extended to
        the desired sequence length and standardized before returning.
        """

        cache_dir_root="/home/sunhuaikuan/ondemand/blue_caduceus/"
        reference_genome_file = download_and_extract_gz(H38_REFERENCE_GENOME_URL, cache_dir_root)
        reference_genome = Fasta(reference_genome_file, one_based_attributes=False)

      
        for idx, row in df.iterrows():
            start = row["position"] - 1  # sub 1 to create idx since coords are 1-based
            alt_allele = row["ALT"]
            label = row["label"]
            chromosome = row["chromosome"]

            # get reference forward sequence
            ref_forward = pad_sequence(
                chromosome=reference_genome[chromosome],
                start=start,
                sequence_length=sequence_length,
                negative_strand=False,
            )

            # only if a valid sequence returned
            if ref_forward:
                # Mutate sequence with the alt allele at the SNP position,
                # which is always centered in the string returned from pad_sequence
                alt_forward = list(ref_forward)
                alt_forward[sequence_length // 2] = alt_allele
                alt_forward = "".join(alt_forward)

                yield  {
                    "label": label,
                    # "chromosome": re.sub("chr", "", chromosome),
                    "chromosome": chromosome,
                    "ref_forward_sequence": standardize_sequence(ref_forward),
                    "alt_forward_sequence": standardize_sequence(alt_forward),
                    "position": row['position']
                }
                        
features = Features({
    'label': Value('int64'),
    'chromosome': Value('string'),
    'ref_forward_sequence': Value('string'),
    'alt_forward_sequence': Value('string'),
    'position': Value('int64')
})

dataset = Dataset.from_generator(lambda: generate_and_yield_examples(df), features=features, keep_in_memory=True)
# dataset.save_to_disk('prof_dataset')

dataset_dict = DatasetDict({
    'train': dataset,
    # 'test': dataset
})

save_path="pathgenicity_noncoding_dataset"
dataset_dict.save_to_disk(save_path)

print("Dataset saved in batches successfully.")

dataset_dict = load_from_disk(save_path)
train = dataset_dict['train']

df = train.to_pandas()
df.head(2)

Generating train split: 0 examples [00:00, ? examples/s]

KeyboardInterrupt: 

In [5]:
%%time

# os.environ["HF_DATASETS_CACHE"] = "/blue/xiaofan/sunhuaikuan/larger/cache"


import os
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Value, concatenate_datasets
from datasets.utils.file_utils import cached_path


# Define the function to generate examples from a chunk of data
def generate_and_yield_examples(df_chunk, sequence_length=131072):
    # cache_dir_root = "/home/sunhuaikuan/ondemand/blue_caduceus/"
    # reference_genome_file = download_and_extract_gz(H38_REFERENCE_GENOME_URL, cache_dir_root)
    # reference_genome = Fasta(reference_genome_file, one_based_attributes=False)
    reference_genome = Fasta('/home/sunhuaikuan/ondemand/blue_caduceus/hg38.fa', one_based_attributes=False)

    for idx, row in df_chunk.iterrows():
        start = row["position"] - 1  # sub 1 to create idx since coords are 1-based
        alt_allele = row["ALT"]
        label = row["label"]
        chromosome = row["chromosome"]
        sequence_length = row["size"]

        # Get reference forward sequence
        ref_forward = pad_sequence(
            chromosome=reference_genome[chromosome],
            start=start,
            sequence_length=sequence_length,
            negative_strand=False,
        )

        # Only if a valid sequence returned
        if ref_forward:
            # Mutate sequence with the alt allele at the SNP position
            alt_forward = list(ref_forward)
            alt_forward[sequence_length // 2] = alt_allele
            alt_forward = "".join(alt_forward)

            yield {
                "label": label,
                "chromosome": chromosome,
                "ref_forward_sequence": standardize_sequence(ref_forward),
                "alt_forward_sequence": standardize_sequence(alt_forward),
                "position": row['position']
            }

# Define the features of the dataset
features = Features({
    'label': Value('int64'),
    'chromosome': Value('string'),
    'ref_forward_sequence': Value('string'),
    'alt_forward_sequence': Value('string'),
    'position': Value('int64')
})


# Define the save path
save_path = "pathgenicity_noncoding_multisets"


# Check if the dataset already exists
if os.path.exists(save_path):
    # Load the existing dataset
    dataset_dict = DatasetDict.load_from_disk(save_path)
    existing_dataset = dataset_dict['train']
else:
    # Create an empty dataset if none exists
    existing_dataset = None



def clear_cache():
    cache_dir = cached_path("/home/sunhuaikuan/.cache")
    if os.path.exists(cache_dir):
        for filename in os.listdir(cache_dir):
            file_path = os.path.join(cache_dir, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f"Failed to delete {file_path}. Reason: {e}")



import numpy as np
from datasets import load_from_disk, concatenate_datasets
import gc


# Define chunk size
chunk_size = 30000  # Adjust chunk size based on available memory

# Split df into chunks
num_chunks = len(df) // chunk_size + int(len(df) % chunk_size != 0)

existing_dataset = None

# save_path_temp = save_path+'_mytemp'

for i in range(num_chunks):
    # Get the current chunk
    df_chunk = df.iloc[i * chunk_size: (i + 1) * chunk_size]

    # Create a new dataset from the current chunk
    new_dataset = Dataset.from_generator(lambda: generate_and_yield_examples(df_chunk), features=features, keep_in_memory=True,cache_dir=None)  #,cache_dir=None

    # Save the dataset chunk to disk
    chunk_save_path = f"{save_path}/train_{i}"
    new_dataset.save_to_disk(chunk_save_path)

    print(f"Saved dataset chunk {i + 1}")
    
    clear_cache()
    

print(f"completed")

Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/30000 [00:00<?, ? examples/s]

Saved dataset chunk 1


Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/30000 [00:00<?, ? examples/s]

Saved dataset chunk 2


Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/30000 [00:00<?, ? examples/s]

Saved dataset chunk 3


Generating train split: 0 examples [00:00, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5760 [00:00<?, ? examples/s]

Saved dataset chunk 4
completed
CPU times: user 11.5 s, sys: 573 ms, total: 12.1 s
Wall time: 43 s


## Combine all the save_path/train_# into one giant dataset

In [6]:
from datasets import load_from_disk, concatenate_datasets

save_path = "pathgenicity_noncoding_multisets"

combined_dataset = None
num_chunks = 4
for i in range(num_chunks):
    # Load each dataset chunk from disk
    chunk_save_path = f"{save_path}/train_{i}"
    dataset_chunk = load_from_disk(chunk_save_path)

    # Concatenate incrementally
    if combined_dataset is None:
        combined_dataset = dataset_chunk
    else:
        combined_dataset = concatenate_datasets([combined_dataset, dataset_chunk])

    # Optional: clean up memory
    del dataset_chunk
    gc.collect()

# Save the combined dataset to disk
combined_save_path = f"{save_path}/combined_dataset"
combined_dataset.save_to_disk(combined_save_path)


Saving the dataset (0/1 shards):   0%|          | 0/95760 [00:00<?, ? examples/s]

In [7]:
save_path = "pathgenicity_noncoding_multisets"
df = load_from_disk(save_path+"/combined_dataset")
df = df.to_pandas()
df

Unnamed: 0,label,chromosome,ref_forward_sequence,alt_forward_sequence,position
0,0,1,GGACGGTGGAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTA...,GGACGGTGGAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTA...,1020382
1,0,1,GGAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGC...,GGAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGC...,1020389
2,0,1,GAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCA...,GAGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCA...,1020390
3,0,1,AGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCAA...,AGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCAA...,1020391
4,0,1,AGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCAA...,AGGAGATCCTCAACGTGGACCCGGTGCAGCACACGTACTCCTGCAA...,1020391
...,...,...,...,...,...
95755,0,X,ACATCATAGCTAAGGTCAAGCAATGGAAGGGGAAAGACAGAGAGAA...,ACATCATAGCTAAGGTCAAGCAATGGAAGGGGAAAGACAGAGAGAA...,154991042
95756,0,X,AAATGCTAACTGTTATAGATAGATTCAGTTGTTTGTACTTCTCTGC...,AAATGCTAACTGTTATAGATAGATTCAGTTGTTTGTACTTCTCTGC...,154992918
95757,0,X,GGAAGACTTTATCATCTTCTTTCTCCCTTTGACTGGTCTGATCATC...,GGAAGACTTTATCATCTTCTTTCTCCCTTTGACTGGTCTGATCATC...,154993156
95758,0,X,GACCACTGTATCATAAACCTCAGCCTGGATGGTAGGACCTAGCAGA...,GACCACTGTATCATAAACCTCAGCCTGGATGGTAGGACCTAGCAGA...,154997113


## demo loading from datsetsdict

In [12]:
# from datasets import load_dataset, load_from_disk

# mydatasets = load_from_disk(save_path)
# print(mydatasets)

# train = mydatasets['train']

# df = train.to_pandas()
# df
# df.head(2)