In [None]:
import re
from transformers import AutoTokenizer
from nltk import ngrams
from collections import Counter
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict as edict
import zipfile

In [None]:
zip_path = '/content/extracted_all.zip'
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall("/content/dataset")

In [None]:

class Config:
    def __init__(self):
        self.tsv_path = '/content/gene_input_text.tsv'
        self.bio_model_name =  "microsoft/biogpt"
        self.cls_token = "<CLS>"
        self.batch_size = 8
        self.biogpt_model_context_length = 1024
        self.mistral7b_model_context_length = 32768
        self.bio_gpt_embedding_dim = 1024
        self.mistral7b_embedding_dim = 4096


In [None]:
config = Config()

## STEP 1

**Unknown TERMS detection Algorithm**

#### experiments

In [2]:
import re
from transformers import AutoTokenizer
from nltk import ngrams
from collections import Counter

# Load the tokenizer (BERT in this example)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = ("The patient was diagnosed with myocarditis and cardiomyopathy. "
        "Also, pheeiong simenowon. The patient had chronic kidney disease and myocarditis symptoms "
        "along with the severe symptoms of the pheeiong simenowon.")
tokens = tokenizer.tokenize(text)
print("Tokens:", tokens)

# --- Step 1: Merge subword tokens to extract candidate unknown words ---
candidate_terms = []
i = 0
while i < len(tokens):
    token = tokens[i]
    if i < len(tokens) - 1 and tokens[i+1].startswith("##"):
        current_term = token
        while i + 1 < len(tokens) and tokens[i+1].startswith("##"):
            i += 1
            current_term += tokens[i][2:]
        candidate_terms.append(current_term)
    i += 1

print("Candidate Terms from Subwords:", candidate_terms)

# --- Step 2: Remove duplicates to get unique new words ---
unique_new_words = list(set(candidate_terms))
print("Unique New Words:", unique_new_words)

# --- Step 3: Build a sequence of new words in the order of occurrence ---
unknown_sequence = []
i = 0
while i < len(tokens):
    token = tokens[i]
    # If token is the start of a merged word (followed by subword tokens)
    if i < len(tokens) - 1 and tokens[i+1].startswith("##"):
        current_term = token
        while i + 1 < len(tokens) and tokens[i+1].startswith("##"):
            i += 1
            current_term += tokens[i][2:]
        if current_term in unique_new_words:
            unknown_sequence.append(current_term)
    else:
        if token in unique_new_words:
            unknown_sequence.append(token)
    i += 1

print("Sequence of New Words in Order:", unknown_sequence)

# --- Step 4: N-gram analysis on the sequence of new words ---
# We check for bigrams and trigrams (adjust as needed)
final_phrases = set()
for n in [2, 3]:
    n_grams = list(ngrams(unknown_sequence, n))
    ngram_counts = Counter(n_grams)
    # Consider phrases that appear at least twice
    for gram, count in ngram_counts.items():
        if count >= 2:
            final_phrases.add(" ".join(gram))

print("Frequent 2/3-grams among new words:", final_phrases)

# --- Step 5: Combine the phrases and remaining unique new words ---
# Remove individual words that are part of any detected phrase
words_in_phrases = set()
for phrase in final_phrases:
    words_in_phrases.update(phrase.split())

# Remove the component words from the unique new words
remaining_words = set(unique_new_words) - words_in_phrases

# The final result is the union of the detected phrases and the remaining words
final_detected = list(final_phrases.union(remaining_words))
print("Final detected new terms:", final_detected)


ImportError: numpy>=1.17,<2.0 is required for a normal functioning of this module, but found numpy==2.2.3.
Try: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main

In [None]:
# Define your model name and auth token
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
auth_token = "####"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=auth_token, trust_remote_code=True)

text = (
    "To assess the efficacy of our gene name extraction tool, we will analyze a text chunk containing a variety of gene symbols and full names. This includes well-known genes like TP53, BRCA1, and EGFR, alongside developmental genes such as SHH, WNT3A, and HOXD13. We will also test its ability to identify less common genes like FOXP2, STAT3, and VEGF, as well as genes involved in metabolism such as APOE, LDLR, and INS. Furthermore, we will include genes with numerical suffixes like CDK4 and ERBB2, and genes with hyphenated names such as HLA-DRB1 and TNF-alpha. Finally, we will incorporate gene families like the KRAS family and the MYC family to see if the function can handle these broader references."
)

text = re.sub(r"[^A-Za-z0-9\s-]", "", text)


