In [16]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import Dataset, load_dataset
from IPython.core.interactiveshell import InteractiveShell

import deepchopper

InteractiveShell.ast_node_interactivity = "all"

In [17]:
from rich.console import Console
from rich.text import Text


def highlight_target(seq: str, start: int, end: int, style="bold magenta"):
    text = Text(seq)
    console = Console()
    text.stylize(style, start, end)
    console.print(text)


def hightlight_predict(
    seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int
):
    text = Text(seq)
    console = Console()

    text.stylize("#adb0b1", target_start, target_end)
    text.stylize("bold magenta", predict_start, predict_end)

    console.print(text)

In [18]:
import platform

print(f"{platform.system()=}")
if platform.system() == "Linux":
    root_dir = Path("/projects/b1171/ylk4626/project/DeepChopper")
else:
    root_dir = Path("/Users/ylk4626/ClionProjects/DeepChopper")

platform.system()='Linux'


In [19]:
train_file = root_dir / "tests/data/test_input.parquet"
data_files = {"train": train_file.as_posix()}

num_proc = 8
train_dataset = load_dataset(
    "parquet",
    data_files=data_files,
    num_proc=num_proc,
    split="train[:80%]",
).with_format("torch")
val_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[80%:90%]"
).with_format("torch")
test_dataset = load_dataset(
    "parquet", data_files=data_files, num_proc=num_proc, split="train[90%:]"
).with_format("torch")

print(f"train_dataset: {train_dataset}")
print(f"val_dataset: {val_dataset}")
print(f"test_dataset: {test_dataset}")

train_dataset: Dataset({
    features: ['id', 'seq', 'qual', 'target'],
    num_rows: 4000
})
val_dataset: Dataset({
    features: ['id', 'seq', 'qual', 'target'],
    num_rows: 500
})
test_dataset: Dataset({
    features: ['id', 'seq', 'qual', 'target'],
    num_rows: 500
})


In [20]:
import pandas as pd


def show_example_for_dataset(dataset, split=None, first_examples: int = 10):
    if split is not None:
        id = dataset[split]["id"][0:first_examples]
        seq = dataset[split]["seq"][0:first_examples]
        qual = dataset[split]["qual"][0:first_examples]
        target = dataset[split]["target"][0:first_examples]
    else:
        id = dataset["id"][0:first_examples]
        seq = dataset["seq"][0:first_examples]
        qual = dataset["qual"][0:first_examples]
        target = dataset["target"][0:first_examples]

    qual = [i.tolist() for i in qual]
    target = [i.tolist() for i in target]
    df = pd.DataFrame({"id": id, "seq": seq, "qual": qual, "target": target})
    return df

In [21]:
show_example_for_dataset(train_dataset)

Unnamed: 0,id,seq,qual,target
0,1065:1135|393d635c-64f0-41ed-8531-12174d8efb28...,GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAA...,"[13, 15, 28, 28, 30, 50, 50, 50, 50, 50, 50, 5...","[1065, 1135]"
1,1573:1653|0e1e016e-02fb-4611-b1e3-6b688615e04f...,AGCGGAGAGCGGCACCATGGCCCGCGGGGCGGCGGCGGCCGCGGCC...,"[21, 24, 25, 25, 50, 49, 50, 40, 42, 42, 43, 5...","[1573, 1653]"
2,607:689|29bcb833-d8d6-44d3-a0e8-5128ee173825+f...,GTAACAATACAAATGGATTTTGGGAGTGACTCAAGAAGTGAAGAAT...,"[31, 43, 44, 43, 42, 43, 43, 44, 50, 50, 39, 5...","[607, 689]"
3,512:569|fea21b55-f0e1-445b-b656-565ecf669bde+5...,GTGTGAACATGCTCAACATCTCCCTTTTCTTTGGGCTGGTCATCCA...,"[7, 6, 13, 30, 50, 50, 50, 50, 50, 50, 50, 42,...","[512, 569]"
4,1128:1194|85667758-fd99-4c0a-9ca0-9fbb979bad53...,TGCGAAAGCCCCGGACTCGTGGAGTTGTTGAACGCCATGGACTCCG...,"[12, 17, 20, 50, 50, 50, 45, 39, 30, 32, 33, 2...","[1128, 1194]"
5,504:600|8111f832-c1ae-4c0f-9a30-f5c876f28d8a+9...,GCGCAGCCATTTTGGCTTCCTGACCTTGGGCTACGGCTGACCGTTT...,"[23, 27, 37, 41, 41, 50, 50, 46, 12, 11, 11, 1...","[504, 600]"
6,673:734|a9979431-9f1f-4d52-a5bc-1400722e0b3d+3...,GGCTGCCGAAGATGGCGGAGGTGCAGGTCCACCTGGTGCTTGATGG...,"[27, 32, 32, 35, 36, 38, 40, 50, 50, 50, 50, 5...","[673, 734]"
7,716:788|a1dd6bda-66eb-4466-adda-39297b7ee523+e...,TTGCAGCGCGATTGCCTCCGAGACCGCGAGACATACACGCAGCGAA...,"[12, 21, 41, 44, 42, 22, 11, 11, 13, 15, 26, 2...","[716, 788]"
8,485:555|4249a00c-2aa9-4f0e-9780-2faf41d69d50+1...,GACATCTCTGACGAGGCTGCGGTGTCTGCTGCATTCCCGCTGGCTC...,"[32, 36, 16, 16, 9, 10, 9, 10, 50, 48, 49, 44,...","[485, 555]"
9,695:776|1034f167-87a9-4d06-8857-1b26becd2242+b...,TGGGGAACAAGCAGCTGTCCCTGAGCCCAGAAGAGTATGTGTTTGC...,"[12, 37, 44, 40, 42, 43, 44, 49, 50, 49, 49, 5...","[695, 776]"


