In [1]:
from dotenv import load_dotenv
import torch
import sys
import os


load_dotenv()  # will load from .env file in the same directory

# Then add this to check the environment:

print(f"CUDA_HOME: {os.environ.get('CUDA_HOME', 'Not set')}")


def check_cuda():
    print(f"PyTorch version: {torch.__version__}")
    print(f"Python version: {sys.version}")

    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"\nCUDA available: {cuda_available}")

    if cuda_available:
        # Get current CUDA device
        current_device = torch.cuda.current_device()
        # Get device properties
        device_props = torch.cuda.get_device_properties(current_device)

        print("\nCUDA Device Details:")
        print(f"  Device: {torch.cuda.get_device_name(current_device)}")
        print(f"  Total memory: {device_props.total_memory / 1024**3:.2f} GB")
        print(f"  CUDA capability: {device_props.major}.{device_props.minor}")
        print(f"  Number of CUDA devices: {torch.cuda.device_count()}")
    else:
        print("\nNo CUDA devices available")


check_cuda()

CUDA_HOME: /run/current-system/sw
PyTorch version: 2.6.0+cu124
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ]

CUDA available: True

CUDA Device Details:
  Device: NVIDIA GeForce RTX 4060 Ti
  Total memory: 15.60 GB
  CUDA capability: 8.9
  Number of CUDA devices: 1


In [2]:
# Load model
# model can be downloaded from https://huggingface.co/lingxusb/megaDNA_updated/resolve/main/megaDNA_phage_145M.pt
model_path = "../checkpoints/megaDNA_phage_145M.pt"  # model name
device = "cuda"  # change this to 'cuda' if you use GPU

model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
# model.eval()

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from Bio import SeqIO
import numpy as np
import pandas as pd
import gc
import random

In [4]:
# Load dataset
fasta_file_path = "../dataset/1Jan2025_genomes.fa"
sequences = []
with open(fasta_file_path, "r") as fasta_file:
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequences.append(str(record.seq))

In [21]:
print(len(sequences))
print(sequences[0])
print(len(sequences[0]))

32043
TTTGGTGGAGCTGGCGGGAGTTGAACCCGCGTCCGAAATTCCTACATACCATTTTTACCCTAACAAAAACATACATTTACTTTTTAAATCATTATGTTATCTATTTCTGTATCTGTCGGGTTTTATGGCGTTTTAAGGCTCTGCCGCCAATGTGCCGCCACTATTTTTTGGTTTTTATCCTTCCAGGATTTTGAACGTCTTGGATGCAAACTTCTAGCTTGTAAAAATCTGTCAGTTGATAAACGCTCACTATCTCCTCATCTAATTGGTCTAAGATATCCCATTTTTTACCTCTTCTTAGATATTTACTAAGTTTTTCTGTGCTTGATATTGATTGGCGTTCAGGAATGTATTCTATGGATTTAAGAACTTGCCATGATAAAGGAGGGGTGCATACAAACTGTTTGGCAAGATGTGGGCAATAATCATGGCGCTTAACATCGCTATGGTTAATTTGGGGAGGGAATACAACATTCAAGCCCTTGTAACCATAGTATCCACTATCGATTTTTGTTGTATGGTATGCAATGCCGATTATTTTATGTTCAGCTTGTCGACTTATCCACTGCATTAACAAATTAGGAATAACGTATTCTTGAATAAAAATGGCATCTTGTTGTTTGTTCAAATAATTGCAGGATAATATTAGAGGCCATAAAGCGACGAGTGATAATGCTAATTTGAAATTTTTCGGTTGATTTTTATTTCTAAAATTTGATGTGATATCAATCAGCGCCTCGATATTTAGATTCAAAATCATCTCTTCCGGATGAGTTTGCGACGTATAGAAAGCAGAAACAAAAAGTTTGTTAAAGTCAGGCCTATTCATCTCTCTCCAGCAAACGTATAAGGATGCGCCTAAATATAAGCATGGGAGTCCTGCTACTGAATATCTTTGATTTCGAACTAAATGACGTTGGCTGAAGGGGATGTGAAAAATCTCTTCTCTTTTTGTTAAGTAATTATCAGATTTTCTTACTCTGTATAAAGGGGT

