In [1]:
from borzoi_pytorch import Borzoi
import torch


device = "cpu"
model = Borzoi.from_pretrained('johahi/borzoi-replicate-0').to(device)

# top-level children modules
print("Model Layer Names:")
for name, module in model.named_children():
    print(f" - {name}")


print("\nFull Structure (Start):")
print(str(model)[:500])

  from .autonotebook import tqdm as notebook_tqdm


Model Layer Names:
 - conv_dna
 - _max_pool
 - res_tower
 - unet1
 - horizontal_conv0
 - horizontal_conv1
 - upsample
 - transformer
 - upsampling_unet1
 - separable1
 - upsampling_unet0
 - separable0
 - crop
 - final_joined_convs
 - human_head
 - final_softplus

Full Structure (Start):
Borzoi(
  (conv_dna): ConvDna(
    (conv_layer): Conv1d(4, 512, kernel_size=(15,), stride=(1,), padding=same)
    (max_pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (_max_pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (res_tower): Sequential(
    (0): ConvBlock(
      (norm): BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (activation): GELU(approximate='tanh')
      (conv_layer)


On Google Colab

In [None]:
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/systems_genetics'

drive.mount('/content/gdrive', force_remount=True)
os.chdir(gdrive_path)

In [None]:
print("ðŸš€ Installing dependencies...")
!pip install -q borzoi-pytorch pyfaidx pybiomart

print("Genome Download: Starting...")
!wget -q -O hg38.fa.gz https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
!gunzip -f hg38.fa.gz
print("Genome ready.")


import torch
import pandas as pd
import numpy as np
from borzoi_pytorch import Borzoi
from pyfaidx import Fasta
from tqdm import tqdm
from pybiomart import Server
import gc 


GENOME_PATH = "hg38.fa"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1       # for low memory (fix out of memory error)
SEQ_LEN = 524288     

print(f"Running on: {DEVICE}")

# load moodel
print("Loading Borzoi Model...")
model = Borzoi.from_pretrained('johahi/borzoi-replicate-0').to(DEVICE)
model.eval()

class StopExecution(Exception): pass
captured_embedding = None

def transformer_hook(module, input, output):
    global captured_embedding
    captured_embedding = output.detach().cpu()
    raise StopExecution()

handle = model.transformer.register_forward_hook(transformer_hook)


print("Loading your gene list...")

try:
    my_genes = pd.read_csv("common_genes.csv", header=None)[0].tolist()
    print(f"Loaded {len(my_genes)} genes from file.")
except FileNotFoundError:
    print("Error: Could not find 'my_study_genes.csv'. Did you upload it?")
    raise

print("Fetching gene locations from BioMart...")
server = Server(host='http://www.ensembl.org')
dataset = server.marts['ENSEMBL_MART_ENSEMBL'].datasets['hsapiens_gene_ensembl']

# query locations
gene_locs = dataset.query(attributes=['ensembl_gene_id', 'chromosome_name', 'transcription_start_site'])
gene_locs.columns = ['ensembl_id', 'chrom', 'tss']

# keep only genes from HPO dataset (common_genes.csv)
gene_locs = gene_locs[gene_locs['ensembl_id'].isin(my_genes)]

# filter standard chromosomes
valid_chroms = [str(i) for i in range(1, 23)] + ['X', 'Y']
gene_locs = gene_locs[gene_locs['chrom'].isin(valid_chroms)]

print(f"Final Processing List: {len(gene_locs)} genes.")



# inference loop

fasta = Fasta(GENOME_PATH)
embeddings = []
valid_gene_ids = []

def get_one_hot(chrom, tss):
    chrom_str = f"chr{chrom}"
    start = tss - (SEQ_LEN // 2)
    end = start + SEQ_LEN
    if chrom_str not in fasta.keys(): return None
    if start < 0 or end > len(fasta[chrom_str]): return None
    seq = fasta[chrom_str][start:end].seq.upper()
    if len(seq) != SEQ_LEN: return None
    
    # Fast encoding A:0, C:1, G:2, T:3
    arr = np.zeros((4, SEQ_LEN), dtype=np.float32)
    for i, char in enumerate(seq):
        if char == 'A': arr[0, i] = 1
        elif char == 'C': arr[1, i] = 1
        elif char == 'G': arr[2, i] = 1
        elif char == 'T': arr[3, i] = 1
    return arr

print("Starting Inference...")

# save checkpoint every 5000 genes
checkpoint_interval = 5000 

for i, row in tqdm(gene_locs.iterrows(), total=len(gene_locs)):
    try:
        x_arr = get_one_hot(row['chrom'], row['tss'])
        if x_arr is None: continue
        
        # use Mixed Precision (autocast) to save memory
        x_tensor = torch.tensor(x_arr).unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                try:
                    model(x_tensor)
                except StopExecution:
                    pass
        
        if captured_embedding is not None:
            mid = captured_embedding.shape[1] // 2
            emb = captured_embedding[:, mid-1:mid+2, :].mean(dim=1)
            embeddings.append(emb.squeeze().numpy())
            valid_gene_ids.append(row['ensembl_id'])
            captured_embedding = None
            
        
        if i % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()

        
        if len(embeddings) > 0 and len(embeddings) % checkpoint_interval == 0:
             temp_df = pd.DataFrame(np.vstack(embeddings), index=valid_gene_ids)
             temp_df.to_csv(f"/content/gdrive/MyDrive/systems_genetics/borzoi_checkpoint_{len(embeddings)}.csv")
             print(f"Saved checkpoint at {len(embeddings)} genes.")

    except Exception as e:
        torch.cuda.empty_cache()
        continue


print(f"Saving {len(embeddings)} embeddings...")
X_borzoi = pd.DataFrame(np.vstack(embeddings), index=valid_gene_ids)
X_borzoi.to_csv("/content/gdrive/MyDrive/systems_genetics/borzoi_embeddings.csv")