In [None]:
highlight_target(seq, *target)

In [None]:
hightlight_predict(seq, *target, 1070, 1120)

In [None]:
hightlight_predict(seq, *target, 1060, 1120)

# 1. Read Len of Direct RNA

In [None]:
def vis_bam_record_len():
    direc_rna_samples = ["22Rv1", "DU145", "LNCaP", "LuCaP", "PC3", "VCaP"]
    data = [np.load(root_dir / f"data/direct_rna/{p}.npy") for p in direc_rna_samples]
    # plt.rc('font', family='Times New Roman')

    fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 6))

    flat_axs = axs.flatten()

    for i, sample in enumerate(range(len(direc_rna_samples))):
        # Create the density plot
        sns.kdeplot(data[i], fill=True, ax=flat_axs[i])
        flat_axs[i].set_title(f"Sample {sample}")

    # _ = ax1.set_xlabel('Threshold', fontsize=20)
    # _ = ax1.set_ylabel('Length of itemsets', fontsize=20)

    # ax1.legend(['Sliding window average'],fontsize=18,loc='lower left',edgecolor='k',fancybox=True)

    # ax1.tick_params(axis='y', labelsize=15)
    # ax1.tick_params(axis='x', labelsize=15
    fig.set_size_inches(20, 20)

    # Adding labels and title
    plt.title("Read Length of  Direc RNA")
    plt.xticks(rotation=30)

    return data

In [None]:
vis_bam_record_len()

In [None]:
data = vis_bam_record_len(root_dir / f"data/direct_rna/{direc_rna_samples[0]}.npy")

In [None]:
max(data)

In [None]:
d2 = list(data)

In [None]:
d2.remove(103380)

In [None]:
max(d2)

In [None]:
sns.kdeplot(d2, fill=True)

In [None]:
data.sort()

In [None]:
sns.kdeplot(data[:-800], fill=True)

In [None]:
des = pd.Series(data).describe()

# 2. Build Model

In [22]:
import torch
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
)

def load_tokenizer_from_hyena_model(model_name):
    max_lengths = {
        "hyenadna-tiny-1k-seqlen": 1024,
        "hyenadna-small-32k-seqlen": 32768,
        "hyenadna-medium-160k-seqlen": 160000,
        "hyenadna-medium-450k-seqlen": 450000,  # T4 up to here
        "hyenadna-large-1m-seqlen": 1_000_000,  # only A100 (paid tier)
    }

    if model_name not in max_lengths:
        msg = f"Model name {model_name} not found in available models."
        raise ValueError(msg)

    max_length = max_lengths[model_name]
    # bfloat16 for better speed and reduced memory usage
    model_name = f"LongSafari/{model_name}-hf"
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, max_length=max_length, truncation=True, padding=True, trust_remote_code=True
    )
    return tokenizer

In [23]:
import evaluate
import numpy as np

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])


def compute_metrics(p):
    predictions, labels = p

    # print(f"{predictions.shape=}")
    # print(f"{labels.shape=}")

    predictions = np.argmax(predictions, axis=2)
    # Initialize lists to hold the filtered predictions and labels
    true_predictions = []
    true_labels = []

    # Filter out '-100' labels and correspondingly filter predictions
    for prediction, label in zip(predictions, labels):
        filtered_prediction = []
        filtered_label = []

        for p, l in zip(prediction, label):
            if l != -100:
                filtered_prediction.append(p)
                filtered_label.append(l)
        true_predictions.append(filtered_prediction)
        true_labels.append(filtered_label)

    for preds, refs in zip(true_predictions, true_labels):
        clf_metrics.add_batch(predictions=preds, references=refs)

    result = clf_metrics.compute()
    return result

In [28]:
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils import logging
from transformers import PretrainedConfig
from typing import List

logging.set_verbosity_info()
logger = logging.get_logger("transformers")


