Mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Install Necessary Libraries

In [3]:
!pip install pandas torch transformers scikit-learn hdbscan



Upload and Prepare Your Dataset

In [2]:
import pandas as pd
data_path = '/content/drive/MyDrive/Sih/synthetic_edna_2.csv'
df = pd.read_csv(data_path)

Data Preparation in Colab (Recap)

In [4]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer

# 1. Load the dataset
data_path = '/content/drive/MyDrive/Sih/synthetic_edna_2.csv'
df = pd.read_csv(data_path)

# 2. Filter for 'known' taxa for fine-tuning
known_taxa_df = df[df['Taxonomic_Level'] == 'species'].copy()

# 3. Define the k-mer conversion function
def seq_to_kmer(sequence, k=6):
    kmers = []
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i+k])
    return ' '.join(kmers)

# 4. Apply k-mer conversion
known_taxa_df['kmer_sequence'] = known_taxa_df['Sequence'].apply(seq_to_kmer)

# 5. Create a mapping from species names to numerical labels
species_list = known_taxa_df['Taxonomy'].unique()
species_to_label = {species: i for i, species in enumerate(species_list)}
label_to_species = {i: species for i, species in enumerate(species_list)}
known_taxa_df['label'] = known_taxa_df['Taxonomy'].map(species_to_label)

# 6. Split the known data for training and evaluation
train_df, eval_df = train_test_split(known_taxa_df, test_size=0.2, random_state=42, stratify=known_taxa_df['label'])

print(f"Training on {len(train_df)} sequences, evaluating on {len(eval_df)} sequences.")
print(f"Number of unique species (classes): {len(species_list)}")

Training on 5179 sequences, evaluating on 1295 sequences.
Number of unique species (classes): 56


Define a Custom PyTorch Dataset

In [6]:
import torch
from torch.utils.data import Dataset

class DnaDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.texts = df['kmer_sequence'].tolist()
        self.labels = df['label'].tolist()
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Tokenize the k-mer sequence
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        # The Trainer expects input_ids and attention_mask
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

Load Model, Configure Trainer, and Start Fine-tuning

In [7]:
from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertTokenizer
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np

# A. Load the pre-trained DNABERT model and tokenizer
# The tokenizer must match the k-mer size (k=6 in this case)
tokenizer = BertTokenizer.from_pretrained('zhihan1996/DNA_bert_6')

# The model is loaded with a classification head on top.
# num_labels tells it how many species to predict.
num_labels = len(species_list)
model = BertForSequenceClassification.from_pretrained('zhihan1996/DNA_bert_6', num_labels=num_labels)

# B. Create the custom datasets
train_dataset = DnaDataset(train_df, tokenizer)
eval_dataset = DnaDataset(eval_df, tokenizer)

# C. Define a function to compute evaluation metrics
def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    # The classification report provides all the metrics you mentioned (F1-score, accuracy)
    print(classification_report(p.label_ids, predictions, target_names=species_list))
    return {
        'accuracy': accuracy_score(p.label_ids, predictions),
        'f1_score_macro': f1_score(p.label_ids, predictions, average='macro'),
    }

# D. Set up the Training Arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=15,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    report_to="none",  # <-- Add this line
)

# E. Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics
)

# F. Start the fine-tuning process!
print("Starting fine-tuning...")
trainer.train()

# G. Save the final model to your Google Drive for later use
trainer.save_model('/content/drive/MyDrive/Edna_Project/fine_tuned_model_2')
tokenizer.save_pretrained('/content/drive/MyDrive/Edna_Project/fine_tuned_model_2')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/359M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNA_bert_6 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/359M [00:00<?, ?B/s]

Starting fine-tuning...


Epoch,Training Loss,Validation Loss,Accuracy,F1 Score Macro
1,No log,3.445904,0.035521,0.00245
2,3.590500,3.407868,0.041699,0.004452
3,3.590500,3.399958,0.035521,0.00357
4,3.418100,3.394344,0.039382,0.003955
5,3.405000,3.398983,0.033205,0.006057


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.04      1.00      0.07        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.04      1.00      0.07        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.08      0.12      0.10        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch,Training Loss,Validation Loss,Accuracy,F1 Score Macro
1,No log,3.445904,0.035521,0.00245
2,3.590500,3.407868,0.041699,0.004452
3,3.590500,3.399958,0.035521,0.00357
4,3.418100,3.394344,0.039382,0.003955
5,3.405000,3.398983,0.033205,0.006057
6,3.405000,3.39525,0.03861,0.009103
7,3.385800,3.405193,0.042471,0.015074
8,3.338200,3.424877,0.041699,0.017052
9,3.338200,3.444811,0.033977,0.01369
10,3.250700,3.489074,0.036293,0.020267


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.33      0.03      0.05        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.12      0.05      0.07        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.00      0.00      0.00        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.33      0.03      0.05        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.02      0.04      0.03        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.11      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.04      1.00      0.07        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.06      0.08      0.07        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.12      0.05      0.07        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.07      0.05      0.06        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.02      0.24      0.04        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.09      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.03      0.04      0.04        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.02      0.04      0.03        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.10      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.04      0.04      0.04        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.08      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.06      0.09      0.07        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.08      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.00      0.00      0.00        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                         precision    recall  f1-score   support

      Bathyal_Annelid_A       0.00      0.00      0.00        24
      ColdWater_Coral_C       0.00      0.00      0.00        20
    Bathypelagic_Fish_B       0.00      0.00      0.00        25
      ColdWater_Coral_A       0.00      0.00      0.00        19
    Bathypelagic_Fish_A       0.00      0.00      0.00        23
           Amphipod_sp1       0.09      0.05      0.06        22
      DeepSea_Protist_D       0.00      0.00      0.00        22
   Microzooplankton_sp2       0.00      0.00      0.00        23
   Echinoderm_Benthic_A       0.00      0.00      0.00        40
       Foraminifera_sp2       0.00      0.00      0.00        24
      ColdWater_Coral_B       0.00      0.00      0.00        24
      Bathyal_Annelid_C       0.00      0.00      0.00        25
      Deep_Cephalopod_B       0.04      0.04      0.04        23
         Abyssal_Crab_C       0.00      0.00      0.00        21
           SeaSnail_sp1 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


