In [26]:
import scanpy as sc
from omnicell.data.loader import DataLoader, DatasetDetails
import torch 
from transformers import AutoTokenizer, AutoModel
import transformers

print(torch.cuda.is_available())


True


In [27]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1").to("cuda")

In [31]:
from omnicell.data.catalogue import Catalogue, DatasetDetails
import numpy as np
catalogue = Catalogue("./configs/catalogue")

#Getting the dataset details from the data_catalogue.json

ds_details = catalogue.get_dataset_details("satija_IFNB_raw")
pert_key = ds_details.pert_key
control_pert = ds_details.control


adata = sc.read(ds_details.path, backed ="r+")

gene_names = adata.var["gene"]


gene_names_idx = gene_names.index.to_numpy().astype(np.int32) - 1

gene_names = list(gene_names)
tokenizer.pad_token = tokenizer.eos_token

gene_names_idx

tokenizer.pad_token = "[PAD]"

In [32]:
batch_size = 32 
tokenized_batches = []
for i in range(0, len(gene_names), batch_size):
    indices = gene_names_idx[i:i+batch_size]
    genes = gene_names[i:i+batch_size]
    tokenized_gene_names_batch = tokenizer(genes, return_tensors="pt", padding=True, truncation=True, max_length=30)
    tokenized_gene_names_batch["idx"] = indices
    tokenized_batches.append(tokenized_gene_names_batch)


In [33]:
import gc

#Clearing cuda cache and gc to free up memory
torch.cuda.empty_cache()
gc.collect()

22

In [38]:
gene_representations = torch.zeros((len(gene_names), 768))

for batch in tokenized_batches:
    # Move batch to GPU
    batch_gpu = {k: v.to("cuda") for k, v in batch.items() if k != "idx"}
    
    # Forward pass
    with torch.no_grad():  # Prevent gradient computation if not needed
        outputs = model(**batch_gpu)
    
    # Move results back to CPU immediately
    mean_logits = torch.mean(outputs.last_hidden_state, dim=1).cpu()
    
    # Store results
    gene_representations[batch["idx"]] = mean_logits
    
    # Clear GPU cache periodically
    if batch["idx"][0] % 100 == 0:  # Every 100 batches
        torch.cuda.empty_cache()
        
    # Delete unnecessary tensors
    del outputs, batch_gpu



In [41]:
torch.save({"repr" : gene_representations, "gene_names" : gene_names} , "notebooks/gene_embeddings/gene_representations_bioBERT.pt")

In [40]:
gene_representations.shape

gene_representations

tensor([[ 0.2309, -0.1624, -0.1522,  ...,  0.3417, -0.0391, -0.2453],
        [ 0.0601, -0.1340, -0.0610,  ...,  0.1344, -0.0038, -0.1334],
        [ 0.2713, -0.0962, -0.2993,  ..., -0.0051,  0.2950, -0.0486],
        ...,
        [ 0.1667, -0.2254, -0.1171,  ...,  0.1285,  0.0304, -0.1970],
        [ 0.1888, -0.3543, -0.0702,  ...,  0.2456, -0.0742, -0.2947],
        [-0.0017, -0.2802, -0.1396,  ..., -0.0196, -0.0011, -0.3765]])

In [11]:
import torch
import umap
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

# Convert to numpy and standardize
data = gene_representations.cpu().numpy()
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data)

# PCA first - reduce to 100 dimensions
print("Running PCA...")
pca = PCA(n_components=100)
data_pca = pca.fit_transform(data_scaled)

# Calculate variance explained
var_explained = pca.explained_variance_ratio_.cumsum()
print(f"Variance explained by 100 PCs: {var_explained[-1]:.3f}")

# UMAP on PCA results
print("Running UMAP...")
reducer = umap.UMAP(
    n_neighbors=30,
    min_dist=0.1,
    random_state=42
)
embedding = reducer.fit_transform(data_pca)

# Plot
plt.figure(figsize=(12, 10))
plt.scatter(
    embedding[:, 0],
    embedding[:, 1],
    s=1,
    alpha=0.5,
    c='blue'
)

plt.title('UMAP visualization of gene representations (PCA -> UMAP)')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
sns.despine()
plt.tight_layout()
plt.show()

: 

: 

: 

In [None]:
model.