In [1]:
from dotenv import load_dotenv
import torch
import sys
import os


load_dotenv()  # will load from .env file in the same directory

# Then add this to check the environment:

print(f"CUDA_HOME: {os.environ.get('CUDA_HOME', 'Not set')}")


def check_cuda():
    print(f"PyTorch version: {torch.__version__}")
    print(f"Python version: {sys.version}")

    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"\nCUDA available: {cuda_available}")

    if cuda_available:
        # Get current CUDA device
        current_device = torch.cuda.current_device()
        # Get device properties
        device_props = torch.cuda.get_device_properties(current_device)

        print("\nCUDA Device Details:")
        print(f"  Device: {torch.cuda.get_device_name(current_device)}")
        print(f"  Total memory: {device_props.total_memory / 1024**3:.2f} GB")
        print(f"  CUDA capability: {device_props.major}.{device_props.minor}")
        print(f"  Number of CUDA devices: {torch.cuda.device_count()}")
    else:
        print("\nNo CUDA devices available")


check_cuda()

CUDA_HOME: /run/current-system/sw
PyTorch version: 2.6.0+cu124
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ]

CUDA available: True

CUDA Device Details:
  Device: NVIDIA GeForce RTX 4060 Ti
  Total memory: 15.60 GB
  CUDA capability: 8.9
  Number of CUDA devices: 1


## model loading


In [2]:
# model can be downloaded from https://huggingface.co/lingxusb/megaDNA_updated/resolve/main/megaDNA_phage_145M.pt
model_path = "../checkpoints/megaDNA_phage_145M.pt"  # model name
device = "cuda"  # change this to 'cuda' if you use GPU

model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