# 1) Tokenize
tokens = tokenizer.tokenize(text)
print("Tokens:", tokens)

# 2) Group subtokens into words at each '▁'
words_subtokens = []
current = []
for tok in tokens:
    if tok.startswith("▁"):
        if current:
            words_subtokens.append(current)
        current = [tok]
    else:
        current.append(tok)
if current:
    words_subtokens.append(current)

# 3) Reconstruct the actual words (strip '▁' and concat)
reconstructed = []
for subtoks in words_subtokens:
    w = subtoks[0].lstrip("▁") + "".join(subtoks[1:])
    w = re.sub(r"[^\w]+$", "", w)  # drop trailing punctuation
    reconstructed.append(w)

# 4) Build the sequence of “unknown” words (those split into >1 subtoken)
unknown_sequence = [
    reconstructed[i]
    for i, subtoks in enumerate(words_subtokens)
    if len(subtoks) > 1
]
print("Candidate Terms from Subwords:", unknown_sequence)

# 5) Find frequent 2‑ and 3‑grams in that sequence (threshold ≥2)
freq_2grams = []
freq_3grams = []
for n in (2, 3):
    counts = Counter(ngrams(unknown_sequence, n))
    freq = [" ".join(gram) for gram, cnt in counts.items() if cnt >= 2]
    if n == 2:
        freq_2grams = freq
    else:
        freq_3grams = freq

print("Frequent 2-grams among new words:", freq_2grams)
print("Frequent 3-grams among new words:", freq_3grams)

# 6) Assemble final_unknown_terms, collapsing any repeated phrases
freq_phrases = set(tuple(ng.split()) for ng in (freq_2grams + freq_3grams))
final_unknown_terms = []
i = 0
while i < len(unknown_sequence):
    matched = False
    # try longer phrases first
    for n in (3, 2):
        if i + n <= len(unknown_sequence) and tuple(unknown_sequence[i:i+n]) in freq_phrases:
            phrase = " ".join(unknown_sequence[i:i+n])
            if phrase not in final_unknown_terms:
                final_unknown_terms.append(phrase)
            i += n
            matched = True
            break
    if not matched:
        w = unknown_sequence[i]
        if w not in final_unknown_terms:
            final_unknown_terms.append(w)
        i += 1

print("final_unknown_terms:", final_unknown_terms)


Tokens: ['▁To', '▁assess', '▁the', '▁eff', 'ic', 'acy', '▁of', '▁our', '▁gene', '▁name', '▁extr', 'action', '▁tool', '▁we', '▁will', '▁analyze', '▁a', '▁text', '▁chunk', '▁containing', '▁a', '▁variety', '▁of', '▁gene', '▁symbols', '▁and', '▁full', '▁names', '▁This', '▁includes', '▁well', '-', 'known', '▁genes', '▁like', '▁T', 'P', '5', '3', '▁BR', 'CA', '1', '▁and', '▁E', 'G', 'FR', '▁alongside', '▁development', 'al', '▁genes', '▁such', '▁as', '▁SH', 'H', '▁W', 'NT', '3', 'A', '▁and', '▁HO', 'X', 'D', '1', '3', '▁We', '▁will', '▁also', '▁test', '▁its', '▁ability', '▁to', '▁identify', '▁less', '▁common', '▁genes', '▁like', '▁FO', 'X', 'P', '2', '▁STAT', '3', '▁and', '▁V', 'E', 'GF', '▁as', '▁well', '▁as', '▁genes', '▁involved', '▁in', '▁met', 'abol', 'ism', '▁such', '▁as', '▁A', 'PO', 'E', '▁L', 'DL', 'R', '▁and', '▁IN', 'S', '▁Furthermore', '▁we', '▁will', '▁include', '▁genes', '▁with', '▁numerical', '▁suffix', 'es', '▁like', '▁CD', 'K', '4', '▁and', '▁ER', 'BB', '2', '▁and', '▁genes',

### NEW Terms DETECTION Algorithm [step 1]

In [None]:
# Define your model name and auth token
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
auth_token = "####"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=auth_token, trust_remote_code=True)

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
def extract_unknown_terms(text, tokenizer):
    """
    Given a string `text` and a SentencePiece‐style tokenizer,
    returns a list of 'unknown' words and repeated unknown‐word phrases
    (2‑ or 3‑grams appearing ≥2×) in their order of first appearance.
    """
    text = re.sub(r"[^A-Za-z0-9\s-]", "", text)

    # 1) Tokenize
    tokens = tokenizer.tokenize(text)

    # 2) Group subtokens into words at each '▁'
    words_subtokens = []
    current = []
    for tok in tokens:
        if tok.startswith("▁"):
            if current:
                words_subtokens.append(current)
            current = [tok]
        else:
            current.append(tok)
    if current:
        words_subtokens.append(current)

    # 3) Reconstruct words (strip '▁', concat, drop trailing punctuation)
    reconstructed = []
    for subtoks in words_subtokens:
        w = subtoks[0].lstrip("▁") + "".join(subtoks[1:])
        w = re.sub(r"[^\w]+$", "", w)
        reconstructed.append(w)

    # 4) Build sequence of unknown words (split into >1 subtoken)
    unknown_sequence = [
        reconstructed[i]
        for i, subtoks in enumerate(words_subtokens)
        if len(subtoks) > 1
    ]

    # 5) Find repeated 2‑ and 3‑grams (threshold ≥2)
    freq_ngrams = []
    for n in (2, 3):
        for gram, cnt in Counter(ngrams(unknown_sequence, n)).items():
            if cnt >= 2:
                freq_ngrams.append(" ".join(gram))
    freq_tuples = set(tuple(ng.split()) for ng in freq_ngrams)

    # 6) Assemble final list, collapsing repeated phrases
    final_unknown_terms = []
    i = 0
    while i < len(unknown_sequence):
        matched = False
        for n in (3, 2):  # try longer first
            if i + n <= len(unknown_sequence) and tuple(unknown_sequence[i:i+n]) in freq_tuples:
                phrase = " ".join(unknown_sequence[i:i+n])
                if phrase not in final_unknown_terms:
                    final_unknown_terms.append(phrase)
                i += n
                matched = True
                break
        if not matched:
            w = unknown_sequence[i]
            if w not in final_unknown_terms:
                final_unknown_terms.append(w)
            i += 1

    return final_unknown_terms


