In [None]:
! pip install pandas ollama datetime Bio google transformers



In [None]:
import pandas as pd
import os
from Bio import Entrez, SeqIO
import urllib.request
import gzip
from io import BytesIO
import requests
from bs4 import BeautifulSoup
from googlesearch import search
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from itertools import product
import difflib
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import transformers_modules


In [None]:
# Reads sequence database into a dataframe

Entrez.email = "your_email@example.com"

SEQUENCE_PATH = 'op/sequences_2.csv'

df = pd.read_csv(SEQUENCE_PATH)

print(df.head())

print(len(df))

          Assembly               Organism_Name                     Species  \
0  GCF_029888495.1                  Yezo virus    Orthonairovirus yezoense   
1  GCF_029888155.1            Beiji nairovirus            Beiji nairovirus   
2  GCF_029888075.1              Songling virus              Songling virus   
3  GCF_018595055.1  Dar es Salaam virus TZ-189  Tanzavirus daressalaamense   
4  GCF_013086615.1                 Cocle virus       Phlebovirus cocleense   

             Genus         Family  
0  Orthonairovirus   Nairoviridae  
1  Orthonairovirus   Nairoviridae  
2  Orthonairovirus   Nairoviridae  
3       Tanzavirus  Phenuiviridae  
4      Phlebovirus  Phenuiviridae  
70145


In [None]:
# Filter to only 1 result per species

filtered_df = df.groupby('Species', group_keys=False).head(1)

print(filtered_df.head())

print(len(filtered_df))

          Assembly               Organism_Name                     Species  \
0  GCF_029888495.1                  Yezo virus    Orthonairovirus yezoense   
1  GCF_029888155.1            Beiji nairovirus            Beiji nairovirus   
2  GCF_029888075.1              Songling virus              Songling virus   
3  GCF_018595055.1  Dar es Salaam virus TZ-189  Tanzavirus daressalaamense   
4  GCF_013086615.1                 Cocle virus       Phlebovirus cocleense   

             Genus         Family  
0  Orthonairovirus   Nairoviridae  
1  Orthonairovirus   Nairoviridae  
2  Orthonairovirus   Nairoviridae  
3       Tanzavirus  Phenuiviridae  
4      Phlebovirus  Phenuiviridae  
203


In [None]:
# Function for getting viral data from the accession number

Entrez.email = "sir.peepster@gmail.com"

def get_viral_data(assembly, genome_dir="op/sequences", annotation_dir="op/annotations"):
    '''
    Given an assembly accession (e.g. GCF_029888495.1),
    fetches and parses the genome and annotation data from NCBI,
    and saves them into separate folders.

    Returns a dict with:
    - genome: str (raw sequence)
    - annotations: list of dicts
    - paths to saved files
    '''
    os.makedirs(genome_dir, exist_ok=True)
    os.makedirs(annotation_dir, exist_ok=True)

    genome_txt = os.path.join(genome_dir, f"{assembly}.txt")
    annotation_csv = os.path.join(annotation_dir, f"{assembly}.csv")

    # Skip if both files already exist
    if os.path.exists(genome_txt) and os.path.exists(annotation_csv):
        print(f"[SKIP] {assembly} already processed.")
        return {
            "genome_file": genome_txt,
            "annotation_file": annotation_csv,
            "skipped": True
        }

    # Search assembly UID
    search_handle = Entrez.esearch(db="assembly", term=assembly, retmode="xml")
    search_results = Entrez.read(search_handle)
    search_handle.close()

    if not search_results['IdList']:
        raise ValueError(f"Assembly {assembly} not found.")

    uid = search_results['IdList'][0]

    # Get FTP path
    summary_handle = Entrez.esummary(db="assembly", id=uid, retmode="xml")
    summary = Entrez.read(summary_handle)
    summary_handle.close()

    doc = summary['DocumentSummarySet']['DocumentSummary'][0]
    ftp_path = doc['FtpPath_RefSeq'] or doc['FtpPath_GenBank']
    if not ftp_path:
        raise ValueError(f"No FTP path found for {assembly}")

    base = ftp_path.split("/")[-1]
    gb_url = f"{ftp_path}/{base}_genomic.gbff.gz"
    print(f"[INFO] Downloading: {gb_url}")

    # Download and parse GenBank
    with urllib.request.urlopen(gb_url) as response:
        with gzip.open(BytesIO(response.read()), "rt") as handle:
            records = list(SeqIO.parse(handle, "genbank"))

    if not records:
        raise ValueError("No GenBank records found.")

    record = records[0]

    # Save genome
    with open(genome_txt, "w") as f:
        f.write(f">{record.id}\n{record.seq}")

    # Parse annotations
    annotations = []
    for feature in record.features:
        feat = {
            "type": feature.type,
            "location": str(feature.location),
        }
        for key, value in feature.qualifiers.items():
            feat[key] = "; ".join(value)
        annotations.append(feat)

    annotation_df = pd.DataFrame(annotations)
    annotation_df.to_csv(annotation_csv, index=False)

    return {
        "genome": str(record.seq),
        "annotations": annotations,
        "genome_file": genome_txt,
        "annotation_file": annotation_csv,
        "genbank_url": gb_url,
        "skipped": False
    }


