In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import Dataset, load_dataset
from IPython.core.interactiveshell import InteractiveShell

import deepchopper

InteractiveShell.ast_node_interactivity = "all"

In [None]:
from rich.console import Console
from rich.text import Text


def highlight_target(seq: str, start: int, end: int, style="bold magenta"):
    text = Text(seq)
    console = Console()
    text.stylize(style, start, end)
    console.print(text)


def hightlight_predict(
    seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int
):
    text = Text(seq)
    console = Console()

    text.stylize("#adb0b1", target_start, target_end)
    text.stylize("bold magenta", predict_start, predict_end)

    console.print(text)

In [None]:
import platform

print(f"{platform.system()=}")
if platform.system() == "Linux":
    root_dir = Path("/projects/b1171/ylk4626/project/DeepChopper")
else:
    root_dir = Path("/Users/ylk4626/ClionProjects/DeepChopper")

In [None]:
train_file = root_dir / "tests/data/test_input.parquet"
data_files = {"train": train_file.as_posix()}

num_proc = 8
train_dataset = load_dataset(
    "parquet",
    data_files=data_files,
    num_proc=num_proc,
    split="train[:80%]",
).with_format("torch")
val_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[80%:90%]"
).with_format("torch")
test_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[90%:]"
).with_format("torch")

print(f"train_dataset: {train_dataset}")
print(f"val_dataset: {val_dataset}")
print(f"test_dataset: {test_dataset}")

In [None]:
train_dataset.features

In [None]:
import pandas as pd


def show_example_for_dataset(dataset, split=None, first_examples: int = 10):
    if split is not None:
        id = dataset[split]["id"][0:first_examples]
        seq = dataset[split]["seq"][0:first_examples]
        qual = dataset[split]["qual"][0:first_examples]
        target = dataset[split]["target"][0:first_examples]
    else:
        id = dataset["id"][0:first_examples]
        seq = dataset["seq"][0:first_examples]
        qual = dataset["qual"][0:first_examples]
        target = dataset["target"][0:first_examples]

    qual = [i.tolist() for i in qual]
    target = [i.tolist() for i in target]
    df = pd.DataFrame({"id": id, "seq": seq, "qual": qual, "target": target})
    return df

In [None]:
show_example_for_dataset(train_dataset)

In [None]:
highlight_target(seq, *target)

In [None]:
hightlight_predict(seq, *target, 1070, 1120)

In [None]:
hightlight_predict(seq, *target, 1060, 1120)

# 1. Read Len of Direct RNA

In [None]:
def vis_bam_record_len():
    direc_rna_samples = ["22Rv1", "DU145", "LNCaP", "LuCaP", "PC3", "VCaP"]
    data = [np.load(root_dir / f"data/direct_rna/{p}.npy") for p in direc_rna_samples]
    # plt.rc('font', family='Times New Roman')

    fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 6))

    flat_axs = axs.flatten()

    for i, sample in enumerate(range(len(direc_rna_samples))):
        # Create the density plot
        sns.kdeplot(data[i], fill=True, ax=flat_axs[i])
        flat_axs[i].set_title(f"Sample {sample}")

    # _ = ax1.set_xlabel('Threshold', fontsize=20)
    # _ = ax1.set_ylabel('Length of itemsets', fontsize=20)

    # ax1.legend(['Sliding window average'],fontsize=18,loc='lower left',edgecolor='k',fancybox=True)

    # ax1.tick_params(axis='y', labelsize=15)
    # ax1.tick_params(axis='x', labelsize=15
    fig.set_size_inches(20, 20)

    # Adding labels and title
    plt.title("Read Length of  Direc RNA")
    plt.xticks(rotation=30)

    return data

In [None]:
vis_bam_record_len()

In [None]:
data = vis_bam_record_len(root_dir / f"data/direct_rna/{direc_rna_samples[0]}.npy")

In [None]:
max(data)

In [None]:
d2 = list(data)

In [None]:
d2.remove(103380)

In [None]:
max(d2)

In [None]:
sns.kdeplot(d2, fill=True)

In [None]:
data.sort()

In [None]:
sns.kdeplot(data[:-800], fill=True)

In [None]:
des = pd.Series(data).describe()

# 2. Build Model

In [None]:
import torch
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
)