In [None]:
text = ("To assess the efficacy of our gene name extraction tool, we will analyze a text chunk containing a variety of gene symbols and full names. This includes well known genes like TP53, BRCA1, and EGFR, alongside developmental genes such as SHH, WNT3A, and HOXD13. We will also test its ability to identify less common genes like FOXP2, STAT3, and VEGF, as well as genes involved in metabolism such as APOE, LDLR, and INS. Furthermore, we will include genes with numerical suffixes like CDK4 and ERBB2, and genes with hyphenated names such as HLA-DRB1 and TNF-alpha. Finally, we will incorporate gene families like the KRAS family and the MYC family to see if the function can handle these broader references.")

print("final_unknown_terms:", extract_unknown_terms(text, tokenizer))


final_unknown_terms: ['efficacy', 'extraction', 'TP53', 'BRCA1', 'EGFR', 'developmental', 'SHH', 'WNT3A', 'HOXD13', 'FOXP2', 'STAT3', 'VEGF', 'metabolism', 'APOE', 'LDLR', 'INS', 'suffixes', 'CDK4', 'ERBB2', 'hyphenated', 'HLA-DRB1', 'TNF-alpha', 'KRAS', 'MYC']


## STEP 2

**Data Processing**

#### data observation

In [None]:
input_file = "/content/gene_alias_description.txt"

with open(input_file, "r", encoding="utf-8") as f:
    for i in range(5):  # first 5 lines
        line = f.readline().strip()
        print(f"Line {i+1}:", line)