class TokenClassificationHead(nn.Module):
    def __init__(
        self,
        input_size: int,
        lin1_size: int,
        lin2_size: int,
        num_class: int,
        use_identity_layer_for_qual: bool,
    ):
        super().__init__()
        self.activation  = nn.ReLU()
        self.linear1 =  nn.Linear(input_size, lin1_size)
        self.linear2 =  nn.Linear(lin1_size, lin2_size)
        self.linear3 =  nn.Linear(lin2_size, num_class)
        self.qual_linear1 = nn.Identity() if use_identity_layer_for_qual else nn.Linear(1, lin1_size) 

    def forward(self, x: torch.Tensor, input_quals: torch.Tensor) -> torch.Tensor:
        output   = self.activation(self.linear1(x))
        residual =  output + self.qual_linear1(input_quals.unsqueeze(-1)) # may add activation
        output   = self.activation(self.linear2(residual) + residual)
        return self.linear3(output)
                

class TokenClassificationConfig(PretrainedConfig):
    model_type = "token-classification"

    def __init__(self,    
        input_size: int = 256,
        lin1_size: int = 1024,
        lin2_size: int= 1024,
        num_class: int = 2,
        use_identity_layer_for_qual: bool = True,
        **kwargs):
        self.input_size  = input_size
        self.lin1_size = lin1_size
        self.lin2_size = lin2_size
        self.num_class = num_class
        self.use_identity_layer_for_qual = use_identity_layer_for_qual
        super().__init__(**kwargs)
    

class TokenClassification(PreTrainedModel):
    config_class = TokenClassificationConfig
    
    def __init__(
        self,
        config,
        hyenadna_model: str = "hyenadna-small-32k-seqlen",
        **kwargs,
    ):
        super().__init__(config, **kwargs)
        self.num_class = config.num_class
        self.hyenadna_model_name = hyenadna_model
        self.hyenadna = AutoModel.from_pretrained(f"LongSafari/{hyenadna_model}-hf", trust_remote_code=True)
        
        self.head = TokenClassificationHead(
            input_size=config.input_size,
            lin1_size=config.lin1_size,
            lin2_size=config.lin2_size,
            num_class=config.num_class,
            use_identity_layer_for_qual=config.use_identity_layer_for_qual,
        )

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        input_quals: torch.Tensor,
        inputs_embeds: torch.FloatTensor | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> torch.Tensor:
        transformer_outputs = self.hyenadna(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        batch_size = input_ids.shape[0]
        hidden_states = transformer_outputs[0]
        
        logits = self.head(hidden_states, input_quals)
        labels = labels.to(logits.device)
        loss_fct = nn.CrossEntropyLoss()

        loss = loss_fct(logits.view(-1, self.num_class), labels.view(-1))

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
        )

In [29]:
from transformers import DataCollatorForTokenClassification

def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded
    
class DataCollatorForTokenClassificationWithQual(DataCollatorForTokenClassification):
    
    def torch_call(self, features):
        import torch
        
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

        qual_name = "input_quals"
        qual_pad_token_id = 0
        input_quals = [feature[qual_name] for feature in features]
        
        no_labels_features = [{k: v for k, v in feature.items() if k not in [qual_name, label_name]} 
                              for feature in features]

        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
            ]
            batch[qual_name] = [
               to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual)) for qual in input_quals
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]
            batch[qual_name] = [
             [qual_pad_token_id] * (sequence_length - len(qual)) +  to_list(qual)  for qual in input_quals
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        batch[qual_name]  = torch.tensor( batch[qual_name], dtype=torch.int64)
        return batch

In [30]:
from functools import partial

from transformers import DataCollatorForTokenClassification

def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0, pad_label=-100):
    tokenized_inputs = tokenizer(data["seq"], max_length=max_length, truncation=True, padding=True)
    labels = torch.tensor(
        deepchopper.vertorize_target(*data["target"], len(data["seq"])) + [pad_label]
    )
    quals = torch.cat((data["qual"], torch.tensor([pad_qual]))).float()
    normalized_quals =  torch.nn.functional.normalize(quals, dim=0)
    tokenized_inputs.update({"labels": labels, "input_quals": quals})
    return tokenized_inputs


def tokenize_dataset(dataset, tokenizer, max_length):
    return dataset.map(
        partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=max_length)
    ).remove_columns(["id", "seq", "qual", "target"])


hyenadna_name = "hyenadna-small-32k-seqlen"
tokenizer = load_tokenizer_from_hyena_model(hyenadna_name)

tokenize_train_dataset = tokenize_dataset(
    train_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)
tokenize_val_dataset = tokenize_dataset(
    val_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)
tokenize_test_dataset = tokenize_dataset(
    test_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence
)

loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/special_tokens_map.json
loading file tokenizer_config.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/tokenizer_config.json
loading file tokenizer.json from cache at None
  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [31]:
# data_collator = DataCollatorForTokenClassification(tokenizer)
data_collator = DataCollatorForTokenClassificationWithQual(tokenizer)
model_config = TokenClassificationConfig()
model = TokenClassification(model_config)