def load_tokenizer_from_hyena_model(model_name):
    max_lengths = {
        "hyenadna-tiny-1k-seqlen": 1024,
        "hyenadna-small-32k-seqlen": 32768,
        "hyenadna-medium-160k-seqlen": 160000,
        "hyenadna-medium-450k-seqlen": 450000,  # T4 up to here
        "hyenadna-large-1m-seqlen": 1_000_000,  # only A100 (paid tier)
    }

    if model_name not in max_lengths:
        msg = f"Model name {model_name} not found in available models."
        raise ValueError(msg)

    max_length = max_lengths[model_name]
    # bfloat16 for better speed and reduced memory usage
    model_name = f"LongSafari/{model_name}-hf"
    return AutoTokenizer.from_pretrained(
        model_name, max_length=max_length, truncation=True, padding=True, trust_remote_code=True
    )

In [None]:
import evaluate
import numpy as np

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])


def compute_metrics(p):
    predictions, labels = p

    # print(f"{predictions.shape=}")
    # print(f"{labels.shape=}")

    predictions = np.argmax(predictions, axis=2)
    # Initialize lists to hold the filtered predictions and labels
    true_predictions = []
    true_labels = []

    # Filter out '-100' labels and correspondingly filter predictions
    for prediction, label in zip(predictions, labels):
        filtered_prediction = []
        filtered_label = []

        for p, l in zip(prediction, label):
            if l != -100:
                filtered_prediction.append(p)
                filtered_label.append(l)
        true_predictions.append(filtered_prediction)
        true_labels.append(filtered_label)

    for preds, refs in zip(true_predictions, true_labels):
        clf_metrics.add_batch(predictions=preds, references=refs)

    result = clf_metrics.compute()
    return result

In [None]:
from typing import List

import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")


class TokenClassificationHead(nn.Module):
    def __init__(
        self,
        input_size: int,
        lin1_size: int,
        lin2_size: int,
        num_class: int,
        *,
        use_identity_layer_for_qual: bool,
    ):
        super().__init__()
        self.activation = nn.ReLU()
        self.linear1 = nn.Linear(input_size, lin1_size)
        self.linear2 = nn.Linear(lin1_size, lin2_size)
        self.linear3 = nn.Linear(lin2_size, num_class)
        self.qual_linear1 = (
            nn.Identity() if use_identity_layer_for_qual else nn.Linear(1, lin1_size)
        )

    def forward(self, x: torch.Tensor, input_quals: torch.Tensor) -> torch.Tensor:
        output = self.activation(self.linear1(x))
        residual = output + self.qual_linear1(input_quals.unsqueeze(-1))  # may add activation
        output = self.activation(self.linear2(residual) + residual)
        return self.linear3(output)


class TokenClassificationConfig(PretrainedConfig):
    model_type = "token-classification"

    def __init__(
        self,
        input_size: int = 256,
        lin1_size: int = 1024,
        lin2_size: int = 1024,
        num_class: int = 2,
        *,
        use_identity_layer_for_qual: bool = True,
        **kwargs,
    ):
        self.input_size = input_size
        self.lin1_size = lin1_size
        self.lin2_size = lin2_size
        self.num_class = num_class
        self.use_identity_layer_for_qual = use_identity_layer_for_qual
        super().__init__(**kwargs)


class TokenClassification(PreTrainedModel):
    config_class = TokenClassificationConfig

    def __init__(
        self,
        config,
        hyenadna_model: str = "hyenadna-small-32k-seqlen",
        **kwargs,
    ):
        super().__init__(config, **kwargs)
        self.num_class = config.num_class
        self.hyenadna_model_name = hyenadna_model
        self.hyenadna = AutoModel.from_pretrained(
            f"LongSafari/{hyenadna_model}-hf", trust_remote_code=True
        )

        self.head = TokenClassificationHead(
            input_size=config.input_size,
            lin1_size=config.lin1_size,
            lin2_size=config.lin2_size,
            num_class=config.num_class,
            use_identity_layer_for_qual=config.use_identity_layer_for_qual,
        )

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        input_quals: torch.Tensor,
        inputs_embeds: torch.FloatTensor | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> torch.Tensor:
        # logger.info(f"{input_ids.shape=}")
        # logger.info(f"{labels.shape=}")
        # logger.info(f"{input_quals.shape=}")

        transformer_outputs = self.backbone(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        batch_size = input_ids.shape[0]
        hidden_states = transformer_outputs[0]

        logits = self.head(hidden_states, input_quals)
        labels = labels.to(logits.device)
        loss_fct = nn.CrossEntropyLoss()

        loss = loss_fct(logits.view(-1, self.num_class), labels.view(-1))

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
        )