In [None]:
# Playground for viral data

demo_result = get_viral_data('GCF_008711635.1')

[SKIP] GCF_008711635.1 already processed.


In [None]:
testing_csv = pd.read_csv('op/annotations/GCF_008711635.1.csv')

print(testing_csv.head())

          type     location       organism     mol_type  \
0       source  [0:7525](+)  Norovirus GII  genomic RNA   
1        5'UTR     [0:4](+)            NaN          NaN   
2         gene  [4:5095](+)            NaN          NaN   
3          CDS  [4:5095](+)            NaN          NaN   
4  mat_peptide  [4:1000](+)            NaN          NaN   

                                 isolate isolation_source          host  \
0  Hu/GII.PNA4-GII.NA4/PNV06929/2008/PER            stool  Homo sapiens   
1                                    NaN              NaN           NaN   
2                                    NaN              NaN           NaN   
3                                    NaN              NaN           NaN   
4                                    NaN              NaN           NaN   

           db_xref country collection_date       genotype  gene  locus_tag  \
0     taxon:122929    Peru     27-May-2008  GII.NA2[PNA2]   NaN        NaN   
1              NaN     NaN            

In [None]:
for index, row in filtered_df.iterrows():
    print(row['Assembly'])
    get_viral_data(row['Assembly'])

GCF_029888495.1
[SKIP] GCF_029888495.1 already processed.
GCF_029888155.1
[SKIP] GCF_029888155.1 already processed.
GCF_029888075.1
[SKIP] GCF_029888075.1 already processed.
GCF_018595055.1
[SKIP] GCF_018595055.1 already processed.
GCF_013086615.1
[SKIP] GCF_013086615.1 already processed.
GCF_013086545.1
[SKIP] GCF_013086545.1 already processed.
GCF_006298385.1
[SKIP] GCF_006298385.1 already processed.
GCF_029887105.1
[SKIP] GCF_029887105.1 already processed.
GCF_024749945.1
[SKIP] GCF_024749945.1 already processed.
GCF_023122845.1
[SKIP] GCF_023122845.1 already processed.
GCF_018591295.1
[SKIP] GCF_018591295.1 already processed.
GCF_018584815.1
[SKIP] GCF_018584815.1 already processed.
GCF_018583805.1
[SKIP] GCF_018583805.1 already processed.
GCF_018583635.1
[SKIP] GCF_018583635.1 already processed.
GCF_018583605.1
[SKIP] GCF_018583605.1 already processed.
GCF_018580895.1
[SKIP] GCF_018580895.1 already processed.
GCF_018580825.1
[SKIP] GCF_018580825.1 already processed.
GCF_013087445.

In [None]:
import ollama
import re
import inspect
from ollama import ChatResponse
from datetime import datetime

def get_time():
    return datetime.utcnow().isoformat(timespec='microseconds') + 'Z'

def timestamp_to_datetime(timestamp: str):
    # Convert ISO 8601 timestamp with 7-digit precision
    return datetime.strptime(timestamp[:-1], "%Y-%m-%dT%H:%M:%S.%f")

def time_difference_in_ns(timestamp1: str, timestamp2: str):
    dt1 = timestamp_to_datetime(timestamp1)
    dt2 = timestamp_to_datetime(timestamp2)

    # Calculate the difference in seconds and convert to nanoseconds
    delta_ns = int((dt2 - dt1).total_seconds() * 1e9)
    return delta_ns

