In [1]:
from pathlib import Path

import torch

In [3]:
# project_path = Path("/projects/b1171/ylk4626/project/Chimera")
project_path = Path("../")

# Load Data from fq

In [4]:
from needletail import parse_fastx_file


def read_records(fastq):
    if isinstance(fastq, Path):
        fastq = fastq.as_posix()
    return list(parse_fastx_file(fastq))

In [5]:
dataset_folder = project_path / "data/train_data/80000"

In [6]:
train_data_path = dataset_folder / "train.fq.target.fq.gz"
val_data_path = dataset_folder / "val.fq.target.fq.gz"
test_data_path = dataset_folder / "test.fq.target.fq.gz"

In [7]:
train_data = read_records(test_data_path)

In [8]:
train_data[0].seq

'GGTAGGCGGGTTTCAGGGGCTCTTTGGTGAAGAGTTTTATGGCGTCAGCGAAGGGTTGTAGTAGCCCGTAGGGGCCTACAACGTTGGGGCCTTTGCGTAGTTGCTGTATCGCCTAGAATTTTTCGTTCGGTAAGCATTAGGAATGCCATTGCGATTAGAATGGGTACAATGAGGAGTAGGAGGTTGGCCATGGGTATGTTGTTAAGAAGAGGAATTGAACCTCTGACTGTAAAGTTTTAAGTTTTATGCGATTACCGGGCTCTGCCATCTTAACAAACCCCTGTTCTTGGGTGGGTGTGGGTATAATGCTAAGTTGAGATGATATCATTTACGGGGGAAGGCGCTTTGTGAAGTAGGCCTTATTTCTCTTGTCCTTTCGTACAGGGAGGAATTTGAAGTAGATAGAAACCGACCTGGATTACTCCGGTCTGAACTCAGATCACGTAGGACTTTAATGGTTGAACAAACGAACCTTTAATAGCGGCTGCACCATTGGGATGTCCTGATCCAACATCGAGGTCGTAAACCCTATTGTTGATATGGACTCTAGATAGGATTGCGCTGTTATCCCTAGGGTAACTTGTTCCGTTGGTCAAGTTATTGGATCAATTGAGTATAGTAGTTCGCTTTGACTGGTGAAGTCTTAGCATGTACTGCTCGGAGGTTGGGTTCTGCTCCGAGGTCGCCCCAACCGAAATTTTTAGATGCCGGTTTGGTCGTTTAGGACCTGTGGGTTTGTTAGGTACTGTTTGCATTAATAAATTAAAGCTCCATAGGGTCTTCTCGTCTTGCTGTGTCATGCCCGCCTCTTCACGGGCAGGTCAATTTCACTGGTTAAAAGTAAGAGACAGCTGAACCCTCGTGGAGCCATTCATACAGGTCCCTATTTAAGGAACAAGTGATTATGCTACCTTTGCACGGTTAGGGTACCAGGACCATTAAACATGTGTCACTGGGCAGGCGGTGCCTGATACTGGTGATGCTAGAGGTGATGTTTTT

In [9]:
from deepbiop import fq

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}


def parse_target(name):
    """Parse the target name to get the name and the target."""
    rid, target = name.split("|")
    return rid, int(target)


def encode_qual(qual, offset=33):
    """Encode the quality score."""
    return list(fq.encode_qual(qual, offset))

In [52]:
from datasets import load_dataset

data_files = {
    "train": (dataset_folder / "test.fq.target.fq.parquet").as_posix(),
}

train_dataset = load_dataset(
    "parquet",
    data_files=data_files,
    num_proc=2,
).with_format("torch")

Setting num_proc from 2 back to 1 for the train split to disable multiprocessing as it only contains one shard.


Generating train split: 0 examples [00:00, ? examples/s]

In [53]:
train_dataset["train"][0]

