In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import Dataset, load_dataset
from IPython.core.interactiveshell import InteractiveShell

import deepchopper

InteractiveShell.ast_node_interactivity = "all"

In [None]:
from rich.console import Console
from rich.text import Text


def highlight_target(seq: str, start: int, end: int, style="bold magenta"):
    text = Text(seq)
    console = Console()
    text.stylize(style, start, end)
    console.print(text)


def hightlight_predict(
    seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int
):
    text = Text(seq)
    console = Console()

    text.stylize("#adb0b1", target_start, target_end)
    text.stylize("bold magenta", predict_start, predict_end)

    console.print(text)

In [None]:
# root_dir = Path("/projects/b1171/ylk4626/project/DeepChopper")
root_dir = Path("/Users/ylk4626/ClionProjects/DeepChopper")

In [None]:
train_file = root_dir / "tests/data/test_input.parquet"
dataset = load_dataset("parquet", data_files={"train": train_file.as_posix()})

In [None]:
dataset

In [None]:
seq1 = dataset["train"]["seq"][0]

In [None]:
seq1

In [None]:
seq = dataset["train"]["seq"][0]
qual = dataset["train"]["qual"][0]
target = dataset["train"]["target"][0]
# label = dataset['train']['label'][0]

In [None]:
highlight_target(seq, *target)

In [None]:
hightlight_predict(seq, *target, 1070, 1120)

In [None]:
hightlight_predict(seq, *target, 1060, 1120)

# 1. Read Len of Direct RNA

In [None]:
def vis_bam_record_len():
    direc_rna_samples = ["22Rv1", "DU145", "LNCaP", "LuCaP", "PC3", "VCaP"]
    data = [np.load(root_dir / f"data/direct_rna/{p}.npy") for p in direc_rna_samples]
    # plt.rc('font', family='Times New Roman')

    fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 6))

    flat_axs = axs.flatten()

    for i, sample in enumerate(range(len(direc_rna_samples))):
        # Create the density plot
        sns.kdeplot(data[i], fill=True, ax=flat_axs[i])
        flat_axs[i].set_title(f"Sample {sample}")

    # _ = ax1.set_xlabel('Threshold', fontsize=20)
    # _ = ax1.set_ylabel('Length of itemsets', fontsize=20)

    # ax1.legend(['Sliding window average'],fontsize=18,loc='lower left',edgecolor='k',fancybox=True)

    # ax1.tick_params(axis='y', labelsize=15)
    # ax1.tick_params(axis='x', labelsize=15
    fig.set_size_inches(20, 20)

    # Adding labels and title
    plt.title("Read Length of  Direc RNA")
    plt.xticks(rotation=30)

    return data

In [None]:
vis_bam_record_len()

In [None]:
data = vis_bam_record_len(root_dir / f"data/direct_rna/{direc_rna_samples[0]}.npy")

In [None]:
max(data)

In [None]:
d2 = list(data)

In [None]:
d2.remove(103380)

In [None]:
max(d2)

In [None]:
sns.kdeplot(d2, fill=True)

In [None]:
data.sort()

In [None]:
sns.kdeplot(data[:-800], fill=True)

In [None]:
des = pd.Series(data).describe()

# 2. Build Model

In [None]:
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
)

max_lengths = {
    "hyenadna-tiny-1k-seqlen": 1024,
    "hyenadna-small-32k-seqlen": 32768,
    "hyenadna-medium-160k-seqlen": 160000,
    "hyenadna-medium-450k-seqlen": 450000,  # T4 up to here
    "hyenadna-large-1m-seqlen": 1_000_000,  # only A100 (paid tier)
}
model_checkpoints = list(max_lengths.keys())

# instantiate pretrained model
checkpoint = "hyenadna-small-32k-seqlen"
max_length = max_lengths[checkpoint]
# bfloat16 for better speed and reduced memory usage
model_name = f"LongSafari/{checkpoint}-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


# model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)