def default_format_tool_instructions(tool_name: str, tool: callable, instructions: str):
    # Extract parameter details from the tool
    signature = inspect.signature(tool)
    param_list = []

    for name, param in signature.parameters.items():
        if param.default is inspect.Parameter.empty:
            param_list.append(name)  # Required parameter
        else:
            param_list.append(f"{name}={param.default}")  # Optional parameter with default

    # Join parameters for formatting
    param_str = ', '.join(param_list)

    # Return formatted instructions
    return f'''You have been given access to the tool "{tool_name}({param_str})". {instructions}
    Please call this function by using the following format: <call tool>{tool_name}({param_str})</call tool>'''

class Agent():
    def __init__(self, model: str):
        '''
        Initiate the model with the model used and
        blank memory and debug log.
        '''
        self.client = ollama.Client(host='http://host.docker.internal:11434')
        self.model = model
        self.memory = []
        self.log = [{
            'time': get_time(),
            'action': '__init__',
            'prompt': model,
            'other': {}
        }] # time, action, prompt, other
        self.tools = {}
        self.format_tool_instructions = default_format_tool_instructions

    def add_tool(self, tool_name: str, tool: callable, instructions: str = None):
        self.tools[tool_name] = {}
        self.tools[tool_name]['tool'] = tool
        self.tools[tool_name]['instructions'] = instructions
        self.log.append({
            'time': get_time(),
            'action': 'add_tool',
            'prompt': tool_name,
            'other': {
                'instructions': instructions
            }
        })
        if instructions is not None:
            self.sys_prompt(self.format_tool_instructions(tool_name, tool, instructions))

    def sys_prompt(self, sys_prompt: str):
        self.memory.append(
            {
                'role': 'system',
                'content': sys_prompt
            }
        )
        self.log.append({
            'time': get_time(),
            'action': 'sys_prompt',
            'prompt': sys_prompt,
            'other': {}
        })

    def chat(self, prompt: str):
        self.memory.append(
            {
                'role': 'user',
                'content': prompt
            }
        )

        start_time = get_time()
        user_content = ''

        self.log.append({
            'time': start_time,
            'action': 'chat',
            'prompt': prompt,
            'other': {}
        })

        while True:
            content = ''
            stream: ChatResponse = self.client.chat(model=self.model,messages=self.memory, stream=True)

            for chunk in stream:
                content += chunk.message.content
                tool_match = re.search(r'<call tool>(.*?)</call tool>', content)
                if tool_match:
                    tool_call = tool_match.group(1)
                    break  # Stop stream when tool use is detected

            user_content += re.sub(r"<call tool>.*?</call tool>", "", content).strip()
            self.memory.append({'role': 'assistant', 'content': content})

            if tool_match:
                tool_name, *params = self._extract_tool_call(tool_call)

                # Handle commas inside strings
                joined_params = ','.join(params)
                params = re.split(r',(?=(?:[^"]*"[^"]*")*[^"]*$)', joined_params)

                # Inject tool result back into memory
                try:
                    tool_result = f"<tool return result>{self.tools[tool_name]['tool'](*params)}</tool return result>"
                except Exception as e:
                    if tool_name not in self.tools:
                        tool_result = f"<tool return result>Error: Tool not defined.</tool return result>"
                    else:
                        tool_result = f"<tool return result>Error: {type(e)}</tool return result>"

                self.log.append({
                    'time': get_time(),
                    'action': 'tool',
                    'prompt': tool_call,
                    'other': {
                        'tool_result': tool_result,
                        'tool_name': tool_name,
                        'params': params,
                    }
                })
                self.sys_prompt(tool_result)
            else:
                break  # No more tool calls, exit loop
        end_time = get_time()

        self.log.append({
            'time': end_time,
            'action': 'chat_end',
            'prompt': prompt,
            'other': {
                'total_duration': time_difference_in_ns(start_time, end_time)
            }
        })

        return user_content

    def get_memory(self):
        return self.memory

    def get_tools(self):
        return self.tools

    def set_memory(self, memory):
        self.memory = memory

    def get_log(self):
        return self.log

    def _add_to_log(self, time, action, prompt, other):
        log_entry = {
            'time': time,
            'action': action,
            'prompt': prompt,
            'other': other
        }
        self.log.append(log_entry)

    def _extract_tool_call(self, tool_call: str):
        """
        Extract tool name and parameters.
        Example Input: 'fibonnaci(5)'
        Output: ('fibonacci', ('5',))
        """
        tool_match = re.match(r'(\w+)\((.*?)\)', tool_call.strip())
        if tool_match:
            tool_name = tool_match.group(1)
            params = tuple(map(str.strip, tool_match.group(2).split(',')))
            return tool_name, *params
        return tool_call, ()