('/content/drive/MyDrive/Edna_Project/fine_tuned_model_2/tokenizer_config.json',
 '/content/drive/MyDrive/Edna_Project/fine_tuned_model_2/special_tokens_map.json',
 '/content/drive/MyDrive/Edna_Project/fine_tuned_model_2/vocab.txt',
 '/content/drive/MyDrive/Edna_Project/fine_tuned_model_2/added_tokens.json')

Generate Embeddings

In [None]:
import torch
import pandas as pd
import numpy as np
from transformers import BertModel, BertTokenizer
import hdbscan
from tqdm import tqdm

# --- Part 1: Load and prepare the data ---
# Load the provided CSV file into a DataFrame
df = pd.read_csv('/content/drive/MyDrive/Sih/synthetic_edna_2.csv')

# Define a custom k-mer conversion function
def seq_to_kmer(sequence, k=6):
    kmers = []
    if not isinstance(sequence, str):
        return ""
    for i in range(len(sequence) - k + 1):
        kmers.append(sequence[i:i+k])
    return ' '.join(kmers)

# Apply k-mer conversion to the entire DataFrame
df['kmer_sequence'] = df['Sequence'].apply(seq_to_kmer)

# --- Part 2: Load the fine-tuned model and tokenizer ---
model_save_path = '/content/drive/MyDrive/Edna_Project/fine_tuned_model'

print("Loading fine-tuned model and tokenizer for embedding generation...")
model_for_embeddings = BertModel.from_pretrained(model_save_path)
tokenizer_for_embeddings = BertTokenizer.from_pretrained(model_save_path)
model_for_embeddings.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_for_embeddings.to(device)

# --- Part 3: Define a function to generate embeddings ---
def get_embeddings_batched(df, model, tokenizer, batch_size=32):
    model.eval()
    embeddings = []
    sequences_list = df['kmer_sequence'].tolist()

    with torch.no_grad():
        for i in tqdm(range(0, len(sequences_list), batch_size)):
            batch_sequences = sequences_list[i:i+batch_size]

            encoded_inputs = tokenizer(
                batch_sequences,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=512
            )

            encoded_inputs = {key: val.to(device) for key, val in encoded_inputs.items()}

            outputs = model(**encoded_inputs)

            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_embeddings)

    return np.vstack(embeddings)

# --- Part 4: Run the function and store the embeddings ---
print("\nGenerating embeddings for all sequences...")
all_embeddings = get_embeddings_batched(df, model_for_embeddings, tokenizer_for_embeddings)
df['embeddings'] = list(all_embeddings)

print("\nEmbeddings generated and stored in the DataFrame.")
print("Shape of the embeddings array:", all_embeddings.shape)

# --- Step 2: Unsupervised Clustering (HDBSCAN) ---
print("\n--- Performing HDBSCAN clustering on unassigned sequences ---")
unassigned_df = df[df['Taxonomy'] == 'Unassigned'].copy()

# Extract the embeddings for only the unassigned sequences
unassigned_embeddings_array = np.vstack(unassigned_df['embeddings'].tolist())

# HDBSCAN parameters can be tuned. min_cluster_size=10 is a good start.
clusterer = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=5)
unassigned_df['novel_cluster'] = clusterer.fit_predict(unassigned_embeddings_array)

print(f"\nHDBSCAN found {unassigned_df['novel_cluster'].nunique()} novel clusters.")
print("The 'unassigned' sequences are now labeled with cluster IDs.")

# --- Step 3: Calculate Biodiversity Metrics ---
# Combine the known and novel taxa for final biodiversity analysis
final_df = df.copy()

# Map the novel cluster IDs back to the original DataFrame
final_df.loc[final_df['Taxonomy'] == 'Unassigned', 'Taxonomy'] = 'Novel_Taxa_' + unassigned_df['novel_cluster'].astype(str)

# Remove reference sequences
final_df = final_df[final_df['Taxonomic_Level'] != 'reference']

print("\n--- Calculating Biodiversity Metrics (Shannon & Simpson Indices) ---")

# Calculate abundance for each location
abundance_by_location = final_df.groupby(['Sample_ID', 'Taxonomy'])['Read_Count'].sum().unstack(fill_value=0)

# Shannon Index (measures richness and evenness)
def shannon_index(counts):
    p = counts / counts.sum()
    p = p[p > 0] # Remove zero probabilities
    return -np.sum(p * np.log(p))

# Simpson Index (measures dominance)
def simpson_index(counts):
    p = counts / counts.sum()
    return np.sum(p**2)

# Apply the functions to each location's abundance counts
diversity_metrics = pd.DataFrame({
    'Shannon_Index': abundance_by_location.apply(shannon_index, axis=1),
    'Simpson_Index': abundance_by_location.apply(simpson_index, axis=1),
})

print(diversity_metrics.to_markdown(numalign="left", stralign="left"))

print("\nAll technical steps complete. The data is now ready for dashboard visualization!")