# Testing STATE and CAPTAIN Embeddings

This notebook shows how to:
1. Load Perturb-CITE-seq data
2. Get STATE embeddings from RNA data
3. Get CAPTAIN predictions from RNA embeddings
4. Verify all formats are correct

## STATE Requirements

Based on official [STATE documentation](https://pypi.org/project/arc-state/):

**Command**: `state emb transform --input <input.h5ad> --output <output.h5ad>`

**Input AnnData Requirements**:
- ✓ CSR sparse matrix format for `.X`
- ✓ `gene_name` column in `.var` DataFrame

**Installation**:
```bash
uv tool install arc-state
```

In [1]:
import torch
import pandas as pd
import numpy as np
import anndata as ad
import subprocess
import os
import tempfile
from scipy.spatial.distance import cosine

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA H100 80GB HBM3


## 1. Load Perturb-CITE-seq Data

In [2]:
# Define paths
data_dir = "/home/nebius/cellian/data/perturb-cite-seq/SCP1064"
meta_path = f"{data_dir}/metadata/RNA_metadata.csv"
rna_csv = f"{data_dir}/expression/RNA_expression_subset1a.csv"
#rna_parquet = f"{data_dir}/other/RNA_expression.parquet"
protein_csv = f"{data_dir}/expression/Protein_expression.csv"

# Load metadata (small, always use CSV)
meta_df = pd.read_csv(meta_path, index_col=0)
rna_df = pd.read_csv(rna_csv, index_col=0)
protein_df = pd.read_csv(protein_csv, index_col=0)
# # Load RNA data - prefer Parquet (10-50x faster!)
# import time
# if os.path.exists(rna_parquet):
#     print("Loading RNA from Parquet (fast)...")
#     start = time.time()
#     rna_df = pd.read_parquet(rna_parquet)
#     print(f"✓ Loaded in {time.time() - start:.2f} seconds")
# else:
#     print("⚠ Parquet file not found, loading from CSV (slow)...")
#     print("  Consider running: python convert_to_parquet.py")
#     start = time.time()
#     rna_df = pd.read_csv(rna_csv, index_col=0)
#     print(f"✓ Loaded in {time.time() - start:.1f} seconds")

# # Load protein data - prefer Parquet
# if os.path.exists(protein_parquet):
#     print("Loading Protein from Parquet (fast)...")
#     protein_df = pd.read_parquet(protein_parquet)
# else:
#     print("Loading Protein from CSV...")
#     protein_df = pd.read_csv(protein_csv, index_col=0)

print(f"\nMetadata shape: {meta_df.shape}")
print(f"RNA expression shape: {rna_df.shape}")
print(f"Protein expression shape: {protein_df.shape}")
print(f"\nFirst few metadata rows:")
meta_df.head()

  meta_df = pd.read_csv(meta_path, index_col=0)



Metadata shape: (218332, 5)
RNA expression shape: (23712, 27291)
Protein expression shape: (20, 218331)

First few metadata rows:


Unnamed: 0_level_0,library_preparation_protocol,condition,MOI,sgRNA,UMI_count
NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
TYPE,group,group,numeric,group,numeric
CELL_1,10X 3' v3 sequencing,Control,1,HLA-B_2,10832.0
CELL_2,10X 3' v3 sequencing,Control,2,,10731.0
CELL_3,10X 3' v3 sequencing,Control,1,HLA-B_2,28821.0
CELL_4,10X 3' v3 sequencing,Control,2,,15322.0


In [3]:
meta_df["sgRNA"] = meta_df['sgRNA'].fillna('None')
meta_df["target_gene"] = meta_df["sgRNA"].str.replace(r"(_\d+)$", "", regex=True)
meta_df["sgRNA"].value_counts()

sgRNA
None           91365
IFNGR2_2         358
NO_SITE_47       333
NO_SITE_913      317
HLA-DRB5_2       315
               ...  
PSMA7_1            2
group              1
TUBB_2             1
EIF2S3_3           1
UBC_2              1
Name: count, Length: 820, dtype: int64

In [4]:
cells = list(rna_df.columns)
protein_df = protein_df[cells]
meta_df = meta_df[meta_df.index.isin(cells)]
meta_df

Unnamed: 0_level_0,library_preparation_protocol,condition,MOI,sgRNA,UMI_count,target_gene
NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
CELL_1,10X 3' v3 sequencing,Control,1,HLA-B_2,10832.0,HLA-B
CELL_2,10X 3' v3 sequencing,Control,2,,10731.0,
CELL_3,10X 3' v3 sequencing,Control,1,HLA-B_2,28821.0,HLA-B
CELL_4,10X 3' v3 sequencing,Control,2,,15322.0,
CELL_5,10X 3' v3 sequencing,Control,0,,10314.0,
...,...,...,...,...,...,...
CELL_27287,10X 3' v3 sequencing,Control,1,SLC25A13_3,27548.0,SLC25A13
CELL_27288,10X 3' v3 sequencing,Control,1,ONE_NON-GENE_SITE_658,18509.0,ONE_NON-GENE_SITE
CELL_27289,10X 3' v3 sequencing,Control,2,,15669.0,
CELL_27290,10X 3' v3 sequencing,Control,2,,20478.0,


In [26]:
rna_df

Unnamed: 0_level_0,CELL_1,CELL_2,CELL_3,CELL_4,CELL_5,CELL_6,CELL_7,CELL_8,CELL_9,CELL_10,...,CELL_27282,CELL_27283,CELL_27284,CELL_27285,CELL_27286,CELL_27287,CELL_27288,CELL_27289,CELL_27290,CELL_27291
GENE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,0.000000,0.000000,0.000000,4.879245,0.0,0.000000,3.700203,4.777204,0.00000,3.353191,...,4.314464,0.000000,0.000000,0.000000,4.933156,4.298653,0.000000,4.171619,0.000000,0.0
A1BG-AS1,0.000000,0.000000,0.000000,0.000000,0.0,4.740639,0.000000,0.000000,0.00000,3.353191,...,0.000000,0.000000,0.000000,3.835987,0.000000,0.000000,0.000000,4.171619,0.000000,0.0
A1CF,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
A2M,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
A2M-AS1,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZXDC,0.000000,0.000000,0.000000,4.193671,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
ZYG11A,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.00000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0
ZYG11B,0.000000,0.000000,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.00000,3.353191,...,0.000000,0.000000,4.387442,0.000000,0.000000,0.000000,0.000000,4.171619,0.000000,0.0
ZYX,5.223799,4.545292,3.575064,0.000000,0.0,0.000000,4.380914,5.466132,4.90324,4.428212,...,4.314464,5.430164,0.000000,4.518285,5.622695,5.792106,5.380409,4.171619,3.908675,0.0


## 2. Extract Control Cells and Compute Baseline

In [6]:
# Find control cells (empty sgRNA)
sgRNA_col = 'sgRNA' if 'sgRNA' in meta_df.columns else 'sgRNA_target'
control_mask = (
    (meta_df[sgRNA_col] == 'None') |
    (meta_df[sgRNA_col].str.lower().str.contains('non-targeting|control', na=False))
)
control_cells = meta_df[control_mask].index.tolist()

print(f"Found {len(control_cells)} control cells")
print(f"Example control cells: {control_cells[:5]}")

# Compute average control RNA profile (baseline)
control_rna_data = rna_df[control_cells]
average_control_rna_profile = control_rna_data.mean(axis=1)

print(f"\nAverage control RNA profile shape: {average_control_rna_profile.shape}")
print(f"Number of genes: {len(average_control_rna_profile)}")
print(f"\nFirst 10 genes and their average expression:")
average_control_rna_profile.head(10)

Found 12594 control cells
Example control cells: ['CELL_2', 'CELL_4', 'CELL_5', 'CELL_7', 'CELL_12']



Average control RNA profile shape: (23712,)
Number of genes: 23712

First 10 genes and their average expression:


GENE
A1BG        0.946926
A1BG-AS1    0.351152
A1CF        0.002074
A2M         0.000608
A2M-AS1     0.003502
A4GALT      0.031894
AAAS        0.813001
AACS        0.555429
AADAC       0.015059
AADACL4     0.000000
dtype: float64

In [33]:
protein_df_control = protein_df[control_cells]

In [7]:
# STATE expects AnnData format (.h5ad) with specific requirements:
# 1. CSR sparse matrix format for .X
# 2. gene_name column in var DataFrame

from scipy.sparse import csr_matrix

X = control_rna_data.T.values  # Shape: (n_cells, n_genes)

# Convert to CSR sparse matrix (required by STATE)
X_sparse = csr_matrix(X)

obs = pd.DataFrame({'cell_type': ['control']}, index=control_rna_data.columns)

# Add gene_name to var (required by STATE)
var = pd.DataFrame({'gene_name': control_rna_data.index},
                   index=control_rna_data.index)

adata_input = ad.AnnData(X=X_sparse, obs=obs, var=var)

print(f"AnnData object created:")
print(f"  Shape: {adata_input.shape} (n_obs x n_vars)")
print(f"  Number of cells: {adata_input.n_obs}")
print(f"  Number of genes: {adata_input.n_vars}")
print(f"  Matrix format: {type(adata_input.X)}")
print(f"  Sparsity: {1 - adata_input.X.nnz / (adata_input.n_obs * adata_input.n_vars):.2%}")
print(f"\nvar columns: {list(adata_input.var.columns)}")
print(f"\nAnnData summary:")
print(adata_input)

AnnData object created:
  Shape: (12594, 23712) (n_obs x n_vars)
  Number of cells: 12594
  Number of genes: 23712
  Matrix format: <class 'scipy.sparse._csr.csr_matrix'>
  Sparsity: 83.24%

var columns: ['gene_name']

AnnData summary:
AnnData object with n_obs × n_vars = 12594 × 23712
    obs: 'cell_type'
    var: 'gene_name'


In [39]:
adata_input.obs

Unnamed: 0,cell_type
CELL_2,control
CELL_4,control
CELL_5,control
CELL_7,control
CELL_12,control
...,...
CELL_27284,control
CELL_27285,control
CELL_27286,control
CELL_27289,control


## 3. Create AnnData Object for STATE

In [22]:
# Create temporary directory for input/output files
tmpdir = tempfile.mkdtemp()
input_h5ad = os.path.join(tmpdir, "input.h5ad")
output_h5ad = os.path.join(tmpdir, "output.h5ad")

print(f"Temporary directory: {tmpdir}")

# Save AnnData to h5ad file
adata_input.write_h5ad(input_h5ad)
print(f"✓ Saved input AnnData to: {input_h5ad}")

# Run STATE embedding via CLI using official command
# Official: state emb transform --input <input.h5ad> --output <output.h5ad>
cmd = ["state", "emb", "transform", "--model-folder", "/home/nebius/SE-600M", "--config", "/home/nebius/SE-600M/config.yaml", "--checkpoint", "/home/nebius/SE-600M/se600m_epoch16.ckpt", "--input", input_h5ad, "--output", output_h5ad]

# Optional: Add model folder if you have a specific checkpoint
# cmd.extend(["--model-folder", "/path/to/SE-600M"])

print(f"\nRunning STATE command: {' '.join(cmd)}")
print("This may take a minute...\n")

result = subprocess.run(cmd, capture_output=True, text=True)

if result.returncode == 0:
    print("✓ STATE embedding completed successfully")
    if result.stdout:
        print(f"\nSTATE output:\n{result.stdout}")
else:
    print(f"✗ STATE embedding failed")
    print(f"Error: {result.stderr}")
    print(f"\nMake sure:")
    print("  1. STATE is installed: uv tool install arc-state")
    print("  2. Input h5ad has CSR matrix format")
    print("  3. Input h5ad has 'gene_name' in var")

Temporary directory: /tmp/tmpakdnoztl


✓ Saved input AnnData to: /tmp/tmpakdnoztl/input.h5ad

Running STATE command: state emb transform --model-folder /home/nebius/SE-600M --config /home/nebius/SE-600M/config.yaml --checkpoint /home/nebius/SE-600M/se600m_epoch16.ckpt --input /tmp/tmpakdnoztl/input.h5ad --output /tmp/tmpakdnoztl/output.h5ad
This may take a minute...

✓ STATE embedding completed successfully

STATE output:
!!! 16275 genes mapped to embedding file (out of 23712)



In [8]:
input_h5ad = ad.read_h5ad("/tmp/tmpakdnoztl/input_mapped.h5ad")
input_h5ad

AnnData object with n_obs × n_vars = 12594 × 18598
    obs: 'cell_type'
    var: 'gene_name', 'feature', 'my_Id'

In [28]:
input_h5ad.X = input_h5ad.X.toarray()
input_h5ad.write_h5ad("/tmp/tmpakdnoztl/input_dense.h5ad")

In [34]:
protein_lst = [prot.split(" ")[0] for prot in protein_df_control.index]
protein_df_control.index = protein_lst
# creating the protein adata object
protein_adata = ad.AnnData(X=protein_df_control.T.values,
                           obs=pd.DataFrame(index=protein_df_control.columns),
                           var=pd.DataFrame({'protein_name': protein_df_control.index},
                                            index=protein_df.index))
protein_adata

AnnData object with n_obs × n_vars = 12594 × 20
    var: 'protein_name'

In [24]:
protein_adata.write_h5ad("/tmp/tmpakdnoztl/protein_input.h5ad")

In [35]:
protein_adata.var

Unnamed: 0,protein_name
CD117,CD117
CD119,CD119
CD140a,CD140a
CD140b,CD140b
CD172a,CD172a
CD184,CD184
CD202b,CD202b
CD274,CD274
CD29,CD29
CD309,CD309


In [16]:
input_sctranslator_rna = ad.read_h5ad("/home/nebius/scTranslator/dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad")
input_sctranslator_rna

AnnData object with n_obs × n_vars = 16177 × 23385
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT'
    var: 'Gene symbol', 'Gene Expression', 'Entrez_Gene_Id', 'my_Id'

In [18]:
input_sctranslator_rna.X

array([[ 0.,  0.,  0., ...,  4., 33.,  0.],
       [ 0.,  0.,  0., ...,  1., 57.,  0.],
       [ 0.,  0.,  0., ...,  0., 62.,  0.],
       ...,
       [ 0.,  0.,  0., ...,  0.,  1.,  0.],
       [ 0.,  0.,  0., ...,  0., 19.,  0.],
       [ 0.,  0.,  0., ...,  2., 59.,  0.]],
      shape=(16177, 23385), dtype=float32)

In [34]:
input = ad.read_h5ad("~/scTranslator/dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad")
input

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '~/scTranslator/dataset/test/dataset1/GSM5008737_RNA_finetune_withcelltype.h5ad', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

## 5. Load and Inspect STATE Embeddings

In [23]:
# Load the output AnnData with embeddings
adata_output = ad.read_h5ad(output_h5ad)

print(f"Output AnnData loaded:")
print(f"  Shape: {adata_output.shape}")
print(f"\nAvailable slots in AnnData:")
print(f"  .X shape: {adata_output.X.shape}")
print(f"  .obs keys: {list(adata_output.obs.columns)}")
print(f"  .var keys: {list(adata_output.var.columns)}")
print(f"  .obsm keys: {list(adata_output.obsm.keys())}")

# Extract STATE embedding
if 'X_state' in adata_output.obsm:
    state_embedding = adata_output.obsm['X_state'][0]
    print(f"\n✓ Found STATE embedding in .obsm['X_state']")
elif 'state' in adata_output.obsm:
    state_embedding = adata_output.obsm['state'][0]
    print(f"\n✓ Found STATE embedding in .obsm['state']")
else:
    state_embedding = adata_output.X[0]
    print(f"\n✓ Using embedding from .X")

print(f"\nSTATE Embedding Format:")
print(f"  Type: {type(state_embedding)}")
print(f"  Shape: {state_embedding.shape}")
print(f"  Dtype: {state_embedding.dtype}")
print(f"  Min value: {np.min(state_embedding):.4f}")
print(f"  Max value: {np.max(state_embedding):.4f}")
print(f"  Mean value: {np.mean(state_embedding):.4f}")
print(f"\nFirst 10 values:")
print(state_embedding[:10])

Output AnnData loaded:
  Shape: (12594, 23712)

Available slots in AnnData:
  .X shape: (12594, 23712)
  .obs keys: ['cell_type']
  .var keys: ['gene_name']
  .obsm keys: ['X_state']

✓ Found STATE embedding in .obsm['X_state']

STATE Embedding Format:
  Type: <class 'numpy.ndarray'>
  Shape: (2058,)
  Dtype: float32
  Min value: -0.4004
  Max value: 0.3633
  Mean value: -0.0001

First 10 values:
[ 0.01511088 -0.01780077  0.04651935 -0.02088624 -0.03892435 -0.04113956
  0.01376593 -0.01787988  0.01526911 -0.00636872]


In [25]:
adata_output.obsm["X_state"]

array([[ 0.01511088, -0.01780077,  0.04651935, ...,  0.36328125,
        -0.12158203, -0.04760742],
       [-0.00102659, -0.02491051,  0.02724587, ...,  0.41015625,
        -0.20019531, -0.0480957 ],
       [ 0.01197206, -0.02222265,  0.02269214, ...,  0.33398438,
        -0.1484375 , -0.04443359],
       ...,
       [ 0.02298459, -0.01794687,  0.01015415, ...,  0.3359375 ,
        -0.15625   , -0.11425781],
       [ 0.00589762,  0.00504662,  0.01852406, ...,  0.44335938,
        -0.17480469, -0.07666016],
       [-0.01367594, -0.04118502,  0.03945588, ...,  0.43945312,
        -0.20507812, -0.08740234]], shape=(12594, 2058), dtype=float32)

In [29]:
adata_input.var

Unnamed: 0_level_0,gene_name
GENE,Unnamed: 1_level_1
A1BG,A1BG
A1BG-AS1,A1BG-AS1
A1CF,A1CF
A2M,A2M
A2M-AS1,A2M-AS1
...,...
ZXDC,ZXDC
ZYG11A,ZYG11A
ZYG11B,ZYG11B
ZYX,ZYX


## 6. Load CAPTAIN Model

In [26]:
# Load CAPTAIN model
captain_path = "/home/nebius/cellian/foundation_models/CAPTAIN_Base/CAPTAIN_Base.pt"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading CAPTAIN model to {device}...")

captain_model = torch.load(captain_path, map_location=device)
if hasattr(captain_model, 'eval'):
    captain_model.eval()

print(f"✓ CAPTAIN model loaded")
print(f"\nModel type: {type(captain_model)}")
print(f"Model device: {device}")

# Check if it's a PyTorch model with parameters
if hasattr(captain_model, 'parameters'):
    total_params = sum(p.numel() for p in captain_model.parameters())
    print(f"Total parameters: {total_params:,}")

Loading CAPTAIN model to cuda...
✓ CAPTAIN model loaded

Model type: <class 'collections.OrderedDict'>
Model device: cuda


## 7. Run CAPTAIN Inference

In [27]:
# Convert STATE embedding to tensor
state_tensor = torch.tensor(state_embedding, dtype=torch.float32).unsqueeze(0)
print(f"Input tensor shape: {state_tensor.shape}")

# Move to correct device
if hasattr(captain_model, 'parameters'):
    model_device = next(captain_model.parameters()).device
else:
    model_device = torch.device('cpu')

state_tensor = state_tensor.to(model_device)
print(f"Tensor device: {state_tensor.device}")

# Run inference
with torch.no_grad():
    if callable(captain_model):
        protein_output = captain_model(state_tensor)
        print(f"\n✓ CAPTAIN inference completed")
    else:
        print(f"\n⚠ Warning: CAPTAIN model is not callable (might be state_dict)")
        protein_output = state_tensor  # Fallback

# Convert to numpy
protein_prediction = protein_output.squeeze().cpu().numpy()

print(f"\nCAPTAIN Output Format:")
print(f"  Type: {type(protein_prediction)}")
print(f"  Shape: {protein_prediction.shape}")
print(f"  Dtype: {protein_prediction.dtype}")
print(f"  Min value: {np.min(protein_prediction):.4f}")
print(f"  Max value: {np.max(protein_prediction):.4f}")
print(f"  Mean value: {np.mean(protein_prediction):.4f}")
print(f"\nFirst 10 values:")
print(protein_prediction[:10])

Input tensor shape: torch.Size([1, 2058])
Tensor device: cpu


CAPTAIN Output Format:
  Type: <class 'numpy.ndarray'>
  Shape: (2058,)
  Dtype: float32
  Min value: -0.4004
  Max value: 0.3633
  Mean value: -0.0001

First 10 values:
[ 0.01511088 -0.01780077  0.04651935 -0.02088624 -0.03892435 -0.04113956
  0.01376593 -0.01787988  0.01526911 -0.00636872]


## 8. Test Full Pipeline with a Perturbation

In [None]:
# Test with CD58 perturbation
perturbation_name = "CD58"

# Find cells with this perturbation
perturb_mask = meta_df[sgRNA_col].str.contains(perturbation_name, case=False, na=False)
perturb_cells = meta_df[perturb_mask].index.tolist()

print(f"Testing perturbation: {perturbation_name}")
print(f"Found {len(perturb_cells)} cells with this perturbation")

if len(perturb_cells) > 0:
    # Get real RNA and protein profiles
    real_rna = rna_df[perturb_cells].mean(axis=1).values
    real_protein = protein_df[perturb_cells].mean(axis=1).values
    
    print(f"\nGround Truth Shapes:")
    print(f"  Real RNA shape: {real_rna.shape}")
    print(f"  Real protein shape: {real_protein.shape}")
    
    # Compare STATE embedding to real RNA
    # Note: STATE embedding is lower-dimensional than raw RNA
    print(f"\nDimensionality:")
    print(f"  STATE embedding: {state_embedding.shape[0]}D")
    print(f"  Real RNA: {real_rna.shape[0]}D (genes)")
    print(f"  CAPTAIN output: {protein_prediction.shape[0]}D")
    print(f"  Real protein: {real_protein.shape[0]}D (proteins)")
    
    # Calculate similarity for matching dimensions
    if protein_prediction.shape[0] == real_protein.shape[0]:
        protein_similarity = 1 - cosine(protein_prediction, real_protein)
        print(f"\n✓ Protein prediction similarity: {protein_similarity:.4f}")
    else:
        # Align dimensions
        min_len = min(len(protein_prediction), len(real_protein))
        protein_similarity = 1 - cosine(protein_prediction[:min_len], real_protein[:min_len])
        print(f"\n⚠ Dimension mismatch - using first {min_len} features")
        print(f"  Protein prediction similarity: {protein_similarity:.4f}")
else:
    print(f"\n⚠ No cells found with perturbation '{perturbation_name}'")
    print(f"Available perturbations: {meta_df[sgRNA_col].dropna().unique()[:20]}")

## 9. Summary of Formats

In [None]:
print("="*60)
print("SUMMARY OF DATA FORMATS")
print("="*60)
print(f"\n1. INPUT: Average Control RNA Profile")
print(f"   Shape: {average_control_rna_profile.shape}")
print(f"   Type: pandas Series")
print(f"   Contains: Raw gene expression values")

print(f"\n2. STATE INPUT: AnnData Object")
print(f"   Shape: (1, {adata_input.n_vars})")
print(f"   Format: .h5ad file")
print(f"   .X contains: Gene expression matrix")

print(f"\n3. STATE OUTPUT: Embedding")
print(f"   Shape: {state_embedding.shape}")
print(f"   Type: numpy array")
print(f"   Contains: Compressed RNA representation")

print(f"\n4. CAPTAIN INPUT: RNA Embedding Tensor")
print(f"   Shape: (1, {state_embedding.shape[0]})")
print(f"   Type: torch.Tensor")
print(f"   Device: {device}")

print(f"\n5. CAPTAIN OUTPUT: Protein Prediction")
print(f"   Shape: {protein_prediction.shape}")
print(f"   Type: numpy array")
print(f"   Contains: Predicted protein expression")

print(f"\n6. GROUND TRUTH: Real Protein Expression")
if len(perturb_cells) > 0:
    print(f"   Shape: {real_protein.shape}")
    print(f"   Type: numpy array")
    print(f"   Contains: Measured protein levels")
print("="*60)

## 10. Save Embeddings for Later Use

In [None]:
# Create results directory
results_dir = "/home/nebius/cellian/results"
os.makedirs(results_dir, exist_ok=True)

# Save embeddings
np.save(f"{results_dir}/state_embedding_test.npy", state_embedding)
np.save(f"{results_dir}/captain_prediction_test.npy", protein_prediction)

print(f"✓ Embeddings saved to {results_dir}/")
print(f"  - state_embedding_test.npy")
print(f"  - captain_prediction_test.npy")

# Load them back to verify
loaded_state = np.load(f"{results_dir}/state_embedding_test.npy")
loaded_captain = np.load(f"{results_dir}/captain_prediction_test.npy")

print(f"\n✓ Verified: Loaded embeddings match original shapes")
print(f"  STATE: {loaded_state.shape}")
print(f"  CAPTAIN: {loaded_captain.shape}")

## Cleanup

In [None]:
# Clean up temporary files
import shutil
if os.path.exists(tmpdir):
    shutil.rmtree(tmpdir)
    print(f"✓ Cleaned up temporary directory: {tmpdir}")