In [None]:
def get_symptom_info(species: str, limit: int = 5000) -> str:
    '''
    Retrieve symptom information for a given virus species by searching:
    1. Wikipedia
    2. ICTV (International Committee on Taxonomy of Viruses)
    3. Scientific literature (e.g., PubMed)
    If no information is found, indicate that no symptom data is available.
    '''
    # Define search queries for different sources
    queries = [
        f"{species} symptoms site:en.wikipedia.org",
        f"{species} symptoms site:ictv.global",
        f"{species} symptoms site:pubmed.ncbi.nlm.nih.gov"
    ]

    for query in queries:
        try:
            # Perform a Google search and retrieve the first result
            search_results = search(query, num=1, stop=1, pause=2)
            for url in search_results:
                # Fetch the content of the URL
                response = requests.get(url)
                if response.status_code == 200:
                    # Parse the webpage content
                    soup = BeautifulSoup(response.text, 'html.parser')
                    paragraphs = soup.find_all('p')
                    content = ' '.join([para.get_text() for para in paragraphs])
                    return content[1:2500] + """Remember that you are an AI agent designed to determine disease symptoms. Make sure you respond in the following format: ```
- symptom1: severity1
- symptom2: severity2
...

Use a 0 to 1 severity scale:
- 0 = symptom not present
- 0.33 ≈ mild presence
- 0.67 ≈ moderate presence
- 1 = severe symptom
Only use the number. Do not include the description.

If a symptom is very specific, ignore it. If a symptom is more than a few words or describes in great detail, just use a simple one-word description. USE ONE WORD DESCRIPTIONS WHENEVER POSSIBLE
```"""
        except Exception as e:
            print(f"An error occurred while searching with query '{query}': {e}")
            continue

    # If no information is found
    return f"No symptom information found for '{species}'. Treat this virus as having no known symptoms."

def extract_symptom_dict(text: str) -> dict:
    symptom_dict = {}

    # Find lines with a colon, typically symptoms
    lines = text.splitlines()
    for line in lines:
        if ':' in line:
            try:
                # Split at the first colon
                symptom, severity = line.split(':', 1)
                symptom = symptom.strip().lstrip("- ").strip()

                # Extract the first float number from severity string
                match = re.search(r"\d*\.?\d+", severity)
                if match:
                    severity_score = float(match.group())
                    symptom_dict[symptom] = severity_score
            except Exception as e:
                continue

    return symptom_dict


In [None]:
disease_classifier = Agent('phi4')

disease_classifier.add_tool("get_symptom_info", get_symptom_info, "Retrieves the symptom information from various websites for a given virus species.")
disease_classifier.sys_prompt("""
You are an expert virus-to-symptom assistant. Given a virus species name, your task is to extract a list of symptoms and their severity based on available information.

You should call the tool `get_symptom_info` to search for it before coming to a conclusion.

The user will give you the name of a virus. When responding, always return your output in this format:

```
- symptom1: severity1
- symptom2: severity2
...
```

Use a 0 to 1 severity scale:
- 0 = symptom not present
- 0.33 ≈ mild presence
- 0.67 ≈ moderate presence
- 1 = severe symptom
Only use the number. Do not include the description.

If no symptoms are found or the virus is asymptomatic, return:
```
```

Be concise and do NOT include unnecessary explanations. Focus only on the symptoms and severity and use ONLY common medical terms to describe symptoms.

ALWAYS return your output in this format. You MUST use the following format:
```
- symptom1: severity1
- symptom2: severity2
...
```

If a symptom is very specific, ignore it. If a symptom is more than a few words or describes in great detail, just use a simple one-word description (a precise medical term is preferable).
Make sure you use the tool!""")