In [25]:
seq_lengths = [len(seq) for seq in sequences]
summary = pd.Series(seq_lengths).describe()
summary

count     32043.000000
mean      60215.966607
std       55524.968903
min        1761.000000
25%       33532.500000
50%       44866.000000
75%       67557.500000
max      735411.000000
dtype: float64

In [9]:
# make a small subset of sequences with length less than 131072, randomly sample 1000
seq_subset = [seq for seq in sequences if len(seq) < 131072]
seq_subset = random.sample(seq_subset, 1000)

In [None]:
def get_embeddings(sequence, model, device, max_length=131072):
    """Get embeddings with proper length handling"""
    # Encode sequence
    nt_vocab = ["**", "A", "T", "C", "G", "#"]
    encoded = [0]  # Start token
    for nucleotide in sequence:
        if nucleotide in nt_vocab:
            encoded.append(nt_vocab.index(nucleotide))
        else:
            encoded.append(1)
    encoded.append(5)  # End token

    input_seq = torch.tensor(encoded).unsqueeze(0).to(device)

    with torch.no_grad():
        embeddings = model(input_seq, return_value="embedding")
        # Move to CPU and convert to numpy immediately
        embeddings = [e.cpu().numpy() for e in embeddings]

    # Clear GPU memory
    del input_seq
    torch.cuda.empty_cache()

    return embeddings


# Store embeddings in a list first to handle variable sequence lengths
local_embeddings_list = []
middle_embeddings_list = []
global_embeddings_list = []
sequence_lengths = []  # To keep track of sequence lengths

# Use your subset instead of full sequences
for i, seq in enumerate(seq_subset):
    try:
        # Get embeddings
        embeddings = get_embeddings(seq, model, device)

        # Store embeddings (taking mean across sequence length)
        local_embeddings_list.append(np.mean(embeddings[0].squeeze(), axis=0))
        middle_embeddings_list.append(np.mean(embeddings[1].squeeze(), axis=0))
        global_embeddings_list.append(np.mean(embeddings[2].squeeze(), axis=0))

        sequence_lengths.append(len(seq))

        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1}/{len(seq_subset)} sequences")
            gc.collect()
            torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error processing sequence {i}: {str(e)}")

# Convert lists to numpy arrays
local_embeddings = np.array(local_embeddings_list)
middle_embeddings = np.array(middle_embeddings_list)
global_embeddings = np.array(global_embeddings_list)

print("\nFinal shapes:")
print(f"Local embeddings: {local_embeddings.shape}")
print(f"Middle embeddings: {middle_embeddings.shape}")
print(f"Global embeddings: {global_embeddings.shape}")

# Save the embeddings
np.save("../results/local_embeddings_subset.npy", local_embeddings)
np.save("../results/middle_embeddings_subset.npy", middle_embeddings)
np.save("../results/global_embeddings_subset.npy", global_embeddings)
np.save("../results/sequence_lengths_subset.npy", np.array(sequence_lengths))

  self.gen = func(*args, **kwds)


Processed 10/1000 sequences
Processed 20/1000 sequences
Processed 30/1000 sequences
Processed 40/1000 sequences
Processed 50/1000 sequences
Processed 60/1000 sequences
Processed 70/1000 sequences
Processed 80/1000 sequences
Processed 90/1000 sequences
Processed 100/1000 sequences
Processed 110/1000 sequences
Processed 120/1000 sequences
Processed 130/1000 sequences
Processed 140/1000 sequences
Processed 150/1000 sequences
Processed 160/1000 sequences
Processed 170/1000 sequences
Processed 180/1000 sequences
Processed 190/1000 sequences
Processed 200/1000 sequences
Processed 210/1000 sequences
Processed 220/1000 sequences
Processed 230/1000 sequences
Processed 240/1000 sequences
Processed 250/1000 sequences
Processed 260/1000 sequences
Processed 270/1000 sequences
Processed 280/1000 sequences
Processed 290/1000 sequences
Processed 300/1000 sequences
Processed 310/1000 sequences
Processed 320/1000 sequences
Processed 330/1000 sequences
Processed 340/1000 sequences
Processed 350/1000 sequ