{'id': '349022f1-b68c-447e-bf8d-79d9e4c939ee|0',
 'seq': 'GGTAGGCGGGTTTCAGGGGCTCTTTGGTGAAGAGTTTTATGGCGTCAGCGAAGGGTTGTAGTAGCCCGTAGGGGCCTACAACGTTGGGGCCTTTGCGTAGTTGCTGTATCGCCTAGAATTTTTCGTTCGGTAAGCATTAGGAATGCCATTGCGATTAGAATGGGTACAATGAGGAGTAGGAGGTTGGCCATGGGTATGTTGTTAAGAAGAGGAATTGAACCTCTGACTGTAAAGTTTTAAGTTTTATGCGATTACCGGGCTCTGCCATCTTAACAAACCCCTGTTCTTGGGTGGGTGTGGGTATAATGCTAAGTTGAGATGATATCATTTACGGGGGAAGGCGCTTTGTGAAGTAGGCCTTATTTCTCTTGTCCTTTCGTACAGGGAGGAATTTGAAGTAGATAGAAACCGACCTGGATTACTCCGGTCTGAACTCAGATCACGTAGGACTTTAATGGTTGAACAAACGAACCTTTAATAGCGGCTGCACCATTGGGATGTCCTGATCCAACATCGAGGTCGTAAACCCTATTGTTGATATGGACTCTAGATAGGATTGCGCTGTTATCCCTAGGGTAACTTGTTCCGTTGGTCAAGTTATTGGATCAATTGAGTATAGTAGTTCGCTTTGACTGGTGAAGTCTTAGCATGTACTGCTCGGAGGTTGGGTTCTGCTCCGAGGTCGCCCCAACCGAAATTTTTAGATGCCGGTTTGGTCGTTTAGGACCTGTGGGTTTGTTAGGTACTGTTTGCATTAATAAATTAAAGCTCCATAGGGTCTTCTCGTCTTGCTGTGTCATGCCCGCCTCTTCACGGGCAGGTCAATTTCACTGGTTAAAAGTAAGAGACAGCTGAACCCTCGTGGAGCCATTCATACAGGTCCCTATTTAAGGAACAAGTGATTATGCTACCTTTGCACGGTTAGGGTACCAGGACCATTAA

In [54]:
from pathlib import Path

from transformers import PreTrainedTokenizer


class Tokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids"]

    def __init__(
        self,
        model_max_length: int,
        bos_token="[BOS]",
        eos_token="[SEP]",
        sep_token="[SEP]",
        cls_token="[CLS]",
        pad_token="[PAD]",
        mask_token="[MASK]",
        unk_token="[UNK]",
        **kwargs,
    ):
        """Character tokenizer for Hugging Face transformers.

        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = ("A", "C", "G", "T", "N")
        self.model_max_length = model_max_length

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(self.characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
        add_prefix_space = kwargs.pop("add_prefix_space", False)
        padding_side = kwargs.pop("padding_side", "right")

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=add_prefix_space,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> list[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def get_special_tokens_mask(
        self,
        token_ids_0: list[int],
        token_ids_1: list[int] | None = None,
        *,
        already_has_special_tokens: bool = False,
    ) -> list[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def build_inputs_with_special_tokens(
        self, token_ids_0: list[int], token_ids_1: list[int] | None = None
    ) -> list[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_vocab(self) -> dict[str, int]:
        return self._vocab_str_to_int

    def decode(self, token_ids, skip_special_tokens=False, **kwargs):
        """Decode ids back to sequence string."""
        if isinstance(token_ids, dict):
            token_ids = token_ids["input_ids"]

        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        if isinstance(token_ids, list) and isinstance(token_ids[0], list):
            token_ids = token_ids[0]  # Take first sequence if batch

        tokens = [self._convert_id_to_token(id) for id in token_ids["input_ids"]]
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in self.all_special_tokens]

        return self.convert_tokens_to_string(tokens)

In [55]:
tokenizer = Tokenizer(model_max_length=100000)

In [56]:
tokenizer(train_dataset["train"][0]["seq"])

{'input_ids': [0, 9, 9, 10, 7, 9, 9, 8, 9, 9, 9, 10, 10, 10, 8, 7, 9, 9, 9, 9, 8, 10, 8, 10, 10, 10, 9, 9, 10, 9, 7, 7, 9, 7, 9, 10, 10, 10, 10, 7, 10, 9, 9, 8, 9, 10, 8, 7, 9, 8, 9, 7, 7, 9, 9, 9, 10, 10, 9, 10, 7, 9, 10, 7, 9, 8, 8, 8, 9, 10, 7, 9, 9, 9, 9, 8, 8, 10, 7, 8, 7, 7, 8, 9, 10, 10, 9, 9, 9, 9, 8, 8, 10, 10, 10, 9, 8, 9, 10, 7, 9, 10, 10, 9, 8, 10, 9, 10, 7, 10, 8, 9, 8, 8, 10, 7, 9, 7, 7, 10, 10, 10, 10, 10, 8, 9, 10, 10, 8, 9, 9, 10, 7, 7, 9, 8, 7, 10, 10, 7, 9, 9, 7, 7, 10, 9, 8, 8, 7, 10, 10, 9, 8, 9, 7, 10, 10, 7, 9, 7, 7, 10, 9, 9, 9, 10, 7, 8, 7, 7, 10, 9, 7, 9, 9, 7, 9, 10, 7, 9, 9, 7, 9, 9, 10, 10, 9, 9, 8, 8, 7, 10, 9, 9, 9, 10, 7, 10, 9, 10, 10, 9, 10, 10, 7, 7, 9, 7, 7, 9, 7, 9, 9, 7, 7, 10, 10, 9, 7, 7, 8, 8, 10, 8, 10, 9, 7, 8, 10, 9, 10, 7, 7, 7, 9, 10, 10, 10, 10, 7, 7, 9, 10, 10, 10, 10, 7, 10, 9, 8, 9, 7, 10, 10, 7, 8, 8, 9, 9, 9, 8, 10, 8, 10, 9, 8, 8, 7, 10, 8, 10, 10, 7, 7, 8, 7, 7, 7, 8, 8, 8, 8, 10, 9, 10, 10, 8, 10, 10, 9, 9, 9, 10, 9, 9, 9, 10, 9, 1

In [57]:
tokenizer.decode(tokenizer(train_dataset["train"][0]["seq"]))

'[CLS]GGTAGGCGGGTTTCAGGGGCTCTTTGGTGAAGAGTTTTATGGCGTCAGCGAAGGGTTGTAGTAGCCCGTAGGGGCCTACAACGTTGGGGCCTTTGCGTAGTTGCTGTATCGCCTAGAATTTTTCGTTCGGTAAGCATTAGGAATGCCATTGCGATTAGAATGGGTACAATGAGGAGTAGGAGGTTGGCCATGGGTATGTTGTTAAGAAGAGGAATTGAACCTCTGACTGTAAAGTTTTAAGTTTTATGCGATTACCGGGCTCTGCCATCTTAACAAACCCCTGTTCTTGGGTGGGTGTGGGTATAATGCTAAGTTGAGATGATATCATTTACGGGGGAAGGCGCTTTGTGAAGTAGGCCTTATTTCTCTTGTCCTTTCGTACAGGGAGGAATTTGAAGTAGATAGAAACCGACCTGGATTACTCCGGTCTGAACTCAGATCACGTAGGACTTTAATGGTTGAACAAACGAACCTTTAATAGCGGCTGCACCATTGGGATGTCCTGATCCAACATCGAGGTCGTAAACCCTATTGTTGATATGGACTCTAGATAGGATTGCGCTGTTATCCCTAGGGTAACTTGTTCCGTTGGTCAAGTTATTGGATCAATTGAGTATAGTAGTTCGCTTTGACTGGTGAAGTCTTAGCATGTACTGCTCGGAGGTTGGGTTCTGCTCCGAGGTCGCCCCAACCGAAATTTTTAGATGCCGGTTTGGTCGTTTAGGACCTGTGGGTTTGTTAGGTACTGTTTGCATTAATAAATTAAAGCTCCATAGGGTCTTCTCGTCTTGCTGTGTCATGCCCGCCTCTTCACGGGCAGGTCAATTTCACTGGTTAAAAGTAAGAGACAGCTGAACCCTCGTGGAGCCATTCATACAGGTCCCTATTTAAGGAACAAGTGATTATGCTACCTTTGCACGGTTAGGGTACCAGGACCATTAAACATGTGTCACTGGGCAGGCGGTGCCTGATACTGGTGATGCTAGAGGTGATG

In [58]:
train_dataset["train"][0]["seq"]

'GGTAGGCGGGTTTCAGGGGCTCTTTGGTGAAGAGTTTTATGGCGTCAGCGAAGGGTTGTAGTAGCCCGTAGGGGCCTACAACGTTGGGGCCTTTGCGTAGTTGCTGTATCGCCTAGAATTTTTCGTTCGGTAAGCATTAGGAATGCCATTGCGATTAGAATGGGTACAATGAGGAGTAGGAGGTTGGCCATGGGTATGTTGTTAAGAAGAGGAATTGAACCTCTGACTGTAAAGTTTTAAGTTTTATGCGATTACCGGGCTCTGCCATCTTAACAAACCCCTGTTCTTGGGTGGGTGTGGGTATAATGCTAAGTTGAGATGATATCATTTACGGGGGAAGGCGCTTTGTGAAGTAGGCCTTATTTCTCTTGTCCTTTCGTACAGGGAGGAATTTGAAGTAGATAGAAACCGACCTGGATTACTCCGGTCTGAACTCAGATCACGTAGGACTTTAATGGTTGAACAAACGAACCTTTAATAGCGGCTGCACCATTGGGATGTCCTGATCCAACATCGAGGTCGTAAACCCTATTGTTGATATGGACTCTAGATAGGATTGCGCTGTTATCCCTAGGGTAACTTGTTCCGTTGGTCAAGTTATTGGATCAATTGAGTATAGTAGTTCGCTTTGACTGGTGAAGTCTTAGCATGTACTGCTCGGAGGTTGGGTTCTGCTCCGAGGTCGCCCCAACCGAAATTTTTAGATGCCGGTTTGGTCGTTTAGGACCTGTGGGTTTGTTAGGTACTGTTTGCATTAATAAATTAAAGCTCCATAGGGTCTTCTCGTCTTGCTGTGTCATGCCCGCCTCTTCACGGGCAGGTCAATTTCACTGGTTAAAAGTAAGAGACAGCTGAACCCTCGTGGAGCCATTCATACAGGTCCCTATTTAAGGAACAAGTGATTATGCTACCTTTGCACGGTTAGGGTACCAGGACCATTAAACATGTGTCACTGGGCAGGCGGTGCCTGATACTGGTGATGCTAGAGGTGATGTTTTT

In [59]:
tokenizer.all_special_tokens

['[BOS]', '[SEP]', '[UNK]', '[PAD]', '[CLS]', '[MASK]']

In [66]:
train_dataset["train"][0]["qual"]

tensor([6, 7, 7,  ..., 3, 3, 3])

In [60]:
from functools import partial

from datasets import Dataset
from transformers import (
    AutoTokenizer,
)

IGNORE_INDEX = -100


def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    if len(data["seq"]) >= max_length:
        quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
        normalized_quals = torch.nn.functional.normalize(quals, dim=0)
    else:
        quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
        normalized_quals = torch.nn.functional.normalize(quals, dim=0)

    # change id to ascii values
    rid, target = parse_target(data["id"])

    tokenized_inputs.update({"input_quals": normalized_quals, "label": target})
    return tokenized_inputs


def tokenize_and_align_labels_and_quals_ids(
    data, tokenizer, max_length, pad_qual=0, pad_label=IGNORE_INDEX, max_id_length=256
):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    truncation = False

    if len(data["seq"]) >= max_length:
        truncation = True
        quals = torch.cat((data["qual"][: max_length - 1], torch.tensor([pad_qual]))).float()
        normalized_quals = torch.nn.functional.normalize(quals, dim=0)
    else:
        quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
        normalized_quals = torch.nn.functional.normalize(quals, dim=0)

    # change id to ascii values
    rid, target = parse_target(data["id"])

    new_id = [len(data["id"]), int(truncation)]
    new_id += [ord(char) for char in rid]

    if len(new_id) > max_id_length:
        new_id = new_id[:max_id_length]
    elif len(new_id) < max_id_length:
        new_id += [0] * (max_id_length - len(new_id))

    tokenized_inputs.update({"input_quals": normalized_quals, "id": new_id, "label": target})
    return tokenized_inputs


def tokenize_dataset(dataset, tokenizer, max_length):
    """Tokenizes the input dataset using the provided tokenizer and aligns labels and qualities.

    Args:
        dataset (Dataset): The input dataset to be tokenized.
        tokenizer (Tokenizer): The tokenizer to be used for tokenization.
        max_length (int): The maximum length of the tokenized sequences.

    Returns:
        Tokenized dataset with aligned labels and qualities.

    Raises:
        ValueError: If the dataset is empty or if the tokenizer is not provided.
        TypeError: If the dataset is not of type Dataset or if the tokenizer is not of type Tokenizer.
    """
    if not dataset:
        raise ValueError("Input dataset is empty")
    if not tokenizer:
        raise ValueError("Tokenizer is not provided")
    if not isinstance(dataset, Dataset):
        raise TypeError("Input dataset must be of type Dataset")

    return dataset.map(
        partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=max_length)
    ).remove_columns(["id", "seq", "qual"])

In [61]:
data_train = train_dataset.map(
    partial(
        tokenize_and_align_labels_and_quals,
        tokenizer=tokenizer,
        max_length=100000,
    ),
    num_proc=4,  # type: ignore
).remove_columns(["id", "seq", "qual"])

Map (num_proc=4):   0%|          | 0/8000 [00:00<?, ? examples/s]

In [65]:
data_train["train"][0]

{'input_ids': tensor([ 0,  9,  9,  ...,  7, 10,  1]),
 'input_quals': tensor([0.0044, 0.0052, 0.0052,  ..., 0.0022, 0.0022, 0.0000]),
 'label': tensor(0)}

In [78]:
data_collator = DataCollator(tokenizer)

In [79]:
train_dataloader = DataLoader(
    dataset=data_train,
    batch_size=12,
    num_workers=1,
    collate_fn=data_collator,
    shuffle=True,
)

In [76]:
import multiprocessing
from functools import partial
from pathlib import Path
from typing import Any

from datasets import Dataset as HuggingFaceDataset
from datasets import load_dataset
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from transformers import DataCollatorWithPadding


def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer."""
    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded


class DataCollator(DataCollatorWithPadding):
    def torch_call(self, features):
        import torch

        label_name = "label" if "label" in features[0] else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0] else None

        qual_name = "input_quals"
        qual_pad_token_id = 0
        input_quals = [feature[qual_name] for feature in features]

        id_name = "id"  # for predction dataset

        no_labels_features = [
            {k: v for k, v in feature.items() if k not in [qual_name, label_name, id_name]} for feature in features
        ]

        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[qual_name] = [
                to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual)) for qual in input_quals
            ]
        else:
            batch[qual_name] = [
                [qual_pad_token_id] * (sequence_length - len(qual)) + to_list(qual) for qual in input_quals
            ]

        batch[qual_name] = torch.tensor(batch[qual_name], dtype=torch.float32)

        # for predction dataset and save id feature
        if id_name in features[0]:
            batch[id_name] = torch.tensor([to_list(feature[id_name]) for feature in features], dtype=torch.int8)

        return batch