classifier_memory = disease_classifier.get_memory()
classifier_tools = disease_classifier.get_tools()

print(classifier_memory)
print(classifier_tools)

[{'role': 'system', 'content': 'You have been given access to the tool "get_symptom_info(species, limit=5000)". Retrieves the symptom information from various websites for a given virus species.\n    Please call this function by using the following format: <call tool>get_symptom_info(species, limit=5000)</call tool>'}, {'role': 'system', 'content': '\nYou are an expert virus-to-symptom assistant. Given a virus species name, your task is to extract a list of symptoms and their severity based on available information.\n\nYou should call the tool `get_symptom_info` to search for it before coming to a conclusion.\n\nThe user will give you the name of a virus. When responding, always return your output in this format:\n\n```\n- symptom1: severity1\n- symptom2: severity2\n...\n```\n\nUse a 0 to 1 severity scale:\n- 0 = symptom not present\n- 0.33 ≈ mild presence\n- 0.67 ≈ moderate presence\n- 1 = severe symptom\nOnly use the number. Do not include the description.\n\nIf no symptoms are found

In [None]:
testing = Agent('phi4')

testing.set_memory(classifier_memory)
testing.tools = classifier_tools

chat_response = testing.chat("Ebola")
print(chat_response)
print(extract_symptom_dict(chat_response))
[print(item) for item in testing.log]

- fever: 1
- headache: 0.67
- muscle pain: 0.67
- sore throat: 0.33
- vomiting: 0.67
- diarrhea: 0.67
- rash: 0.67
- decreased liver function: 0.67
- decreased kidney function: 0.67
- bleeding (internal/external): 1
{'fever': 1.0, 'headache': 0.67, 'muscle pain': 0.67, 'sore throat': 0.33, 'vomiting': 0.67, 'diarrhea': 0.67, 'rash': 0.67, 'decreased liver function': 0.67, 'decreased kidney function': 0.67, 'bleeding (internal/external)': 1.0}
{'time': '2025-03-29T22:03:39.877710Z', 'action': '__init__', 'prompt': 'phi4', 'other': {}}
{'time': '2025-03-29T22:03:39.877821Z', 'action': 'chat', 'prompt': 'Ebola', 'other': {}}
{'time': '2025-03-29T22:03:49.773022Z', 'action': 'tool', 'prompt': 'get_symptom_info(Ebola, limit=5000)', 'other': {'tool_result': '<tool return result> Ebola, also known as Ebola virus disease (EVD) and Ebola hemorrhagic fever (EHF), is a viral hemorrhagic fever in humans and other primates, caused by ebolaviruses.[1] Symptoms typically start anywhere between two da

[None, None, None, None, None]

In [None]:
training_df = filtered_df.copy()

print(training_df.head())

symptom_data = []
all_symptoms = set()

for _, row in tqdm(training_df.iterrows(), total=len(training_df)):
    species = row['Organism_Name']

    # Use the agent to get a response
    try:
        disease_classifier_copy = Agent('phi4')
        disease_classifier_copy.set_memory(classifier_memory)
        disease_classifier_copy.tools = classifier_tools

        response = disease_classifier_copy.chat(species)

        symptom_dict = extract_symptom_dict(response)
    except Exception as e:
        print(f"[WARN] Failed on {species}: {e}")
        symptom_dict = {}

    # Track all symptoms seen so far
    all_symptoms.update(symptom_dict.keys())

    # Add organism + its symptom dictionary
    symptom_data.append({
        "Organism_Name": species,
        **symptom_dict
    })

# Create the symptom DataFrame
symptom_df = pd.DataFrame(symptom_data)

# Fill missing symptoms with 0 (not mentioned)
for symptom in all_symptoms:
    if symptom not in symptom_df.columns:
        symptom_df[symptom] = 0

# Optional: reorder columns
symptom_columns = sorted(list(all_symptoms))
final_df = symptom_df[['Organism_Name'] + symptom_columns]


          Assembly               Organism_Name                     Species  \
0  GCF_029888495.1                  Yezo virus    Orthonairovirus yezoense   
1  GCF_029888155.1            Beiji nairovirus            Beiji nairovirus   
2  GCF_029888075.1              Songling virus              Songling virus   
3  GCF_018595055.1  Dar es Salaam virus TZ-189  Tanzavirus daressalaamense   
4  GCF_013086615.1                 Cocle virus       Phlebovirus cocleense   

             Genus         Family  
0  Orthonairovirus   Nairoviridae  
1  Orthonairovirus   Nairoviridae  
2  Orthonairovirus   Nairoviridae  
3       Tanzavirus  Phenuiviridae  
4      Phlebovirus  Phenuiviridae  


100%|██████████| 203/203 [07:59<00:00,  2.36s/it]


In [None]:
# print(final_df.head())

# display influenza symptoms (testcase)

final_df.fillna(0, inplace=True)

row = final_df.loc[final_df['Organism_Name'] == 'Bundibugyo ebolavirus']

if not row.empty:
    for column, value in row.iloc[0].items():
        if value != 0:
          print(f"{column}: {value}")
else:
    print("No match found.")



Organism_Name: Bundibugyo ebolavirus
diarrhea: 1.0
fever: 1.0
headache: 1.0
muscle pain: 1.0
rash: 0.67
vomiting: 1.0
weakness: 1.0


In [None]:
final_df = symptom_df.merge(
    training_df[["Organism_Name", "Assembly"]],
    on="Organism_Name",
    how="left"  # or "inner" if you only want overlapping ones
)

final_df.fillna(0, inplace=True)
print(final_df)

                  Organism_Name  fever  headache  dizziness  blurred vision  \
0                    Yezo virus   0.67      0.33       0.33            0.33   
1              Beiji nairovirus   0.33      0.00       0.00            0.00   
2                Songling virus   0.67      0.33       0.00            0.00   
3    Dar es Salaam virus TZ-189   0.00      0.00       0.00            0.00   
4                   Cocle virus   0.33      0.00       0.00            0.00   
..                          ...    ...       ...        ...             ...   
198      Human mastadenovirus C   0.33      0.00       0.00            0.00   
199       Human erythrovirus V9   0.33      0.00       0.00            0.00   
200    Human gammaherpesvirus 8   0.00      0.00       0.00            0.00   
201           Influenza B virus   0.00      0.00       0.00            0.00   
202           Influenza C virus   0.00      0.00       0.00            0.00   

     shortness of breath  fatigue  arthralgia  thro

In [334]:
import os
import torch
import pandas as pd
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForMaskedLM


# ------------------------------------------------
# 0) Make sure final_df is defined or loaded here.
# final_df = ...
# ------------------------------------------------

non_numeric = ['Organism_Name', 'Assembly']
numeric_cols = [col for col in final_df.columns if col not in non_numeric]

lowercase_map = {}
for col in numeric_cols:
    key = col.lower()
    lowercase_map.setdefault(key, []).append(col)

agg_data = {}
for key, cols in lowercase_map.items():
    agg_data[key] = final_df[cols].sum(axis=1)

df_agg = pd.concat([final_df[non_numeric], pd.DataFrame(agg_data)], axis=1)

def load_genome_from_assembly(assembly):
    filepath = os.path.join("op", "sequences", assembly + ".txt")
    try:
        with open(filepath, 'r') as f:
            lines = f.readlines()
        return lines[1].strip() if len(lines) > 1 else ""
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return ""

df_agg["Genome"] = df_agg["Assembly"].apply(load_genome_from_assembly)

print("Aggregated DataFrame with Genome column:")
print(df_agg.head())

symptom_cols = list(agg_data.keys())
symptom_tensor = torch.tensor(df_agg[symptom_cols].values, dtype=torch.float)

# ------------------------------------------------
# 1) Load BOTH the tokenizer and model with trust_remote_code=True
#    so DNABERT's custom code + config are used properly.
# ------------------------------------------------
MODEL_NAME = "armheb/DNA_bert_6"

# Load configuration, tokenizer, and model (this model does not require trust_remote_code)
config = AutoConfig.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, config=config)
model.eval()

def compute_transformer_embedding(sequence, max_length=512):
    """
    Convert a (nucleotide) sequence into a transformer embedding.
    For very long sequences, consider chunking or specialized handling
    because BERT-like models typically have a 512-token limit.
    """
    inputs = tokenizer(
        sequence,
        return_tensors='pt',
        max_length=max_length,
        truncation=True,
        padding='max_length'
    )
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        # Use the last hidden state from hidden_states tuple:
        last_hidden_state = outputs.hidden_states[-1]  # [batch_size, seq_len, hidden_size]
        embedding = last_hidden_state.mean(dim=1).squeeze(0)  # [hidden_size]
    return embedding

# ------------------------------------------------
# 2) Precompute embeddings for all known genomes
# ------------------------------------------------
embeddings = []
for seq in df_agg["Genome"]:
    seq = seq.strip()
    if seq:
        emb = compute_transformer_embedding(seq)
    else:
        emb = torch.zeros(model.config.hidden_size)
    embeddings.append(emb)

genome_embeddings = torch.stack(embeddings, dim=0)  # [num_genomes, hidden_size]

# ------------------------------------------------
# 3) Lookup function for the nearest embedding
# ------------------------------------------------
def get_closest_organism(input_genome):
    # Embed user input
    input_embedding = compute_transformer_embedding(input_genome)
    input_embedding_2d = input_embedding.unsqueeze(0)  # [1, hidden_size]

    # Cosine similarity
    similarities = F.cosine_similarity(input_embedding_2d, genome_embeddings, dim=1)
    best_idx = torch.argmax(similarities).item()

    organism = df_agg.loc[best_idx, 'Organism_Name']
    guess_tensor = symptom_tensor[best_idx]
    return organism, guess_tensor

Aggregated DataFrame with Genome column:
                Organism_Name         Assembly  fever  headache  dizziness  \
0                  Yezo virus  GCF_029888495.1   0.67      0.33       0.33   
1            Beiji nairovirus  GCF_029888155.1   0.33      0.00       0.00   
2              Songling virus  GCF_029888075.1   0.67      0.33       0.00   
3  Dar es Salaam virus TZ-189  GCF_018595055.1   0.67      0.33       0.00   
4                 Cocle virus  GCF_013086615.1   0.33      0.00       0.00   

   blurred vision  shortness of breath  fatigue  arthralgia  thrombocytopenia  \
0            0.33                 0.33     0.33        0.33              0.67   
1            0.00                 0.67     0.00        0.00              0.00   
2            0.00                 0.00     0.33        0.00              0.00   
3            0.00                 0.00     0.33        0.00              0.00   
4            0.00                 0.33     0.00        0.00              0.00   

   

Some weights of the model checkpoint at armheb/DNA_bert_6 were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [344]:
# ------------------------------------------------
# 4) Example usage
# ------------------------------------------------
input_genome = (
    "ATGAATATAAATCCTTATTTTCTCTTCATAGATGTGCCAGTACAGGCAGCAATTTCAACAACATTCCCATACACTGGTGTTCCCCCTTATTCCCATGGAACAGGAACAGGTTACACAATAGACACCGTGATCAGAACGCATGAGTACTCAAACAAGGGGAAACAGTACATTTCTGATGTTACAGGATGCACAATGGTGGATCCAACAAATGGACCATTACCCGAAGATAATGAGCCGAGTGCCTATGCGCAATTAGATTGCGTTTTAGAGGCTTTGGATAGAATGGATGAAGAACACCCAGGTCTTTTTCAAGCAGCCTCACAGAATGCTATGGAGGCCCTAATGGTCACAACTGTAGACAAATTAACCCAGGGGAGACAAACTTTTGATTGGACAGTATGCAGAAACCAACCTGCTGCAACGGCACTGAATACAACAATAACCTCTTTTAGGTTGAATGATTTAAATGGAGCCGACAAAGGTGGATTAATACCTTTTTGCCAGGATATCATTGATTCATTAGACAGACCTGAAATGACTTTCTTCTCAGTAAAGAATATAAAGAAAAAATTGCCTGCCAAAAACAGAAAGGGTTTCCTCATAAAGAGGATACCAATGAAGGTAAAAGACAAAATAACCAAAGTGGAATACATCAAAAGAGCATTATCATTAAACACAATGACAAAAGACGCTGAAAGAGGCAAACTGAAAAGAAGAGCGATTGCCACTGCTGGAATACAAATAAGAGGGTTTGTATTAGTAGTTGAAAACTTGGCTAAAAATATATGTGAAAATCTAGAACAAAGTGGTTTACCAGTAGGTGGAAACGAGAAGAAAGCCAAACTGTCAAATGCAGTGGCCAAAATGCTCAGTAACTGCCCACCAGGAGGGATTAGCATGACAGTAACAGGAGACAATACAAAATGGAATGAATGTTTAAACCCAAGGATCTTTTTGGCCATGACCGAAAGAATAACCAGAGACAGCCCAGTTTGGTTCAGGGATTTTTGTAGTATAGCACCGGTCCTGTTCTCCAATAAGATAGCAAGATTGGGGAAAGGATTCATGATAACAAGCAAAACAAAAAGACTAAAGGCCCAAATACCTTGTCCTGATCTGTTTAGTATACCATTAGAAAGATATAATGAAGAAACAAGGGCAAAATTGAAGAAGCTAAAACCATTCTTCAATGAAGAAGGAACTGCATCTTTGTCACCTGGGATGATGATGGGAATGTTTAATATGCTATCTACCGTATTGGGAGTAGCTGCACTAGGTATCAAGAACATTGGAAACAAAGAATACCTATGGGATGGACTGCAATCTTCTGATGATTTTGCTCTATTTGTTAATGCAAAGGATGAAGAAACATGTATGGAAGGAATAAACGACTTTTACCGAACATGTAAATTATTGGGAATAAACATGAGCAAAAAGAAAAGTTACTGTAATGAGACTGGAATGTTTGAATTTACAAGCATGTTCTACAGAGATGGATTTGTATCTAATTTTGCAATGGAACTCCCTTCGTTTGGGGTTGCTGGAGTAAATGAATCAGCAGATATGGCAATAGGAATGACAATAATAAAGAACAACATGATCAACAATGGAATGGGTCCAGCAACAGCACAAACAGCCATACAGCTATTCATAGCTGATTATAGATACACCTACAAATGCCACAGGGGAGATTCCAAAGTAGAAGGAAAGAGAATGAAAATCATAAAGGAGTTATGGGAAAACACTAAAGGAAGAGATGGTCTATTAGTAGCAGATGGTGGGCCCAACATTTACAATTTGAGAAACTTGCATATCCCAGAAATAGTATTGAAGTATAATCTAATGGACCCTGAATACAAAGGGCGATTACTTCATCCTCAAAATCCCTTTGTGGGACATTTGTCTATTGAGGGCATCAAAGAGGCAGACATAACTCCAGCACATGGTCCAGTAAAGAAAATGGACTACGATGCAGTGTCTGGAACTCATAGTTGGAGAACCAAAAGAAACAGATCTATACTAAACACTGATCAGAGGAACATGATTCTTGAGGAACAATGCTACGCTAAATGTTGCAACCTATTTGAGGCCTGTTTTAACAGTGCATCATACAGGAAGCCAGTGGGTCAACATAGCATGCTTGAGGCTATGGCCCACAGATTAAGAATGGATGCACGATTAGATTATGAATCAGGGAGAATGTCAAAAGATGATTTTGAGAAAGCAATGGCTCACCTTGGTGAGATTGGGTACATATAA"
)

try:
    organism, guess = get_closest_organism(input_genome)
    for column in df_agg.columns:
      print(column, end=", ")
    print()
    print("Symptom Tensor Row:", guess.tolist())
except ValueError as e:
    print(e)

Organism_Name, Assembly, fever, headache, dizziness, blurred vision, shortness of breath, fatigue, arthralgia, thrombocytopenia, leukopenia, lymphocytopenia, coagulation disorder, increased liver enzymes, cough, runny nose, stuffy nose, sore throat, wheezing, rash, muscle pain, (note, nausea, nasal congestion, malaise, myalgia, joint pain, swollen lymph nodes, runny or stuffy nose, diarrhea, respiratory symptoms, swelling, muscle ache, backache, chills, respiratory issues, symptom1, symptom2, myalgia (muscle pain), loss of taste or smell, runny/stuffy nose, vomiting, abdominal pain, thrombocytopenia (low platelet count), swollen glands, bleeding, organ failure, jaundice, dark urine, **cough**, **fever**, **runny nose**, **sore throat**, **fatigue**, wart, dysplasia, cancer, itching, warts, note, cognitive impairment, gastrointestinal distress, respiratory distress, lymphadenopathy, *note, neurological symptoms, conjunctivitis, fecal, viral, viral infection, nausea/vomiting, lesion, res