loading configuration file config.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/config.json
loading configuration file config.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/config.json
Model config HyenaConfig {
  "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
  "activation_freq": 10,
  "architectures": [
    "HyenaDNAForCausalLM"
  ],
  "auto_map": {
    "AutoConfig": "LongSafari/hyenadna-small-32k-seqlen-hf--configuration_hyena.HyenaConfig",
    "AutoModel": "LongSafari/hyenadna-small-32k-seqlen-hf--modeling_hyena.HyenaDNAModel",
    "AutoModelForCausalLM": "LongSafari/hyenadna-small-32k-seqlen-hf--modeling_hyena.HyenaDNAForCausalLM",
    "AutoModelForSequenceClassification": "LongSafari/hyenadna-small-32k-seqlen-hf--modelin

In [32]:
model

TokenClassification(
  (hyenadna): HyenaDNAModel(
    (backbone): HyenaLMBackbone(
      (embeddings): HyenaEmbeddings(
        (word_embeddings): Embedding(16, 256)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0-3): 4 x HyenaBlock(
          (mixer): HyenaOperator(
            (dropout): Dropout(p=0.0, inplace=False)
            (in_proj): Linear(in_features=256, out_features=768, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
            (short_filter): Conv1d(768, 768, kernel_size=(3,), stride=(1,), padding=(2,), groups=768)
            (filter_fn): HyenaFilter(
              (dropout): Dropout(p=0.0, inplace=False)
              (pos_emb): HyenaPositionalEmbedding()
              (implicit_filter): Sequential(
                (0): Linear(in_features=5, out_features=64, bias=True)
                (1): HyenaSin()
                (2): Linear(in_features=64, out_features=64, bias=True)
             

In [33]:
tokenize_train_dataset

Dataset({
    features: ['input_ids', 'labels', 'input_quals'],
    num_rows: 4000
})

In [34]:
training_args = TrainingArguments(
    output_dir="hyena_model_test2",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    torch_compile=False,
    report_to = "wandb",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenize_train_dataset,
    eval_dataset=tokenize_test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
result = trainer.train(resume_from_checkpoint = False)

PyTorch: setting up devices
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
***** Running training *****
  Num examples = 4,000
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 500
  Number of trainable parameters = 4,592,386
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myangyangli[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.023,0.004554,0.998527,0.974879,0.975158,0.974599


***** Running Evaluation *****
  Num examples = 500
  Batch size = 8
Checkpoint destination directory hyena_model_test2/checkpoint-500 already exists and is non-empty. Saving will proceed but saved results may be invalid.
Saving model checkpoint to hyena_model_test2/checkpoint-500
Configuration saved in hyena_model_test2/checkpoint-500/config.json
Model weights saved in hyena_model_test2/checkpoint-500/model.safetensors
tokenizer config file saved in hyena_model_test2/checkpoint-500/tokenizer_config.json
Special tokens file saved in hyena_model_test2/checkpoint-500/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from hyena_model_test2/checkpoint-500 (score: 0.0045535205863416195).
There were missing keys in the checkpoint model loaded: ['hyenadna.backbone.layers.0.mixer.filter_fn.implicit_filter.3.freq', 'hyenadna.backbone.layers.0.mixer.filter_fn.implicit_filter.5.freq', 'hyenadna.backbone.layers.1.mixer.

In [46]:
# resume_config = TokenClassificationConfig.from_pretrained("./hyena_model_test2/checkpoint-1000/")
resume_model  = TokenClassification.from_pretrained("./hyena_model_test2/checkpoint-500/")

loading configuration file ./hyena_model_test2/checkpoint-500/config.json
Model config TokenClassificationConfig {
  "architectures": [
    "TokenClassification"
  ],
  "input_size": 256,
  "lin1_size": 1024,
  "lin2_size": 1024,
  "model_type": "token-classification",
  "num_class": 2,
  "torch_dtype": "float32",
  "transformers_version": "4.38.1",
  "use_identity_layer_for_qual": true
}

loading weights file ./hyena_model_test2/checkpoint-500/model.safetensors
loading configuration file config.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/config.json
loading configuration file config.json from cache at /tmp/ylk4626-jupyter//xdg_cache_home/huggingface/hub/models--LongSafari--hyenadna-small-32k-seqlen-hf/snapshots/8fe770c78eb13fe33bf81501612faeddf4d6f331/config.json
Model config HyenaConfig {
  "_name_or_path": "LongSafari/hyenadna-small-32k-seqlen-hf",
  "activ

In [52]:
model.cpu()

TokenClassification(
  (hyenadna): HyenaDNAModel(
    (backbone): HyenaLMBackbone(
      (embeddings): HyenaEmbeddings(
        (word_embeddings): Embedding(16, 256)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0-3): 4 x HyenaBlock(
          (mixer): HyenaOperator(
            (dropout): Dropout(p=0.0, inplace=False)
            (in_proj): Linear(in_features=256, out_features=768, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
            (short_filter): Conv1d(768, 768, kernel_size=(3,), stride=(1,), padding=(2,), groups=768)
            (filter_fn): HyenaFilter(
              (dropout): Dropout(p=0.0, inplace=False)
              (pos_emb): HyenaPositionalEmbedding()
              (implicit_filter): Sequential(
                (0): Linear(in_features=5, out_features=64, bias=True)
                (1): HyenaSin()
                (2): Linear(in_features=64, out_features=64, bias=True)
             

In [53]:
model.state_dict()['head.linear1.weight']

tensor([[ 0.0211, -0.0359, -0.0023,  ...,  0.0611,  0.0329,  0.0383],
        [ 0.0439,  0.0458, -0.0358,  ...,  0.0044,  0.0060,  0.0362],
        [-0.0601,  0.0475,  0.0528,  ..., -0.0582, -0.0536, -0.0541],
        ...,
        [ 0.0551,  0.0277,  0.0052,  ...,  0.0281, -0.0534, -0.0359],
        [ 0.0395,  0.0079,  0.0401,  ..., -0.0006, -0.0034, -0.0282],
        [-0.0273, -0.0044,  0.0127,  ...,  0.0619, -0.0128, -0.0341]])

In [54]:
resume_model.state_dict()['head.linear1.weight']

tensor([[ 0.0211, -0.0359, -0.0023,  ...,  0.0611,  0.0329,  0.0383],
        [ 0.0439,  0.0458, -0.0358,  ...,  0.0044,  0.0060,  0.0362],
        [-0.0601,  0.0475,  0.0528,  ..., -0.0582, -0.0536, -0.0541],
        ...,
        [ 0.0551,  0.0277,  0.0052,  ...,  0.0281, -0.0534, -0.0359],
        [ 0.0395,  0.0079,  0.0401,  ..., -0.0006, -0.0034, -0.0282],
        [-0.0273, -0.0044,  0.0127,  ...,  0.0619, -0.0128, -0.0341]])

In [66]:
for k in model.state_dict():
    v1 = resume_model.state_dict()[k]
    v2 = model.state_dict()[k]
    result = v1.eq(v2)
    if  not torch.all(result) :
        print(f"{k} is not equal")

In [60]:
from safetensors import safe_open

tensors = {}
with safe_open("./hyena_model_test2/checkpoint-500/model.safetensors", framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)


In [36]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 500
  Batch size = 8


{'eval_loss': 0.0045535205863416195,
 'eval_accuracy': 0.9985270825136446,
 'eval_f1': 0.9748788391500101,
 'eval_precision': 0.9751584865609133,
 'eval_recall': 0.9745993520827958,
 'eval_runtime': 58.2788,
 'eval_samples_per_second': 8.579,
 'eval_steps_per_second': 1.081,
 'epoch': 1.0}

In [35]:
predicts = trainer.predict(tokenize_val_dataset)

The following columns in the test set don't have a corresponding argument in `TokenClassification.forward` and have been ignored: input_quals. If input_quals are not expected by `TokenClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 500
  Batch size = 8


predictions.shape=(500, 9524, 2)
labels.shape=(500, 9524)


In [51]:
def summary_predict(predictions, labels):
    predictions = np.argmax(predictions, axis=2)
    # Initialize lists to hold the filtered predictions and labels
    true_predictions = []
    true_labels = []

    # Filter out '-100' labels and correspondingly filter predictions
    for prediction, label in zip(predictions, labels):
        filtered_prediction = []
        filtered_label = []

        for p, l in zip(prediction, label):
            if l != -100:
                filtered_prediction.append(p)
                filtered_label.append(l)
        true_predictions.append(filtered_prediction)
        true_labels.append(filtered_label)

    return true_predictions, true_labels

from rich.console import Console
from rich.highlighter import RegexHighlighter
from rich.theme import Theme


class LabelHighlighter(RegexHighlighter):
    """Apply style to anything that looks like an email."""

    base_style = "label."
    highlights = [r"(?P<label>1+)"]



def alignment_predict(prediction, label):
    import textwrap

    prediction_str = "".join(map(lambda x: str(x), prediction))
    label_str = "".join(map(lambda x: str(x), label))
    
    front2 = "L:"
    front1 = "P:"
    theme = Theme({"label.label": "bold magenta"})
    console = Console(highlighter=LabelHighlighter(), theme=theme)
    for l1, l2 in zip(textwrap.wrap(prediction_str), textwrap.wrap(label_str)):
        ss = f"{front1}{l1}\n{front2}{l2}"
        console.print(ss)

In [38]:
true_predictions, true_labels = summary_predict(predicts[0], predicts[1])

In [52]:
alignment_predict(true_predictions[4], true_labels[4])

# Train with Pytorch

# Train with native model

In [None]:
import math
from collections import namedtuple
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torchvision.ops import StochasticDepth

In [None]:
# @title Hyena layer


def fftconv(u, k, D):
    """
    We apply a convolution through the fourier domain (from the Convolution Theorem)

    """
    seqlen = u.shape[-1]
    fft_size = 2 * seqlen

    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

    if len(u.shape) > 3:
        k_f = k_f.unsqueeze(1)
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]

    out = y + u * D.unsqueeze(-1)
    return out.to(dtype=u.dtype)


@torch.jit.script
def mul_sum(q, y):
    return (q * y).sum(dim=1)


class OptimModule(nn.Module):
    """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters"""

    def register(self, name, tensor, lr=None, wd=0.0):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {}
            if lr is not None:
                optim["lr"] = lr
            if wd is not None:
                optim["weight_decay"] = wd
            setattr(getattr(self, name), "_optim", optim)


class Sin(nn.Module):
    """The Sin activation function for the Hyena Filter function."""

    def __init__(self, dim, w=10, train_freq=True):
        super().__init__()
        self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim)

    def forward(self, x):
        return torch.sin(self.freq * x)


class PositionalEmbedding(OptimModule):
    def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
        """Complex exponential positional embeddings for Hyena filters."""
        super().__init__()

        self.seq_len = seq_len
        # The time embedding fed to the filteres is normalized so that t_f = 1
        t = torch.linspace(0, 1, self.seq_len)[None, :, None]  # 1, L, 1

        if emb_dim > 1:
            bands = (emb_dim - 1) // 2
        # To compute the right embeddings we use the "proper" linspace
        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
        w = 2 * math.pi * t_rescaled / seq_len  # 1, L, 1

        f = torch.linspace(1e-4, bands - 1, bands)[None, None]
        z = torch.exp(-1j * f * w)
        z = torch.cat([t, z.real, z.imag], dim=-1)
        self.register("z", z, lr=lr_pos_emb)
        self.register("t", t, lr=0.0)

    def forward(self, L):
        return self.z[:, :L], self.t[:, :L]


class ExponentialModulation(OptimModule):
    """The window function applied to the output of the (MLP) filter function."""

    def __init__(
        self,
        d_model,
        fast_decay_pct=0.3,
        slow_decay_pct=1.5,
        target=1e-2,
        modulation_lr=0.0,
        modulate: bool = True,
        shift: float = 0.05,
        **kwargs,
    ):
        super().__init__()
        self.modulate = modulate
        self.shift = shift
        max_decay = math.log(target) / fast_decay_pct
        min_decay = math.log(target) / slow_decay_pct
        deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
        self.register("deltas", deltas, lr=modulation_lr)

    def forward(self, t, x):
        if self.modulate:
            decay = torch.exp(-t * self.deltas.abs())
            x = x * (decay + self.shift)
        return x


class HyenaFilter(OptimModule):
    def __init__(
        self,
        d_model,
        emb_dim=3,  # dim of input to MLP, augments with positional encoding
        order=16,  # width of the implicit MLP
        fused_fft_conv=False,
        seq_len=1024,
        lr=1e-3,
        lr_pos_emb=1e-5,
        dropout=0.0,
        w=1,  # frequency of periodic activations
        wd=0,  # weight decay of kernel parameters
        bias=True,
        num_inner_mlps=2,
        normalized=False,
        **kwargs,
    ):
        """
        Implicit long filter with modulation.

        Args:
            d_model: number of channels in the input
            emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
            order: width of the FFN
            num_inner_mlps: number of inner linear layers inside filter MLP

        Note:
            filter_dropout is not implemented
        """
        super().__init__()

        self.d_model = d_model
        self.use_bias = bias
        self.fused_fft_conv = fused_fft_conv
        self.bias = nn.Parameter(torch.randn(self.d_model))
        self.dropout = nn.Dropout(dropout)

        act = Sin(dim=order, w=w)
        self.emb_dim = emb_dim
        assert (
            emb_dim % 2 != 0 and emb_dim >= 3
        ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
        self.seq_len = seq_len

        self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)

        self.implicit_filter = nn.Sequential(
            nn.Linear(emb_dim, order),
            act,
        )
        for i in range(num_inner_mlps):
            self.implicit_filter.append(nn.Linear(order, order))
            self.implicit_filter.append(act)

        self.implicit_filter.append(nn.Linear(order, d_model, bias=False))

        self.modulation = ExponentialModulation(d_model, **kwargs)

        self.normalized = normalized
        for c in self.implicit_filter.children():
            for name, v in c.state_dict().items():
                optim = {"weight_decay": wd, "lr": lr}
                setattr(getattr(c, name), "_optim", optim)

    def filter(self, L, *args, **kwargs):
        z, t = self.pos_emb(L)
        h = self.implicit_filter(z)
        h = self.modulation(t, h)
        return h

    def forward(self, x, L, k=None, bias=None, *args, **kwargs):
        if k is None:
            k = self.filter(L)

        # Ensure compatibility with filters that return a tuple
        k = k[0] if type(k) is tuple else k

        y = fftconv(x, k, bias)
        return y


class HyenaOperator(nn.Module):
    def __init__(
        self,
        d_model,
        l_max,
        order=2,
        filter_order=64,
        dropout=0.0,
        filter_dropout=0.0,
        **filter_args,
    ):
        r"""
        Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf

        Args:
            d_model (int): Dimension of the input and output embeddings (width of the layer)
            l_max: (int): Maximum input sequence length. Defaults to None
            order: (int): Depth of the Hyena recurrence. Defaults to 2
            dropout: (float): Dropout probability. Defaults to 0.0
            filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
        """
        super().__init__()

        self.d_model = d_model
        self.l_max = l_max
        self.order = order
        inner_width = d_model * (order + 1)
        self.dropout = nn.Dropout(dropout)
        self.in_proj = nn.Linear(d_model, inner_width)
        self.out_proj = nn.Linear(d_model, d_model)

        self.short_filter = nn.Conv1d(inner_width, inner_width, 3, padding=2, groups=inner_width)
        self.filter_fn = HyenaFilter(
            d_model * (order - 1),
            order=filter_order,
            seq_len=l_max,
            channels=1,
            dropout=filter_dropout,
            **filter_args,
        )

    def forward(self, u, *args, **kwargs):
        l = u.size(-2)
        l_filter = min(l, self.l_max)
        u = self.in_proj(u)
        u = rearrange(u, "b l d -> b d l")

        uc = self.short_filter(u)[..., :l_filter]
        *x, v = uc.split(self.d_model, dim=1)

        k = self.filter_fn.filter(l_filter)[0]
        k = rearrange(k, "l (o d) -> o d l", o=self.order - 1)
        bias = rearrange(self.filter_fn.bias, "(o d) -> o d", o=self.order - 1)

        for o, x_i in enumerate(reversed(x[1:])):
            v = self.dropout(v * x_i)
            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])

        y = rearrange(v * x[0], "b d l -> b l d")

        y = self.out_proj(y)
        return y

In [None]:
# @title MLP layer
"""
The MLP layer after the mixer layer (HyenaOperator).
"""


class Mlp(nn.Module):

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.gelu,
        return_residual=False,
        device=None,
        dtype=None,
    ):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)

In [None]:
# @title Block layer (Hyena + MLP layers)


"""
A block consists of a Mixer layer (Hyena or attention), and a MLP layer.

"""


class LinearResidual(nn.Linear):
    """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return super().forward(input), input


class Block(nn.Module):

    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        prenorm=True,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        drop_path1=0.0,
        drop_path2=0.0,
        return_residual=False,
        residual_in_fp32=False,
    ):
        """
        From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
        For prenorm=True, this Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
        the hidden_states (output of the MLP) and the residual.
        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        For prenorm=False, this Block has the same structure as a regular postnorm Transformer
        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
        This is for performance reason: for post-norm architecture, returning the input allows us
        to fuse the backward of nn.Linear with the residual connection.
        """
        super().__init__()
        self.prenorm = prenorm
        self.return_residual = return_residual
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls()
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode="row")
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode="row")
            self.norm2 = norm_cls(dim)

    def forward(self, hidden_states, residual=None, mixer_subset=None, mixer_kwargs=None):
        r"""Pass the input through the encoder layer.
        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
            mixer_subset: for cross-attention only. If not None, will take a subset of x
                before applying the query projection. Useful for e.g., ViT where we only care
                about the CLS token in the last layer.
        """
        if self.prenorm:
            dropped = self.drop_path1(self.dropout1(hidden_states))
            residual = (dropped + residual) if residual is not None else dropped
            hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
            if mixer_kwargs is None:
                mixer_kwargs = {}
            if mixer_subset is not None:
                mixer_kwargs["mixer_subset"] = mixer_subset
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
            if not isinstance(self.mlp, nn.Identity):
                dropped = self.drop_path2(self.dropout2(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)

                hidden_states = self.mlp(hidden_states)
            return hidden_states, residual
        else:
            assert residual is None
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out

            hidden_states = self.norm1(
                (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                    dtype=self.norm1.weight.dtype
                )
            )

            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out

                hidden_states = self.norm2(
                    (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                        dtype=self.norm2.weight.dtype
                    )
                )

            return hidden_states


def create_mixer_cls(
    layer=None, attn_layer_idx=None, attn_cfg=None, layer_idx=None, device=None, dtype=None
):
    factory_kwargs = {"device": device, "dtype": dtype}
    if attn_layer_idx is not None and layer_idx in attn_layer_idx:
        causal = True if attn_cfg is None else attn_cfg.pop("causal", True)

        mha_cls = MHA

        mixer_cls = partial(
            mha_cls,
            causal=causal,
            layer_idx=layer_idx,
            **(attn_cfg if attn_cfg is not None else {}),
            **factory_kwargs,
        )
    else:
        # mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs)

        mixer_cls = partial(HyenaOperator, **layer)

    return mixer_cls


def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None):
    factory_kwargs = {"device": device, "dtype": dtype}
    inner_dim = d_inner if d_inner is not None else 4 * d_model

    mlp_cls = partial(
        Mlp,
        hidden_features=inner_dim,
        activation=partial(F.gelu, approximate="tanh"),
        **factory_kwargs,
    )

    return mlp_cls


def create_block(
    d_model,
    d_inner=None,
    layer=None,
    attn_layer_idx=None,
    attn_cfg=None,
    layer_norm_epsilon=1e-5,
    resid_dropout1=0.0,
    resid_dropout2=0.0,
    residual_in_fp32=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    factory_kwargs = {"device": device, "dtype": dtype}
    mixer_cls = create_mixer_cls(
        layer=layer,
        attn_layer_idx=attn_layer_idx,
        attn_cfg=attn_cfg,
        layer_idx=layer_idx,
        **factory_kwargs,
    )
    mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, **factory_kwargs)
    norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs)
    block = Block(
        d_model,
        mixer_cls,
        mlp_cls,
        norm_cls=norm_cls,
        prenorm=True,
        resid_dropout1=resid_dropout1,
        resid_dropout2=resid_dropout2,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, glu_act=False
):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
            # If using GLU activation for now, we scale the std by 2
            elif name in ["output_linear.0.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                if not glu_act:
                    nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
                else:
                    out_features = p.shape[0]
                    # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
                    # on average.
                    nn.init.normal_(
                        p[: out_features // 2],
                        mean=0.0,
                        std=initializer_range / math.sqrt(2 * n_layer) * 2,
                    )

In [None]:
class LMBackbone(nn.Module):

    def __init__(
        self,
        d_model: int,
        n_layer: int,
        d_inner: int,
        vocab_size: int,
        process_group=None,
        layer=None,
        attn_layer_idx=None,
        attn_cfg=None,
        max_position_embeddings=0,
        resid_dropout: float = 0.0,
        embed_dropout: float = 0.1,
        layer_norm_epsilon: float = 1e-5,
        initializer_cfg=None,
        residual_in_fp32=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.process_group = process_group
        self.residual_in_fp32 = residual_in_fp32
        # note max_position_embeddings is 0 for Hyena, and therefore isn't used
        self.embeddings = GPT2Embeddings(
            d_model, vocab_size, max_position_embeddings, **factory_kwargs
        )

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    d_inner=d_inner,
                    layer=layer,
                    attn_layer_idx=attn_layer_idx,
                    attn_cfg=attn_cfg,
                    layer_norm_epsilon=layer_norm_epsilon,
                    resid_dropout1=embed_dropout if i == 0 else resid_dropout,
                    resid_dropout2=resid_dropout,
                    residual_in_fp32=residual_in_fp32,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.drop_f = nn.Dropout(resid_dropout)
        self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs)

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def forward(self, input_ids, position_ids=None):
        hidden_states = self.embeddings(
            input_ids,
            position_ids=position_ids,
        )
        residual = None

        for layer in self.layers:
            hidden_states, residual = layer(hidden_states, residual)

        dropped = self.drop_f(hidden_states)
        residual = (dropped + residual) if residual is not None else dropped
        hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))

        return hidden_states

In [None]:
class TokenClassificationHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        lin1_size: int,
        lin2_size: int,
        num_class: int,
    ):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, lin1_size),
            # nn.BatchNorm1d(lin1_size),
            nn.ReLU(),
            nn.Linear(lin1_size, lin2_size),
            # nn.BatchNorm1d(lin2_size),
            nn.ReLU(),
            nn.Linear(lin2_size, num_class),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

In [None]:
# @title Model (backbone + head)


"""
Putting it all together, the model consists of a backbone model
and a decoder head (you can turn off head for embeddings only too).

Here we use a simple head to do multi-classification, but
can also swap the head to do next token prediction too.  We defer to the main
HyenaDNA for that code, since pretraining with next token prediction isn't quite
feasible on colab.

"""


class HyenaDNAModel(nn.Module):

    def __init__(
        self,
        d_model: int,
        n_layer: int,
        d_inner: int,
        vocab_size: int,
        layer=None,
        attn_layer_idx=None,
        attn_cfg=None,
        max_position_embeddings=0,
        resid_dropout: float = 0.0,
        embed_dropout: float = 0.1,
        layer_norm_epsilon: float = 1e-5,
        initializer_cfg=None,
        residual_in_fp32=False,
        pad_vocab_size_multiple: int = 1,
        n_classes: int = 2,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)

        # check if layer (config) has d_model (HF code differs from main Safari code)
        if "d_model" not in layer:
            layer["d_model"] = d_model

        self.backbone = LMBackbone(
            d_model=d_model,
            n_layer=n_layer,
            d_inner=d_inner,
            vocab_size=vocab_size,
            layer=layer,
            attn_layer_idx=attn_layer_idx,
            attn_cfg=attn_cfg,
            max_position_embeddings=max_position_embeddings,
            resid_dropout=resid_dropout,
            embed_dropout=embed_dropout,
            layer_norm_epsilon=layer_norm_epsilon,
            initializer_cfg=initializer_cfg,
            residual_in_fp32=residual_in_fp32,
            **factory_kwargs,
            **kwargs,
        )

        # we only need a head if doing classification, otherwise we'll use the
        # hidden states as embeddings

        self.head = SequenceDecoder(d_model=d_model, d_output=n_classes, l_output=0, mode="pool")

        # Initialize weights and apply final processing
        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

        # if self.use_head:
        #     self.tie_weights()

    # def tie_weights(self):
    #     self.head.weight = self.backbone.embeddings.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, state=None):  # state for the repo interface
        hidden_states = self.backbone(input_ids, position_ids=position_ids)

        if self.use_head:
            return self.head(hidden_states)
        else:
            return hidden_states