In [None]:
from transformers import DataCollatorForTokenClassification


def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded


class DataCollatorForTokenClassificationWithQual(DataCollatorForTokenClassification):

    def torch_call(self, features):
        import torch

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = (
            [feature[label_name] for feature in features]
            if label_name in features[0].keys()
            else None
        )

        qual_name = "input_quals"
        qual_pad_token_id = 0
        input_quals = [feature[qual_name] for feature in features]

        no_labels_features = [
            {k: v for k, v in feature.items() if k not in [qual_name, label_name]}
            for feature in features
        ]

        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label))
                for label in labels
            ]
            batch[qual_name] = [
                to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual))
                for qual in input_quals
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label)
                for label in labels
            ]
            batch[qual_name] = [
                [qual_pad_token_id] * (sequence_length - len(qual)) + to_list(qual)
                for qual in input_quals
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        batch[qual_name] = torch.tensor(batch[qual_name], dtype=torch.int64)
        return batch

In [None]:
from functools import partial

from transformers import DataCollatorForTokenClassification


def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0, pad_label=-100):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    labels = torch.tensor(
        deepchopper.vertorize_target(*data["target"], len(data["seq"])) + [pad_label]
    )
    quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
    normalized_quals = torch.nn.functional.normalize(quals, dim=0)
    tokenized_inputs.update({"labels": labels, "input_quals": quals})
    return tokenized_inputs


def tokenize_dataset(dataset, tokenizer, max_length):
    return dataset.map(
        partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=max_length)
    ).remove_columns(["id", "seq", "qual", "target"])


hyenadna_name = "hyenadna-small-32k-seqlen"
tokenizer = load_tokenizer_from_hyena_model(hyenadna_name)

tokenize_train_dataset = tokenize_dataset(
    train_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)
tokenize_val_dataset = tokenize_dataset(
    val_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)
tokenize_test_dataset = tokenize_dataset(
    test_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)

In [None]:
# data_collator = DataCollatorForTokenClassification(tokenizer)
data_collator = DataCollatorForTokenClassificationWithQual(tokenizer)
model_config = TokenClassificationConfig()
model = TokenClassification(model_config)

In [None]:
model

In [None]:
tokenize_train_dataset

In [None]:
from accelerate import Accelerator

accelerator = Accelerator()