MEGADNA(
  (start_tokens): ParameterList(
      (0): Parameter containing: [torch.float32 of size 512 (cuda:0)]
      (1): Parameter containing: [torch.float32 of size 256 (cuda:0)]
      (2): Parameter containing: [torch.float32 of size 196 (cuda:0)]
  )
  (token_embs): ModuleList(
    (0): Embedding(6, 196)
    (1): Sequential(
      (0): Embedding(6, 196)
      (1): Rearrange('... r d -> ... (r d)')
      (2): LayerNorm((3136,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=3136, out_features=256, bias=True)
      (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (2): Sequential(
      (0): Embedding(6, 196)
      (1): Rearrange('... r d -> ... (r d)')
      (2): LayerNorm((200704,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=200704, out_features=512, bias=True)
      (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (transformers): ModuleList(
    (0): Transformer(
      (layers): ModuleList(
       

## sequence generation


In [2]:
import numpy as np

nucleotides = ["**", "A", "T", "C", "G", "#"]  # vocabulary


def token2nucleotide(s):
    return nucleotides[s]


PRIME_LENGTH = 4  # give the model a random DNA primer to start
num_seq = 2  # number of runs
context_length = (
    10000  # maximal length for the generated sequence (upper limit for the model is 131K)
)

In [None]:
for j in range(num_seq):
    # Load the pre-trained model
    model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
    model.eval()  # Set the model to evaluation mode

    # set the random DNA primer
    primer_sequence = (
        torch.tensor(np.random.choice(np.arange(1, 5), PRIME_LENGTH)).long().to(device)[None,]
    )
    primer_DNA = "".join(map(token2nucleotide, primer_sequence[0]))
    print(f"Primer sequence: {primer_DNA}\n{'*' * 100}")

    # Generate a sequence using the model
    seq_tokenized = model.generate(
        primer_sequence, seq_len=context_length, temperature=0.95, filter_thres=0.0
    )
    generated_sequence = "".join(map(token2nucleotide, seq_tokenized.squeeze().cpu().int()))

    # Split the generated sequence into contigs at the '#' character
    contigs = generated_sequence.split("#")

    # Write the contigs to a .fna file
    output_file_path = f"generate_{1 + j}.fna"
    with open(output_file_path, "w") as file:
        for idx, contig in enumerate(contigs):
            if len(contig) > 0:
                file.write(f">contig_{idx}\n{contig}\n")

    # Clean up to free memory
    del model, primer_sequence, generated_sequence
    torch.cuda.empty_cache()

## mutagenesis


In [None]:
from Bio import SeqIO
from BCBio import GFF
import random

Please download the fasta file and gene annotation for lambda phage from https://www.ncbi.nlm.nih.gov/nuccore/NC_001416.1


In [None]:
# Read the FASTA file
fasta_file_path = "NC_001416.1.fasta"
seq_ids, sequences = [], []

with open(fasta_file_path, "r") as fasta_file:
    for record in SeqIO.parse(fasta_file, "fasta"):
        seq_ids.append(record.id)
        sequences.append(str(record.seq))

# Read the gene annotations
gff_file_path = "NC_001416.1.gff3"
limit_info = dict(gff_type=["CDS"])

start_position, end_position, strand_position = [], [], []

with open(gff_file_path) as in_handle:
    for rec in GFF.parse(in_handle, limit_info=limit_info):
        start_position.extend(feature.location.start for feature in rec.features)
        end_position.extend(feature.location.end for feature in rec.features)
        strand_position.extend(feature.location.strand for feature in rec.features)


In [None]:
nt = ["**", "A", "T", "C", "G", "#"]  # Vocabulary
seq_id = 0  # Sequence ID


def encode_sequence(sequence, nt_vocab=nt):
    """Encode a DNA sequence to its numerical representation."""
    return (
        [0]
        + [nt_vocab.index(nucleotide) if nucleotide in nt_vocab else 1 for nucleotide in sequence]
        + [5]
    )


def get_loss_for_sequence(model, sequence, device):
    """Get model loss for a given sequence."""
    input_seq = torch.tensor(sequence).unsqueeze(0).to(device)
    with torch.no_grad():
        loss = model(input_seq, return_value="loss")
    return loss


# Get the model loss for the WT sequence
encoded_wt_sequence = encode_sequence(sequences[seq_id])
wt_loss = get_loss_for_sequence(model, encoded_wt_sequence, device)
print(wt_loss)

# Get the model loss for the mutants in the start codons
loss_start = []
random.seed(42)
for j, (start, end, strand) in enumerate(zip(start_position, end_position, strand_position)):
    encoded_mutant_sequence = encode_sequence(sequences[seq_id])

    # Mutate start codon positions based on strand orientation
    positions = range(start + 1, start + 4) if strand == 1 else range(end - 2, end + 1)
    for i in positions:
        encoded_mutant_sequence[i] = random.choice([1, 2, 3, 4])

    # Get model loss for mutated sequence
    mutant_loss = get_loss_for_sequence(model, encoded_mutant_sequence, device)
    loss_start.append(mutant_loss)

## embedding and loss


In [3]:
import numpy as np

In [5]:
print(model.max_seq_len)
# max sequence length: 128*64*16 = 131072

(128, 64, 16)


In [9]:
# a random input sequence
encoded_sequence = np.random.choice(np.arange(1, 5), 10000)
input_seq = torch.tensor(encoded_sequence).unsqueeze(0).to(device)
embeddings = model(input_seq, return_value="embedding")
# print(embeddings)

In [10]:
print(len(embeddings))
print(embeddings[0].shape)
print(embeddings[1].shape)
print(embeddings[2].shape)

3
torch.Size([1, 11, 512])
torch.Size([10, 65, 256])
torch.Size([640, 17, 196])


In [11]:
loss = model(input_seq, return_value="loss")
print(loss)

tensor(1.4106, device='cuda:0', grad_fn=<NllLoss2DBackward0>)


In [None]:
# try with longer sequence to verify the max sequence length

import gc

# Clear memory first
gc.collect()
torch.cuda.empty_cache()

# Process sequence

encoded_sequence = np.random.choice(np.arange(1, 5), 131072)
input_seq = torch.tensor(encoded_sequence).unsqueeze(0).to(device)

with torch.no_grad():
    embeddings = model(input_seq, return_value="embedding")
    embeddings = [e.cpu().numpy() for e in embeddings]  # Move to CPU right away

# Clear GPU memory again
torch.cuda.empty_cache()