class FqDataModule(LightningDataModule):
    """`LightningDataModule` for the fq dataset.

    A `LightningDataModule` implements 7 key methods:

    ```python
        def prepare_data(self):
        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
        # Download data, pre-process, split, save to disk, etc...

        def setup(self, stage):
        # Things to do on every process in DDP.
        # Load data, set variables, etc...

        def train_dataloader(self):
        # return train dataloader

        def val_dataloader(self):
        # return validation dataloader

        def test_dataloader(self):
        # return test dataloader

        def predict_dataloader(self):
        # return predict dataloader

        def teardown(self, stage):
        # Called on every process in DDP.
        # Clean up after fit or test.
    ```

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://lightning.ai/docs/pytorch/latest/data/datamodule.html
    """

    def __init__(
        self,
        tokenizer: AutoTokenizer,
        train_data_path: Path,
        val_data_path: Path | None = None,
        test_data_path: Path | None = None,
        predict_data_path: Path | None = None,
        train_val_test_split: tuple[float, float, float] = (0.7, 0.2, 0.1),
        batch_size: int = 12,
        num_workers: int = 0,
        max_train_samples: int | None = None,
        max_val_samples: int | None = None,
        max_test_samples: int | None = None,
        max_predict_samples: int | None = None,
        *,
        pin_memory: bool = False,
    ) -> None:
        """Initialize a `FqDataModule`.

        :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
        :param batch_size: The batch size. Defaults to `64`.
        :param num_workers: The number of workers. Defaults to `0`.
        :param pin_memory: Whether to pin memory. Defaults to `False`.
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train: Dataset | None = None
        self.data_val: Dataset | None = None
        self.data_test: Dataset | None = None
        self.batch_size_per_device = batch_size
        self.data_collator = DataCollator(tokenizer)

    @property
    def num_classes(self) -> int:
        """Get the number of classes."""
        return 2

    def prepare_data(self) -> None:
        """Encode the FastQ data to Parquet format."""
        data_paths = [self.hparams.train_data_path]

        if self.hparams.val_data_path is not None:
            data_paths.append(self.hparams.val_data_path)

        if self.hparams.test_data_path is not None:
            data_paths.append(self.hparams.test_data_path)

        if self.hparams.predict_data_path is not None:
            data_paths.append(self.hparams.predict_data_path)

        for data_path in data_paths:
            if Path(data_path).suffix == ".parquet":
                pass
            else:
                msg = f"Data file {data_path} is not in Parquet format."
                raise ValueError(msg)

        self.hparams.train_data_path = Path(self.hparams.train_data_path).with_suffix(".parquet").as_posix()

        if self.hparams.val_data_path is not None:
            self.hparams.val_data_path = Path(self.hparams.val_data_path).with_suffix(".parquet").as_posix()

        if self.hparams.test_data_path is not None:
            self.hparams.test_data_path = Path(self.hparams.test_data_path).with_suffix(".parquet").as_posix()

        if self.hparams.predict_data_path is not None:
            self.hparams.predict_data_path = Path(self.hparams.predict_data_path).with_suffix(".parquet").as_posix()

    def setup(self, stage: str | None = None) -> None:
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
        `self.setup()` once the data is prepared and available for use.

        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
        """
        # Divide batch size by the number of devices.
        if self.trainer is not None:
            if self.hparams.batch_size % self.trainer.world_size != 0:
                msg = f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
                raise RuntimeError(msg)
            self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

        if stage == "predict":
            if not self.hparams.predict_data_path:
                msg = "Predict data path is required for prediction stage."
                raise ValueError(msg)

            num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
            data_files = {"predict": self.hparams.predict_data_path}
            predict_dataset = load_dataset(
                "parquet",
                data_files=data_files,
                num_proc=max(1, num_proc),
            ).with_format("torch")

            predict_dataset = predict_dataset["predict"]
            if self.hparams.max_predict_samples is not None:
                max_predict_samples = min(self.hparams.max_predict_samples, len(predict_dataset))
                predict_dataset = HuggingFaceDataset.from_dict(predict_dataset[:max_predict_samples]).with_format(
                    "torch"
                )

            self.data_predict = predict_dataset.map(
                partial(
                    tokenize_and_align_labels_and_quals_ids,
                    tokenizer=self.hparams.tokenizer,
                    max_length=self.hparams.tokenizer.max_len_single_sentence,
                ),
                num_proc=max(1, num_proc),  # type: ignore
            ).remove_columns(["seq", "qual", "target"])
            del predict_dataset
            return

        # load and split datasets only if not loaded already
        if not self.data_train and not self.data_val and not self.data_test:
            num_proc = min(self.hparams.num_workers, multiprocessing.cpu_count() - 1)
            data_files = {}
            data_files["train"] = self.hparams.train_data_path

            if self.hparams.val_data_path is not None:
                data_files["validation"] = self.hparams.val_data_path

            if self.hparams.test_data_path is not None:
                data_files["test"] = self.hparams.test_data_path

            if self.hparams.val_data_path is None or self.hparams.test_data_path is None:
                split_percent = self.hparams.train_val_test_split

                train_dataset = load_dataset(
                    "parquet",
                    data_files=data_files,
                    num_proc=max(1, num_proc),
                    split=f"train[:{split_percent[0]}%]",
                ).with_format("torch")

                val_dataset = load_dataset(
                    "parquet",
                    data_files=data_files,
                    num_proc=max(1, num_proc),
                    split=f"train[{split_percent[0]}%:{split_percent[0] + split_percent[1]}%]",
                ).with_format("torch")

                test_dataset = load_dataset(
                    "parquet",
                    data_files=data_files,
                    num_proc=max(1, num_proc),
                    split=f"train[{split_percent[0] + split_percent[1]}%:]",
                ).with_format("torch")

            else:
                raw_datasets = load_dataset("parquet", data_files=data_files, num_proc=max(1, num_proc)).with_format(
                    "torch"
                )

                train_dataset = raw_datasets["train"]
                val_dataset = raw_datasets["validation"]
                test_dataset = raw_datasets["test"]

            if self.hparams.max_train_samples is not None:
                max_train_samples = min(self.hparams.max_train_samples, len(train_dataset))
                train_dataset = HuggingFaceDataset.from_dict(train_dataset[:max_train_samples]).with_format("torch")

            if self.hparams.max_val_samples is not None:
                max_val_samples = min(self.hparams.max_val_samples, len(val_dataset))
                val_dataset = HuggingFaceDataset.from_dict(val_dataset[:max_val_samples]).with_format("torch")

            if self.hparams.max_test_samples is not None:
                max_test_samples = min(self.hparams.max_test_samples, len(test_dataset))
                test_dataset = HuggingFaceDataset.from_dict(test_dataset[:max_test_samples]).with_format("torch")

            self.data_train = train_dataset.map(
                partial(
                    tokenize_and_align_labels_and_quals,
                    tokenizer=self.hparams.tokenizer,
                    max_length=self.hparams.tokenizer.max_len_single_sentence,
                ),
                num_proc=max(1, num_proc),  # type: ignore
            ).remove_columns(["id", "seq", "qual"])

            self.data_val = val_dataset.map(
                partial(
                    tokenize_and_align_labels_and_quals,
                    tokenizer=self.hparams.tokenizer,
                    max_length=self.hparams.tokenizer.max_len_single_sentence,
                ),
                num_proc=max(1, num_proc),  # type: ignore
            ).remove_columns(["id", "seq", "qual"])

            self.data_test = test_dataset.map(
                partial(
                    tokenize_and_align_labels_and_quals,
                    tokenizer=self.hparams.tokenizer,
                    max_length=self.hparams.tokenizer.max_len_single_sentence,
                ),
                num_proc=max(1, num_proc),  # type: ignore
            ).remove_columns(["id", "seq", "qual"])

            del train_dataset, val_dataset, test_dataset

    def train_dataloader(self) -> DataLoader[Any]:
        """Create and return the train dataloader.

        :return: The train dataloader.
        """
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.data_collator.torch_call,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        """Create and return the validation dataloader.

        :return: The validation dataloader.
        """
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.data_collator.torch_call,
            shuffle=False,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        """Create and return the test dataloader.

        :return: The test dataloader.
        """
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.data_collator.torch_call,
            shuffle=False,
        )

    def predict_dataloader(self) -> DataLoader[Any]:
        """Create and return the predict dataloader.

        :return: The predict dataloader.
        """
        return DataLoader(
            dataset=self.data_predict,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=self.data_collator.torch_call,
            shuffle=False,
        )

    def teardown(self, stage: str | None = None) -> None:
        """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,.

        `trainer.test()`, and `trainer.predict()`.

        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
            Defaults to ``None``.
        """

    def state_dict(self) -> dict[Any, Any]:
        """Called when saving a checkpoint. Implement to generate and save the datamodule state.

        :return: A dictionary containing the datamodule state that you want to save.
        """
        return {}

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Called when loading a checkpoint. Implement to reload datamodule state given datamodule.

        `state_dict()`.

        :param state_dict: The datamodule state returned by `self.state_dict()`.
        """

In [84]:
fq_data_module = FqDataModule(tokenizer, dataset_folder / "test.fq.target.fq.gz")

In [87]:
import chimera

train_data_path = "../data/train_data/80000/test.fq.target.fq.parquet"
max_len = 100000
# Load the training data
tokenizer = chimera.data.tokenizer.Tokenizer(model_max_length=max_len)
fq_data_module = chimera.data.fq.DataModule(tokenizer, train_data_path)
fq_data_module.prepare_data()
fq_data_module.setup()
train_data_loader = fq_data_module.train_dataloader()

Map:   0%|          | 0/5600 [00:00<?, ? examples/s]

Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

In [88]:
data_iterator = iter(train_data_loader)

In [91]:
b1 = next(data_iterator)

In [95]:
b1["input_ids"].shape

torch.Size([12, 8265])

In [97]:
b1["input_quals"].shape

torch.Size([12, 8265])

In [65]:
def resume_read_name(bytes_data: torch.Tensor | list[int]):
    # Convert bytes to string
    if isinstance(bytes_data, torch.Tensor):
        if bytes_data.numel() == 0:
            return ""
        bytes_data = bytes_data.tolist()
    elif not bytes_data:
        return ""
        
    try:
        read_name_length = bytes_data[0]
        if read_name_length <= 0 or read_name_length >= len(bytes_data):
            return ""
        read_name = bytes_data[1:1+read_name_length] 
        return "".join(chr(b) for b in read_name if 32 <= b <= 126)  # Only valid ASCII printable chars
    except (IndexError, TypeError):
        return ""

In [103]:
from dataclasses import dataclass

@dataclass
class Predict:
    name: str
    label: int
    sv: str | None = None

def collect_predict_from_file(path):
    predicts = torch.load(path, weights_only=True)
    read_names = [resume_read_name(id) for id in predicts["id"]]
    labels = predicts["prediction"].argmax(dim=1).tolist()
    return (Predict(name, label) for name, label in zip(read_names, labels))

def collect_predict_from_folder(folder: Path| str):
    if isinstance(folder, str):
        folder = Path(folder)

    for file in folder.glob("*.pt"):
        yield from collect_predict_from_file(file)

def summarize_predict(predicts):
    total = 0
    number_label_1 = 0
    for predict in predicts:
        total += 1
        if predict.label == 1:
            number_label_1 += 1
    return total, number_label_1

def write_predicts(predicts, path):
    total = 0 
    number_label_1 = 0

    with open(path, "w") as f:
        for predict in predicts:
            total += 1
            if predict.label == 1:
                number_label_1 += 1
            f.write(f"{predict.name}\t{predict.label}\n")
    
    return total, number_label_1


def load_predicts(path) -> list[Predict]:
    with open(path) as f:
        predicts = []
        for line in f:
            name, label = line.strip().split("\t")
            predicts.append(Predict(name, int(label)))
        return predicts

In [104]:
predict_path = Path("/projects/b1171/ylk4626/project/Chimera/logs/eval/runs/2025-02-20_19-01-17")
# /projects/b1171/ylk4626/project/Chimera/logs/eval/runs/2025-02-20_20-30-32
# /projects/b1171/ylk4626/project/Chimera/logs/eval/runs/2025-02-20_22-06-13


In [105]:
# evaluate the memory usage
pt = collect_predict_from_folder(predict_path / "predicts/0")

In [106]:
write_predicts(pt, predict_path / "predicts/predicts.txt")

: 

: 

: 

In [2]:
 # data/sv/PC3_10_cells_MDA_Mk1c_dirty/chimeric_reads_mapping/cutesv.vcf.sv.read.sup.txt