In [1]:
from dotenv import load_dotenv
import torch
import sys
import os


load_dotenv()  # will load from .env file in the same directory

# Then add this to check the environment:

print(f"CUDA_HOME: {os.environ.get('CUDA_HOME', 'Not set')}")


def check_cuda():
    print(f"PyTorch version: {torch.__version__}")
    print(f"Python version: {sys.version}")

    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"\nCUDA available: {cuda_available}")

    if cuda_available:
        # Get current CUDA device
        current_device = torch.cuda.current_device()
        # Get device properties
        device_props = torch.cuda.get_device_properties(current_device)

        print("\nCUDA Device Details:")
        print(f"  Device: {torch.cuda.get_device_name(current_device)}")
        print(f"  Total memory: {device_props.total_memory / 1024**3:.2f} GB")
        print(f"  CUDA capability: {device_props.major}.{device_props.minor}")
        print(f"  Number of CUDA devices: {torch.cuda.device_count()}")
    else:
        print("\nNo CUDA devices available")


check_cuda()

CUDA_HOME: /run/current-system/sw
PyTorch version: 2.6.0+cu124
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ]

CUDA available: True

CUDA Device Details:
  Device: NVIDIA GeForce RTX 4060 Ti
  Total memory: 15.60 GB
  CUDA capability: 8.9
  Number of CUDA devices: 1


In [2]:
# Load model
# model can be downloaded from https://huggingface.co/lingxusb/megaDNA_updated/resolve/main/megaDNA_phage_145M.pt
model_path = "../checkpoints/megaDNA_phage_145M.pt"  # model name
device = "cuda"  # change this to 'cuda' if you use GPU

model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
# model.eval()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from Bio import SeqIO
import numpy as np
import pandas as pd
import gc
import random

In [6]:
# Load dataset
fasta_file_path = "../dataset/1Jan2025_genomes.fa"
sequences = list(SeqIO.parse(fasta_file_path, "fasta"))

In [7]:
sequences[0]

SeqRecord(seq=Seq('TTTGGTGGAGCTGGCGGGAGTTGAACCCGCGTCCGAAATTCCTACATACCATTT...CAT'), id='AY319521', name='AY319521', description='AY319521 Salmonella phage SopEPhi, complete sequence.', dbxrefs=[])

In [8]:
seq_lengths = [len(str(seq.seq)) for seq in sequences]
summary = pd.Series(seq_lengths).describe()
summary

count     32043.000000
mean      60215.966607
std       55524.968903
min        1761.000000
25%       33532.500000
50%       44866.000000
75%       67557.500000
max      735411.000000
dtype: float64

In [11]:
# Create subset with length less than 131072, randomly sample 1000
seq_subset = [seq for seq in sequences if len(str(seq.seq)) < 131072]
seq_subset = random.sample(seq_subset, 1000)

In [13]:
# save the subset
with open("../dataset/1Jan2025_genomes_subset.fa", "w") as f:
    for seq in seq_subset:
        f.write(f">{seq.description}\n{str(seq.seq)}\n")

In [4]:
# load the subset
seq_subset = list(SeqIO.parse("../dataset/1Jan2025_genomes_subset.fa", "fasta"))

In [18]:
def get_embeddings(sequence, model, device, max_seq_length=131072):
    """Get embeddings for a single sequence
    Args:
        sequence: Input DNA sequence
        model: MEGADNA model
        device: torch device (CPU/GPU)
        max_seq_length: Maximum sequence length (default: 131072 from model dimensions 128*64*16)
    Returns:
        combined_emb: 964-dimensional embedding vector (on GPU)
    """
    # Check sequence length
    if len(sequence) > max_seq_length:
        print(f"Warning: Sequence length {len(sequence)} exceeds max length {max_seq_length}")
        sequence = sequence[:max_seq_length]

    # Encode sequence
    nt_vocab = ["**", "A", "T", "C", "G", "#"]
    encoded = [0]  # Start token
    for nucleotide in str(sequence):
        if nucleotide in nt_vocab:
            encoded.append(nt_vocab.index(nucleotide))
        else:
            encoded.append(1)
    encoded.append(5)  # End token

    input_seq = torch.tensor(encoded).unsqueeze(0).to(device)

    with torch.no_grad():
        embeddings = model(input_seq, return_value="embedding")
        local_emb = torch.mean(embeddings[0].squeeze(), dim=0)  # (512,)
        middle_emb = torch.mean(torch.mean(embeddings[1].squeeze(), dim=0), dim=0)  # (256,)
        global_emb = torch.mean(torch.mean(embeddings[2].squeeze(), dim=0), dim=0)  # (196,)
        combined_emb = torch.cat([local_emb, middle_emb, global_emb])  # (964,)

    return combined_emb


def process_sequences(sequences, model, device, batch_size=50, max_seq_length=131072):
    """Process multiple sequences in batches
    Args:
        sequences: List of sequences to process
        model: MEGADNA model
        device: torch device (CPU/GPU)
        batch_size: Number of sequences to process at once
        max_seq_length: Maximum sequence length (default: 131072 from model dimensions 128*64*16)
    Returns:
        all_embeddings: Tensor of shape (n_sequences, 964) containing embeddings
    """
    all_embeddings = torch.zeros((len(sequences), 964), device=device)

    for i in range(0, len(sequences), batch_size):
        batch = sequences[i : i + batch_size]

        for j, record in enumerate(batch):
            try:
                all_embeddings[i + j] = get_embeddings(record.seq, model, device, max_seq_length)
            except Exception as e:
                print(f"Error processing sequence {record.id}: {str(e)}")

        print(f"Processed {min(i + batch_size, len(sequences))}/{len(sequences)} sequences")

    return all_embeddings

Processed 50/1000 sequences
Processed 100/1000 sequences
Processed 150/1000 sequences
Processed 200/1000 sequences
Processed 250/1000 sequences
Processed 300/1000 sequences
Processed 350/1000 sequences
Processed 400/1000 sequences
Processed 450/1000 sequences
Processed 500/1000 sequences
Processed 550/1000 sequences
Processed 600/1000 sequences
Processed 650/1000 sequences
Processed 700/1000 sequences
Processed 750/1000 sequences
Processed 800/1000 sequences
Processed 850/1000 sequences
Processed 900/1000 sequences
Processed 950/1000 sequences
Processed 1000/1000 sequences


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (1000,) + inhomogeneous part.