training_args = TrainingArguments(
    output_dir="hyena_model_use_qual_test",
    learning_rate=2e-5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    num_train_epochs=1,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    torch_compile=False,
    # tf32=True,
    report_to="wandb",
    run_name="hyena_model_use_qual",
    resume_from_checkpoint=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenize_train_dataset,
    eval_dataset=tokenize_test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer = accelerator.prepare(trainer)

In [None]:
result = trainer.train()

In [None]:
# resume_config = TokenClassificationConfig.from_pretrained("./hyena_model_test2/checkpoint-1000/")
resume_model = TokenClassification.from_pretrained("./hyena_model_test2/checkpoint-500/")

In [None]:
for k in model.state_dict():
    v1 = resume_model.state_dict()[k]
    v2 = model.state_dict()[k]
    result = v1.eq(v2)
    if not torch.all(result):
        print(f"{k} is not equal")

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open(
    "./hyena_model_test2/checkpoint-500/model.safetensors", framework="pt", device=0
) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

In [None]:
trainer.evaluate()

In [None]:
predicts = trainer.predict(tokenize_val_dataset)

In [None]:
def summary_predict(predictions, labels):
    predictions = np.argmax(predictions, axis=2)
    # Initialize lists to hold the filtered predictions and labels
    true_predictions = []
    true_labels = []

    # Filter out '-100' labels and correspondingly filter predictions
    for prediction, label in zip(predictions, labels):
        filtered_prediction = []
        filtered_label = []

        for p, l in zip(prediction, label):
            if l != -100:
                filtered_prediction.append(p)
                filtered_label.append(l)
        true_predictions.append(filtered_prediction)
        true_labels.append(filtered_label)

    return true_predictions, true_labels


from rich.console import Console
from rich.highlighter import RegexHighlighter
from rich.theme import Theme


class LabelHighlighter(RegexHighlighter):
    """Apply style to anything that looks like an email."""

    base_style = "label."
    highlights = [r"(?P<label>1+)"]


def alignment_predict(prediction, label):
    import textwrap

    prediction_str = "".join(map(lambda x: str(x), prediction))
    label_str = "".join(map(lambda x: str(x), label))

    front2 = "L:"
    front1 = "P:"
    theme = Theme({"label.label": "bold magenta"})
    console = Console(highlighter=LabelHighlighter(), theme=theme)
    for l1, l2 in zip(textwrap.wrap(prediction_str), textwrap.wrap(label_str)):
        ss = f"{front1}{l1}\n{front2}{l2}"
        console.print(ss)

In [None]:
true_predictions, true_labels = summary_predict(predicts[0], predicts[1])

In [None]:
alignment_predict(true_predictions[6], true_labels[6])

# Load Pretrained Model

In [None]:
from functools import partial
from pathlib import Path

import numpy as np
from accelerate import Accelerator
from datasets import load_dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments

from deepchopper.data import load_and_split_dataset
from deepchopper.models.hyena import (
    IGNORE_INDEX,
    DataCollatorForTokenClassificationWithQual,
    HyenadnaMaxLengths,
    TokenClassification,
    TokenClassificationConfig,
    compute_metrics,
    tokenize_and_align_labels_and_quals,
)
from deepchopper.utils import alignment_predict, highlight_target, summary_predict

In [None]:
# check_point = root_dir / "notebooks/deepchopper_train/checkpoint-26672"
root_dir = Path("/projects/b1171/ylk4626/project/DeepChopper")
check_point = root_dir / "notebooks/cdc_train100000_20ep_18b/checkpoint-20007/"
resume_tokenizer = AutoTokenizer.from_pretrained(check_point, trust_remote_code=True)
resume_model = TokenClassification.from_pretrained(check_point)

In [None]:
def random_show_seq(dataset, sample=3):
    total = len(dataset)
    import secrets

    highlight_ids = (secrets.randbelow(total) for _ in range(sample))
    for highlight_id in highlight_ids:
        print(f"id: {eval_dataset[highlight_id]['id']}")
        highlight_target(eval_dataset[highlight_id]["seq"], *eval_dataset[highlight_id]["target"])

In [None]:
import multiprocessing

data_path = root_dir / "data/fqs/PC3.internal.parquet"

eval_dataset = load_dataset(
    "parquet",
    data_files={"test": data_path.as_posix()},
    num_proc=multiprocessing.cpu_count(),
    split="test[:200]",
).with_format("torch")

In [None]:
random_show_seq(eval_dataset)

In [None]:
tokenized_eval_dataset = (
    eval_dataset.map(
        partial(
            tokenize_and_align_labels_and_quals,
            tokenizer=resume_tokenizer,
            max_length=resume_tokenizer.max_len_single_sentence,
        ),
        num_proc=multiprocessing.cpu_count(),
        desc="Running tokenizer on train dataset",
    )
    .remove_columns(["id", "seq", "qual", "target"])
    .shuffle()
)

In [None]:
data_collator = DataCollatorForTokenClassificationWithQual(resume_tokenizer)

training_args = TrainingArguments(
    output_dir="hyena_model_use_qual_testt",
    learning_rate=2e-5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    num_train_epochs=1,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    torch_compile=False,
    # tf32=True,
    report_to="wandb",
    run_name="eval",
    resume_from_checkpoint=False,
)


# Initialize our Trainer
trainer = Trainer(
    model=resume_model,
    args=training_args,
    # train_dataset=train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=resume_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
predicts = trainer.predict(tokenized_eval_dataset)

In [None]:
from rich import print

print(predicts[2])

In [None]:
true_prediction, true_label = summary_predict(predictions=predicts[0], labels=predicts[1])

In [None]:
alignment_predict(true_prediction[4], true_label[4])