In [2]:
import torch
import sys


def check_cuda():
    print(f"PyTorch version: {torch.__version__}")
    print(f"Python version: {sys.version}")

    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"\nCUDA available: {cuda_available}")

    if cuda_available:
        # Get current CUDA device
        current_device = torch.cuda.current_device()
        # Get device properties
        device_props = torch.cuda.get_device_properties(current_device)

        print("\nCUDA Device Details:")
        print(f"  Device: {torch.cuda.get_device_name(current_device)}")
        print(f"  Total memory: {device_props.total_memory / 1024**3:.2f} GB")
        print(f"  CUDA capability: {device_props.major}.{device_props.minor}")
        print(f"  Number of CUDA devices: {torch.cuda.device_count()}")
    else:
        print("\nNo CUDA devices available")


check_cuda()

PyTorch version: 2.6.0+cu126
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ]

CUDA available: False

No CUDA devices available


In [3]:
import torch
import numpy as np

nucleotides = ["**", "A", "T", "C", "G", "#"]  # vocabulary


def token2nucleotide(s):
    return nucleotides[s]


PRIME_LENGTH = 4  # give the model a random DNA primer to start
num_seq = 2  # number of runs
context_length = (
    10000  # maximal length for the generated sequence (upper limit for the model is 131K)
)

# model can be downloaded from https://huggingface.co/lingxusb/megaDNA_updated/resolve/main/megaDNA_phage_145M.pt
model_path = "../checkpoints/megaDNA_phage_145M.pt"  # model name
device = "cuda"  # change this to 'cuda' if you use GPU

In [4]:
for j in range(num_seq):
    # Load the pre-trained model
    model = torch.load(model_path, map_location=torch.device(device))
    model.eval()  # Set the model to evaluation mode

    # set the random DNA primer
    primer_sequence = (
        torch.tensor(np.random.choice(np.arange(1, 5), PRIME_LENGTH)).long().to(device)[None,]
    )
    primer_DNA = "".join(map(token2nucleotide, primer_sequence[0]))
    print(f"Primer sequence: {primer_DNA}\n{'*' * 100}")

    # Generate a sequence using the model
    seq_tokenized = model.generate(
        primer_sequence, seq_len=context_length, temperature=0.95, filter_thres=0.0
    )
    generated_sequence = "".join(map(token2nucleotide, seq_tokenized.squeeze().cpu().int()))

    # Split the generated sequence into contigs at the '#' character
    contigs = generated_sequence.split("#")

    # Write the contigs to a .fna file
    output_file_path = f"generate_{1 + j}.fna"
    with open(output_file_path, "w") as file:
        for idx, contig in enumerate(contigs):
            if len(contig) > 0:
                file.write(f">contig_{idx}\n{contig}\n")

    # Clean up to free memory
    del model, primer_sequence, generated_sequence
    torch.cuda.empty_cache()

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL megaDNA.megadna.MEGADNA was not an allowed global by default. Please use `torch.serialization.add_safe_globals([MEGADNA])` or the `torch.serialization.safe_globals([MEGADNA])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.