In [4]:
import pandas as pd
from dataclasses import dataclass
import torch
import os

@dataclass
class ModelConfig:
    name: str
    embedding_layer: int
    embed_dim: int
    tokens_per_batch: int
    truncation_seq_length: int


GENE2PROTEIN_PATH = "/home/ec2-user/cytoself-data/sequences.csv"
PROTEIN_EMBED_PATH = "/home/ec2-user/cytoself-data/"

# GENE2PROTEIN_PATH = "/home/ec2-user/esm-data/protein_loc.csv"
# PROTEIN_EMBED_PATH = "/home/ec2-user/esm-data/"

# model_config = ModelConfig(
#     name="esm2_t33_650M_UR50D",
#     embedding_layer=32, # 33
#     embed_dim=1280,
#     tokens_per_batch=1024,
#     truncation_seq_length=1024
# )

model_config = ModelConfig(
    name="esm2_t36_3B_UR50D",
    embedding_layer=34,
    embed_dim=2560,
    tokens_per_batch=1024,
    truncation_seq_length=1024
)

gene_to_protein = pd.read_csv(GENE2PROTEIN_PATH)
num_genes = len(gene_to_protein)

output_dir = os.path.join(PROTEIN_EMBED_PATH, f'{model_config.name}_{model_config.embedding_layer}')
files = sorted([os.path.join(output_dir, f) for f in os.listdir(output_dir)])

In [5]:
import zarr


z_embedding_prot = zarr.open(
    os.path.join(PROTEIN_EMBED_PATH, f'{model_config.name}_{model_config.embedding_layer}.zarr'),
    mode="w",
        shape=(num_genes, model_config.truncation_seq_length + 1, model_config.embed_dim),
        chunks=(1, None, None),
    dtype="float32",
)

In [6]:
for file in files:
    labels, strs, representations = torch.load(file)
    # Save data for each protein
    for i, label in enumerate(labels):
        index = gene_to_protein.index.get_loc(label)
        truncate_len = min(model_config.truncation_seq_length, len(strs[i]))
        z_embedding_prot[index, : truncate_len + 1] = (
            representations[i, : truncate_len + 1].detach().cpu().numpy()
        )

: 