Line 1: APOC4-APOC2	APOC4-APOC2 Readthrough (NMD Candidate)	This locus represents naturally occurring read-through transcription between the neighboring apolipoprotein C-IV (APOC4) and apolipoprotein C-II (APOC2) genes on chromosome 19. The read-through transcript is a candidate for nonsense-mediated mRNA decay (NMD), and is thus unlikely to produce a protein product.
Line 2: NME4	NME/NM23 Nucleoside Diphosphate Kinase 4	The nucleoside diphosphate (NDP) kinases (EC 2.7.4.6) are ubiquitous enzymes that catalyze transfer of gamma-phosphates, via a phosphohistidine intermediate, between nucleoside and dioxynucleoside tri- and diphosphates. The enzymes are products of the nm23 gene family, which includes NME4 (Milon et al., 1997
Line 3: APOL1	Apolipoprotein L1	This gene encodes a secreted high density lipoprotein which binds to apolipoprotein A-I. Apolipoprotein A-I is a relatively abundant plasma protein and is the major apoprotein of HDL. It is involved in the formation of most cholester

In [None]:
# Checking if all lines have exactly 3 parts (ID, alias, description)
bad_lines = []
with open(input_file, "r", encoding="utf-8") as f:
    for idx, line in enumerate(f):
        parts = line.strip().split("\t")
        if len(parts) != 3:
            bad_lines.append((idx, line.strip()))

print(f"Total malformed lines: {len(bad_lines)}")
if bad_lines:
    print("Sample malformed lines:")
    for idx, l in bad_lines[:3]:  # first 3 bad lines
        print(f"Line {idx}: {l}")



Total malformed lines: 0


### Processing code to get final dataset


here we got the data set like this - <br>
<br>
File **gene_alias_description.txt** contains gene ID, gene alias and gene description,
<br>
Format: [gene ID] \t [gene alias] \t [gene description]
<br>
<br>
we want the format like -
Format: [gene ID] \t  Alias ... | Description: ...
<br> <br>
here seperation tokens have added because we want the seperation between the context of the gene description and the alias while creating the [cls] token for the gene

In [None]:
input_file = "gene_alias_description.txt"
output_file = "gene_input_text.tsv"


In [None]:
with open(input_file, "r", encoding="utf-8") as fin, open(output_file, "w", encoding="utf-8") as fout:
    for line in fin:
        parts = line.strip().split("\t")
        if len(parts) != 3:
            continue  # Skip malformed lines altho there are none!

        gene_id, alias, description = parts
        combined_text = f"Alias {alias} | Description {description}"
        fout.write(f"{gene_id}\t{combined_text}\n")

print(f"Processed gene descriptions written to {output_file} .")


FileNotFoundError: [Errno 2] No such file or directory: 'gene_alias_description.txt'

In [None]:
with open(output_file, "r", encoding="utf-8") as f:
    for i in range(5):
        print(f.readline().strip())


APOC4-APOC2	Alias APOC4-APOC2 Readthrough (NMD Candidate) | Description This locus represents naturally occurring read-through transcription between the neighboring apolipoprotein C-IV (APOC4) and apolipoprotein C-II (APOC2) genes on chromosome 19. The read-through transcript is a candidate for nonsense-mediated mRNA decay (NMD), and is thus unlikely to produce a protein product.
NME4	Alias NME/NM23 Nucleoside Diphosphate Kinase 4 | Description The nucleoside diphosphate (NDP) kinases (EC 2.7.4.6) are ubiquitous enzymes that catalyze transfer of gamma-phosphates, via a phosphohistidine intermediate, between nucleoside and dioxynucleoside tri- and diphosphates. The enzymes are products of the nm23 gene family, which includes NME4 (Milon et al., 1997
APOL1	Alias Apolipoprotein L1 | Description This gene encodes a secreted high density lipoprotein which binds to apolipoprotein A-I. Apolipoprotein A-I is a relatively abundant plasma protein and is the major apoprotein of HDL. It is involve

In [None]:
from typing import Union, List, Optional

def get_gene_text(
    gene_ids: Union[str, List[str]],
    tsv_path: str = '/content/gene_input_text.tsv'
) -> Union[Optional[str], List[Optional[str]]]:
    """
    Retrieve the alias and description text(s) for given gene ID(s) from a TSV file.

    Parameters:
        gene_ids: A single gene ID (str) or a list of gene IDs (List[str]).
        tsv_path: Path to the TSV file containing gene information.

    Returns:
        If a single gene ID is provided, returns a single gene text (or None if not found).
        If a list is provided, returns a list of gene texts corresponding to each gene ID.
        For any gene ID that is not found, the corresponding entry in the output will be None.
    """
    # Normalize input to a list if it is a single gene ID.
    single_input = False
    if isinstance(gene_ids, str):
        gene_ids = [gene_ids]
        single_input = True

    # Prepare a dictionary to store the mapping from gene_id to gene_text.
    gene_text_dict = {}

    # Read the TSV file once, storing entries that match the requested gene_ids.
    with open(tsv_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) != 2:
                continue  # Skip lines that don't have exactly two columns.
            current_gene_id, gene_text = parts
            if current_gene_id in gene_ids:
                gene_text_dict[current_gene_id] = gene_text
                # Early exit if all gene_ids have been found.
                if len(gene_text_dict) == len(gene_ids):
                    break

    # Construct the output in the same order as the input gene_ids.
    results = [gene_text_dict.get(gene_id, None) for gene_id in gene_ids]

    # Return a single string if the input was a single gene ID.
    return results[0] if single_input else results

In [None]:
if get_gene_text(gene_ids='ZNF638') == None:
    print("Gene not found")
else:
    print(get_gene_text(gene_ids='ZNF638'))

Alias Zinc Finger Protein 638 | Description The protein encoded by this gene is a nucleoplasmic protein. It binds cytidine-rich sequences in double-stranded DNA. This protein has three types of domains: MH1, MH2 (repeated three times) and MH3. It is associated with packaging, transferring, or processing transcripts. Multiple alternatively spliced transcript variants have been found for this gene, but the biological validity of some variants has not been determined.


In [None]:
if get_gene_text(gene_ids='AC') == None:
    print("Gene not found")
else:
    print(get_gene_text(gene_ids='AC'))

Gene not found


In [None]:
if get_gene_text(gene_ids=['ZNF638', 'APOC4-APOC2', 'PLEC']) == None:
    print("Gene not found")
else:
    print(get_gene_text(gene_ids=['ZNF638', 'APOC4-APOC2', 'PLEC']))
    print(type(get_gene_text(gene_ids=['ZNF638', 'APOC4-APOC2', 'PLEC'])))

['Alias Zinc Finger Protein 638 | Description The protein encoded by this gene is a nucleoplasmic protein. It binds cytidine-rich sequences in double-stranded DNA. This protein has three types of domains: MH1, MH2 (repeated three times) and MH3. It is associated with packaging, transferring, or processing transcripts. Multiple alternatively spliced transcript variants have been found for this gene, but the biological validity of some variants has not been determined.', 'Alias APOC4-APOC2 Readthrough (NMD Candidate) | Description This locus represents naturally occurring read-through transcription between the neighboring apolipoprotein C-IV (APOC4) and apolipoprotein C-II (APOC2) genes on chromosome 19. The read-through transcript is a candidate for nonsense-mediated mRNA decay (NMD), and is thus unlikely to produce a protein product.', 'Alias Plectin | Description Plectin is a prominent member of an important family of structurally and in part functionally related proteins, termed plak

## STEP 3

**Generating Initial Embeddings**

#### experiments

In [None]:
!pip install sacremoses

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

bio_tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt",  trust_remote_code=True)
bio_model = AutoModelForCausalLM.from_pretrained("microsoft/biogpt")
bio_model.eval()

BioGptForCausalLM(
  (biogpt): BioGptModel(
    (embed_tokens): BioGptScaledWordEmbedding(42384, 1024, padding_idx=1)
    (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
    (layers): ModuleList(
      (0-23): 24 x BioGptDecoderLayer(
        (self_attn): BioGptSdpaAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm): LayerNorm((

In [None]:
print("Bio gpt config:", bio_model.config)

Bio gpt config: BioGptConfig {
  "_attn_implementation_autoset": true,
  "activation_dropout": 0.0,
  "architectures": [
    "BioGptForCausalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "layerdrop": 0.0,
  "max_position_embeddings": 1024,
  "model_type": "biogpt",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "scale_embedding": true,
  "torch_dtype": "float32",
  "transformers_version": "4.51.1",
  "use_cache": true,
  "vocab_size": 42384
}



In [None]:
inputs = bio_tokenizer(get_gene_text(gene_id='PLEC'), return_tensors="pt")
inputs['input_ids'].shape

torch.Size([1, 820])

In [None]:
with torch.no_grad():
    outputs = bio_model.biogpt(**inputs)       # hidden_states: (1, seq_len, d')
    hidden_states = outputs.last_hidden_state       # shape: [1, L, d']

In [None]:
hidden_states.shape

torch.Size([1, 164, 1024])

### DataLoader

In [141]:
from torch.utils.data import Dataset, DataLoader, Subset

class GeneDataset(Dataset):
    """
    Dataset that only returns gene IDs.
    """
    def __init__(self, config: Config):
        self.gene_ids = []

        with open(config.tsv_path, "r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split("\t")
                if len(parts) != 2:
                    continue  # Skip invalid lines.
                gene_id, _ = parts
                self.gene_ids.append(gene_id)

    def __len__(self):
        return len(self.gene_ids)

    def __getitem__(self, idx):
        return self.gene_ids[idx]



def create_dataloaders(config: Config):
    """
    Create train and test dataloaders using the first 10,000 as training and the rest as test.
    """
    full_dataset = GeneDataset(config)

    train_size = 10000
    train_indices = list(range(train_size))
    test_indices = list(range(train_size, len(full_dataset)))

    train_dataset = Subset(full_dataset, train_indices)
    test_dataset = Subset(full_dataset, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

    return train_loader, test_loader

In [142]:
train_loader, test_loader = create_dataloaders(config)

for train_batch in train_loader:
    print("Train batch gene IDs:", train_batch)
    break

for test_batch in test_loader:
    print("Test batch gene IDs:", test_batch)
    break


Train batch gene IDs: ['TRIM10', 'MYH10', 'PTPRN', 'GCSAM', 'DYTN', 'ZFP91', 'GLDN', 'BNIP2']
Test batch gene IDs: ['HSPBAP1', 'ELMOD3', 'NEIL1', 'KIF12', 'IFT43', 'MASTL', 'TONSL', 'DDRGK1']


### step3 Implementation of the embedding initialization Method

In [143]:
class get_initial_embedding(nn.Module):
    """Get intial embeddings for the new word from bio-gpt"""
    def __init__(self, model_name=config.bio_model_name, cls_token=config.cls_token):
        super().__init__()
        self.bio_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.bio_model = AutoModelForCausalLM.from_pretrained(model_name)
        self.bio_model.eval()
        for p in self.bio_model.parameters():
            p.requires_grad = False

        ## HERE Adding the cls token
        self.cls_token = cls_token
        if self.cls_token not in self.bio_tokenizer.get_vocab():
            self.bio_tokenizer.add_tokens([self.cls_token])
            self.bio_model.resize_token_embeddings(len(self.bio_tokenizer))

        self.max_length = self.bio_model.config.max_position_embeddings

    def forward(self, gene_ids: List[str]) -> torch.Tensor:
        # print(len(gene_ids))
        descriptions = get_gene_text(gene_ids)
        # print(len(descriptions))

        ## Adding the <cls> token at the start
        input_texts = [f"{self.cls_token} {desc}" for desc in descriptions]
        inputs = self.bio_tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True) # Pad to the length of the longest sequence in this batch.

        # (B, L=No. of embedding in the alias+description)
        input_ids = inputs["input_ids"]
        # print(input_ids.shape)

        # (B, L)
        attention_mask = inputs["attention_mask"]


        with torch.no_grad():
            outputs = self.bio_model.biogpt(input_ids=input_ids, attention_mask=attention_mask)
            # (B, L, d'=embd_dimof_biogpt)
            hidden_states = outputs.last_hidden_state
        # print(hidden_states.shape)

        # Get the CLS embedding = the first token (<cls>)
        cls_embedding = hidden_states[:, 0, :]  # (B, d')
        return cls_embedding



In [144]:
initial_embedding_model = get_initial_embedding()

In [None]:
es = initial_embedding_model(["ZNF638", "APOC4-APOC2"])

torch.Size([2, 96, 1024])


In [None]:
es.shape

torch.Size([3, 1024])

## STEP 4

**dim mapping and Alignment by the f and W**

In [None]:
## W Linear Layer to map the embedding dim of the Biogpt to the embedding dim of the mistral 7b
class WLayer(nn.Module):
    def __init__(self, input_dim: int=config.bio_gpt_embedding_dim, output_dim: int=config.mistral7b_embedding_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

In [None]:
## fθ [resnet style block]
class ResNetBlock(nn.Module):
    def __init__(self, dim: int=config.mistral7b_embedding_dim):
        super().__init__()
        self.W_layer = WLayer()
        self.linear1 = nn.Linear(dim, dim)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(dim, dim)

    def forward(self, x):
        ## (B, d') => (B, d)
        x = self.W_layer(x)

        residual = x
        out = self.linear1(x)
        out = self.activation(out)
        out = self.linear2(out)

        ## (B, d)
        return out + residual

## STEP 5

**Keeping all the things TOGETHER...**

In [1]:
!huggingface-cli login
!pip install -U bitsandbytes
!pip install flash-attn --no-build-isolation



    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) y
Token is valid (permission: read).
The token `LLM_model_token` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-aut

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(model_name,  token="####")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token="####",
    load_in_8bit=True, # This line requires bitsandbytes
    use_flash_attention_2=True,
    torch_dtype=torch.float16,
    )

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [3]:
model.config

MistralConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "quantization_config": {
    "_load_in_4bit": false,
    "_load_in_8bit": true,
    "bnb_4bit_compute_dtype": "float32",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "fp4",
    "bnb_4bit_use_double_quant": false,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": false,
    "load_in_8bit": true,
    "quant_method": "bitsandbytes"
  },
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "sliding_window"