In [None]:
def generate_label_from_target(data):
    return {"label": deepchopper.vertorize_target(*data["target"], len(data["seq"]))}


def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0, pad_label=-100):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    labels = deepchopper.vertorize_target(*data["target"], len(data["seq"])) + [pad_label]
    quals = data["qual"] + [pad_qual]

    tokenized_inputs.update({"label": labels, "input_qual": quals})

    return tokenized_inputs


from functools import partial

tokenize_dataset = dataset.map(
    partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=32768)
)

In [None]:
from transformers import AutoModel

hyena_dna_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

In [None]:
hyena_dna_model

In [None]:
import evaluate

seqeval = evaluate.load("seqeval")

In [None]:
import numpy as np


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
import torch
from torch import nn
from transformers import AutoModel
from transformers.modeling_outputs import TokenClassifierOutput


class TokenClassificationHead(nn.Module):
    def __init__(
        self,
        input_size: int,
        lin1_size: int,
        lin2_size: int,
        output_size: int,
    ):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, lin1_size),
            nn.BatchNorm1d(lin1_size),
            nn.ReLU(),
            nn.Linear(lin1_size, lin2_size),
            nn.BatchNorm1d(lin2_size),
            nn.ReLU(),
            nn.Linear(lin2_size, output_size),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class HyenaDNAForTokenClassification(nn.Module):
    def __init__(
        self,
        backbone_model_name: str,
        input_size: int = 256,
        lin1_size: int = 2048,
        lin2_size: int = 1024,
    ):
        super().__init__()
        self.backbone_max_length = {
            "hyenadna-tiny-1k-seqlen": 1024,
            "hyenadna-small-32k-seqlen": 32768,
            "hyenadna-medium-160k-seqlen": 160000,
            "hyenadna-medium-450k-seqlen": 450000,  # T4 up to here
            "hyenadna-large-1m-seqlen": 1_000_000,  # only A100 (paid tier)
        }
        assert backbone_model_name in self.backbone_max_length.keys()

        self.backbone_model_name = f"LongSafari/{backbone_model_name}-hf"
        self.backbone = AutoModel.from_pretrained(self.backbone_model_name, trust_remote_code=True)
        self.head = TokenClassificationHead(
            input_size=input_size,
            lin1_size=lin1_size,
            lin2_size=lin2_size,
            output_size=self.backbone_max_length[backbone_model_name],
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        quals: torch.Tensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> torch.Tensor:
        transformer_outputs = self.hyena(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        batch_size = input_ids.shape[0]

        hidden_states = transformer_outputs[0]
        logits = self.head(hidden_states)

        sequence_lengths = (
            torch.eq(input_ids, self.hyena.config.pad_token_id).long().argmax(-1) - 1
        ).to(logits.device)
        labels = labels.to(logits.device)

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels)

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
        )

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

test_model = HyenaDNAForTokenClassification(backbone_model_name="hyenadna-small-32k-seqlen")

In [None]:
training_args = TrainingArguments(
    output_dir="my_awesome_model",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=test_model,
    args=training_args,
    train_dataset=tokenized_wnut["train"],
    eval_dataset=tokenized_wnut["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# Generate some random sequence and labels
# If you're copying this code, replace the sequences and labels
# here with your own data!
# sequence = 'ACTG' * int(max_length/4)
# sequence = [sequence] * 8  # Create 8 identical samples
# tokenized = tokenizer(sequence)["input_ids"]
# labels = [0, 1] * 4

sequence = [seq]
tokenized = tokenizer(sequence)["input_ids"]
labels = label

# Create a dataset for training
ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
ds.set_format("pt")

# Initialize Trainer
# Note that we're using extremely small batch sizes to maximize
# our ability to fit long sequences in memory!
args = {
    "output_dir": "tmp",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 4,
    "gradient_checkpointing": True,
    "learning_rate": 2e-5,
}
training_args = TrainingArguments(**args)

trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()

print(result)

# Now we can save_pretrained() or push_to_hub() to share the trained model!

In [None]:
result