In [None]:
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils import logging

In [None]:
from typing import List

import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel,AutoTokenizer
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils import logging


In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import Dataset, load_dataset
from IPython.core.interactiveshell import InteractiveShell

import deepchopper

InteractiveShell.ast_node_interactivity = "all"

In [None]:
import platform 
from pathlib import Path 

print(f"{platform.system()=}")
if platform.system() == "Linux":
    root_dir = Path("/projects/b1171/ylk4626/project/DeepChopper")
else:
    root_dir = Path("/Users/ylk4626/ClionProjects/DeepChopper")

In [None]:
train_file = root_dir / "tests/data/test_input.parquet"
data_files = {"train": train_file.as_posix()}

num_proc = 8
train_dataset = load_dataset(
    "parquet",
    data_files=data_files,
    num_proc=num_proc,
    split="train[:80%]",
).with_format("torch")
val_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[80%:90%]"
).with_format("torch")
test_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[90%:]"
).with_format("torch")

print(f"train_dataset: {train_dataset}")
print(f"val_dataset: {val_dataset}")
print(f"test_dataset: {test_dataset}")

In [None]:
train_dataset

In [None]:
from torch.utils.data import DataLoader

In [None]:
def load_tokenizer_from_hyena_model(model_name):
    max_lengths = {
        "hyenadna-tiny-1k-seqlen": 1024,
        "hyenadna-small-32k-seqlen": 32768,
        "hyenadna-medium-160k-seqlen": 160000,
        "hyenadna-medium-450k-seqlen": 450000,  # T4 up to here
        "hyenadna-large-1m-seqlen": 1_000_000,  # only A100 (paid tier)
    }

    if model_name not in max_lengths:
        msg = f"Model name {model_name} not found in available models."
        raise ValueError(msg)

    max_length = max_lengths[model_name]
    # bfloat16 for better speed and reduced memory usage
    model_name = f"LongSafari/{model_name}-hf"
    return AutoTokenizer.from_pretrained(
        model_name, max_length=max_length, truncation=True, padding=True, trust_remote_code=True
    )

In [None]:
hyenadna_name = "hyenadna-small-32k-seqlen"
tokenizer = load_tokenizer_from_hyena_model(hyenadna_name)

In [None]:
from deepchopper.models.hyena import (
    IGNORE_INDEX,
    # DataCollatorForTokenClassificationWithQual,
    compute_metrics,
    # tokenize_and_align_labels_and_quals,
)

In [95]:
from transformers import (
    AutoTokenizer,
    DataCollatorForTokenClassification,
)

import deepchopper

def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer."""
    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded


class DataCollatorForTokenClassificationWithQual(DataCollatorForTokenClassification):
    def torch_call(self, features):
        import torch

        label_name = "label" if "label" in features[0] else "labels"
        labels = (
            [feature[label_name] for feature in features] if label_name in features[0] else None
        )

        qual_name = "input_quals"
        qual_pad_token_id = 0
        input_quals = [feature[qual_name] for feature in features]

        no_labels_features = [
            {k: v for k, v in feature.items() if k not in [qual_name, label_name]}
            for feature in features
        ]

        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)
        
        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label))
                for label in labels
            ]
            batch[qual_name] = [
                to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual))
                for qual in input_quals
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label)
                for label in labels
            ]
            batch[qual_name] = [
                [qual_pad_token_id] * (sequence_length - len(qual)) + to_list(qual)
                for qual in input_quals
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        batch[qual_name] = torch.tensor(batch[qual_name], dtype=torch.float32)
        print(batch)        
        return batch



In [96]:

tt = DataCollatorForTokenClassificationWithQual(tokenizer)


In [97]:
def tokenize_and_align_labels_and_quals(
    data, tokenizer, max_length, pad_qual=0, pad_label=IGNORE_INDEX
):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    labels = torch.tensor(
        [*deepchopper.vertorize_target(*data["target"], len(data["seq"])), pad_label]
    )
    quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
    normalized_quals = torch.nn.functional.normalize(quals, dim=0)
    tokenized_inputs.update({"labels": labels, "input_quals": normalized_quals})
    return tokenized_inputs

In [98]:
tokenized_train_dataset = train_dataset.map(
                    partial(
                        tokenize_and_align_labels_and_quals,
                        tokenizer=tokenizer,
                        max_length=tokenizer.max_len_single_sentence,
                    ),
                    batched=False,
                    num_proc=12,
                    desc="Running tokenizer on train dataset",
                ).remove_columns(["id", "seq", "qual", "target"])



  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Running tokenizer on train dataset (num_proc=12):   0%|          | 0/4000 [00:00<?, ? examples/s]

In [99]:
tokenized_train_dataset[0]

{'input_ids': tensor([9, 8, 7,  ..., 8, 7, 1]),
 'labels': tensor([   0,    0,    0,  ...,    0,    0, -100]),
 'input_quals': tensor([0.0098, 0.0113, 0.0210,  ..., 0.0188, 0.0218, 0.0000])}

In [100]:
from functools import partial
dataloader = DataLoader(tokenized_train_dataset, batch_size=4, collate_fn=tt.torch_call)

In [101]:
for batch in dataloader:
    print(batch)
    break

{'input_ids': tensor([[ 4,  4,  4,  ...,  8,  7,  1],
        [ 4,  4,  4,  ...,  9,  7,  1],
        [ 9, 10,  7,  ..., 10,  8,  1],
        [ 4,  4,  4,  ...,  7,  7,  1]]), 'labels': tensor([[-100, -100, -100,  ...,    0,    0, -100],
        [-100, -100, -100,  ...,    0,    0, -100],
        [   0,    0,    0,  ...,    0,    0, -100],
        [-100, -100, -100,  ...,    0,    0, -100]]), 'input_quals': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0188, 0.0218, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0219, 0.0228, 0.0000],
        [0.0141, 0.0196, 0.0201,  ..., 0.0087, 0.0082, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0299, 0.0276, 0.0000]])}
{'input_ids': tensor([[ 4,  4,  4,  ...,  8,  7,  1],
        [ 4,  4,  4,  ...,  9,  7,  1],
        [ 9, 10,  7,  ..., 10,  8,  1],
        [ 4,  4,  4,  ...,  7,  7,  1]]), 'labels': tensor([[-100, -100, -100,  ...,    0,    0, -100],
        [-100, -100, -100,  ...,    0,    0, -100],
        [   0,    0,    0,  ...,    0,   

In [None]:
batch['input_ids']

In [None]:
batch['labels']

In [None]:
batch['input_quals'][2][-200:]

In [None]:
batch['input_quals'][2].sum()