[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rsinghlab/pyaging/blob/main/tutorials/tutorial_cpgptgrimage3.ipynb) [![Open In nbviewer](https://img.shields.io/badge/View%20in-nbviewer-orange)](https://nbviewer.jupyter.org/github/rsinghlab/pyaging/blob/main/tutorials/tutorial_cpgptgrimage3.ipynb)

# 🧬 ⚙️ CpGPT Reference Map Tutorial ⚙️ 🧬

Welcome to the CpGPT Reference Map Tutorial! 👋 

In this notebook, we'll walk you through how to map your data to a reference dataset for zero-shot label transfer.

## Table of Contents

0. [Read Quick Setup Tutorial](#0-read-quick-setup-tutorial)
1. [Setup Environment](#1-setup-environment)
2. [Download Target Data](#2-download-target-data)
3. [Download Reference Data](#3-download-reference-data)
4. [Load Model and Dependencies](#4-load-model-and-dependencies)
5. [Prepare Data Objects](#5-prepare-data-objects)
6. [Compute Sample Embeddings](#6-compute-sample-embeddings)
7. [Map Target to Reference](#7-map-target-to-reference)

## 0. Read Quick Setup Tutorial

Before, going through this tutorial, please familiarize yourself with the [quick setup tutorial](https://github.com/lcamillo/CpGPT/blob/main/tutorials/quick_setup.ipynb).

## 1. Setup Environment

We'll import the necessary Python packages and set up our environment for CpGPT. We'll be using a mix of standard data science libraries and CpGPT-specific modules. We'll also set some important variables that will be used throughout the notebook. Pay attention to these as you may need to adjust them based on your specific setup and requirements.

In [1]:
# Random seed for reproducibility
RANDOM_SEED = 42

# Directory paths
DEPENDENCIES_DIR = "../dependencies"
LLM_DEPENDENCIES_DIR = DEPENDENCIES_DIR + "/human"
DATA_DIR = "../data"
TARGET_PROCESSED_DIR = "../data/tutorials/processed/target_reference_map"
REFERENCE_PROCESSED_DIR = "../data/tutorials/processed/reference_reference_map"

MODEL_NAME = "small"
MODEL_CHECKPOINT_PATH = f"../dependencies/model/weights/{MODEL_NAME}.ckpt"
MODEL_CONFIG_PATH = f"../dependencies/model/config/{MODEL_NAME}.yaml"
MODEL_VOCAB_PATH = f"../dependencies/model/vocab/{MODEL_NAME}.json"

TARGET_BETAS_PATH = "../data/cpgcorpus/raw/GSE52238/GPL8490/betas/gse_betas.arrow"
TARGET_METADATA_PATH = "../data/cpgcorpus/raw/GSE52238/GPL8490/metadata/metadata.arrow"

REFERENCE_BETAS_PATH = "../data/altumage/raw/betas/betas.arrow"
REFERENCE_METADATA_PATH = "../data/altumage/raw/metadata/metadata.arrow"

# The maximum context length to give to the model
MAX_INPUT_LENGTH = 10_000 # you might wanna go higher hardware permitting

> **⚠️ Warning**
> 
> It is recommended to have a GPU for inference as CPU might be slow.
> 
> Reconstructing the methylome for a few hundred samples might take up to one hour on a CPU. ⌛
>
> This might be a great exercise in testing your patience.

### 1.2 Import packages


In [None]:
# Standard library imports
import warnings
import os
import json

warnings.simplefilter(action="ignore", category=FutureWarning)

# Plotting imports
import gdown
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyaging as pya
import seaborn as sns
from tqdm.rich import tqdm
import faiss

# Lightning imports
from lightning.pytorch import seed_everything

# cpgpt-specific imports
from cpgpt.data.components.cpgpt_datasaver import CpGPTDataSaver
from cpgpt.data.cpgpt_datamodule import CpGPTDataModule
from cpgpt.trainer.cpgpt_trainer import CpGPTTrainer
from cpgpt.data.components.dna_llm_embedder import DNALLMEmbedder
from cpgpt.data.components.illumina_methylation_prober import IlluminaMethylationProber
from cpgpt.infer.cpgpt_inferencer import CpGPTInferencer
from cpgpt.model.cpgpt_module import m_to_beta

# Set random seed for reproducibility
seed_everything(RANDOM_SEED, workers=True)
try:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
except:
    pass

Seed set to 42


42

## 2. Download Target Data

If you have your own data, please feel free to skip the following step but make sure it is saved in a .arrow format. Here, as an example target dataset, we'll use GSE52238, which contains methylation profiling data from a study on human somatic cell reprogramming. This dataset examines global DNA methylation changes during reprogramming of different somatic cell types (endoderm, mesoderm, and parthenogenetic germ cells) into iPSCs.

In [3]:
# First let's declare the inferencer
inferencer = CpGPTInferencer(dependencies_dir=DEPENDENCIES_DIR, data_dir=DATA_DIR)

inferencer.download_cpgcorpus_dataset("GSE52238")

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInitializing class CpGPTInferencer.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cuda.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing dependencies directory: ../dependencies[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing data directory: ../data[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 19 CpGPT models available such as age, age_cot, average_adultweight, etc.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 2088 GSE datasets available such as GSE100184, GSE100208, GSE100209, etc.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mDataset GSE52238 already exists at ../data/cpgcorpus/raw/GSE52238 (skipping download).[0m


## 3. Download Reference Data

As an example, we will use the AltumAge dataset as reference but feel free to use another atlas.

### 3.1 Download AltumAge Reference



In [4]:
# Download altumage data if not already downloaded
if not os.path.exists(REFERENCE_BETAS_PATH):
    url = (
        "https://drive.google.com/file/d/17iScHhjasi1qCL3JTte5bNaqlHV2gPQG/view?usp=share_link"
    )
    gdown.download(url, output=REFERENCE_BETAS_PATH, fuzzy=True)

# Download metadata if not already downloaded
if not os.path.exists(REFERENCE_METADATA_PATH):
    url = (
        "https://drive.google.com/file/d/1eyYzNNW6hhLZonkbV7xlhVefF23TSBJQ/view?usp=share_link"
    )
    gdown.download(url, output=REFERENCE_METADATA_PATH, fuzzy=True)

## 4. Load Model and Dependencies

For reference mapping, we recommend using the `small` model. If compute resources permit, then it might be worth trying with a bigger context window or the `large` model.

### 4.1 Download Checkpoint and Configuration Files

In [5]:
# Download the checkpoint and configuration files
inferencer.download_model(MODEL_NAME)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel checkpoint already exists at ../dependencies/model/weights/small.ckpt (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel config already exists at ../dependencies/model/config/small.yaml (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [33m[1mNo vocabulary file found for model 'small'.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mSuccessfully downloaded model 'small'.[0m


### 4.2 Load Model

In [6]:
# Load the model configuration
config = inferencer.load_cpgpt_config(MODEL_CONFIG_PATH)

# Load the model weights
model = inferencer.load_cpgpt_model(
    config,
    model_ckpt_path=MODEL_CHECKPOINT_PATH,
    strict_load=True,
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoaded CpGPT model config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInstantiated CpGPT model from config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cuda.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoading checkpoint from: ../dependencies/model/weights/small.ckpt[0m


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mCheckpoint loaded into the model.[0m


### 4.3 Download Dependencies

In [None]:
inferencer.download_dependencies(species="human")

## 5. Prepare Data Objects

### 5.1 Declare Embedder and Prober

In order to retrieve the sample embeddings, we need to memory-map the data. This is done by using the `CpGPTDataSaver` class. We first need to define the `DNALLMEmbedder` and `IlluminaMethylationProber` classes, which contain the information about the DNA LLM Embeddings and the conversion between Illumina array probes to genomic locations, respectively.

In [7]:
embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


In [8]:
prober = IlluminaMethylationProber(dependencies_dir=LLM_DEPENDENCIES_DIR, embedder=embedder)

[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mInitializing class IlluminaMethylationProber.[0m
[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina methylation manifest files will be stored under ../dependencies/human/manifests.[0m
[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina metadata dictionary loaded successfully.[0m


### 5.2 Memory-Map Data

In [9]:
# Define datasaver
target_datasaver = CpGPTDataSaver(data_paths=TARGET_BETAS_PATH, processed_dir=TARGET_PROCESSED_DIR)

# Process the file
target_datasaver.process_files(prober, embedder)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mInitializing class CpGPTDataSaver.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mDataset folders will be stored under ../data/tutorials/processed/target_reference_map.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mLoaded existing dataset metrics.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mLoaded existing genomic locations.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mStarting file processing.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1m1 files already processed. Skipping those.[0m


In [10]:
# Define datasaver
reference_datasaver = CpGPTDataSaver(data_paths=REFERENCE_BETAS_PATH, processed_dir=REFERENCE_PROCESSED_DIR)

# Process the file
reference_datasaver.process_files(prober, embedder)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mInitializing class CpGPTDataSaver.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mDataset folders will be stored under ../data/tutorials/processed/reference_reference_map.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mLoaded existing dataset metrics.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mLoaded existing genomic locations.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mStarting file processing.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1m1 files already processed. Skipping those.[0m


### 5.3 Declare data modules

Let's define two data modules: one for the target data and one for the reference data.

In [11]:
# Define datamodule
target_datamodule = CpGPTDataModule(
    predict_dir=TARGET_PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_INPUT_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


In [12]:
# Define datamodule
reference_datamodule = CpGPTDataModule(
    predict_dir=REFERENCE_PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_INPUT_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


## 6. Compute Sample Embeddings

### 6.1 Declare Trainer

Given all models were trained under mixed precision, we'll use the `precision="16-mixed"` argument.

In [13]:
trainer = CpGPTTrainer(precision="16-mixed")

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### 6.2 Get Sample Embeddings

In [14]:
# Get the target sample embeddings
target_sample_embeddings = trainer.predict(
    model=model,
    datamodule=target_datamodule,
    predict_mode="forward",
    return_keys=["sample_embedding"]
)

target_sample_embeddings = target_sample_embeddings['sample_embedding']

You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

/data/miniforge3/envs/cpgpt/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Dependending on how large the reference dataset is, it might take a while to run the following. Please be patient.

In [15]:
reference_embeddings_path = REFERENCE_PROCESSED_DIR + "/reference_sample_embeddings.pt"

if os.path.exists(reference_embeddings_path):
    reference_sample_embeddings = torch.load(reference_embeddings_path)
else:
    reference_sample_embeddings = trainer.predict(
        model=model,
        datamodule=reference_datamodule,
        predict_mode="forward",
        return_keys=["sample_embedding"]
    )

    reference_sample_embeddings = reference_sample_embeddings['sample_embedding']

    # Save the embeddings for future use
    torch.save(reference_sample_embeddings, reference_embeddings_path)

## 7. Map Target to Reference

### 7.1 Load Metadata

In [41]:
target_metadata_df = pd.read_feather(TARGET_METADATA_PATH)
target_metadata_df.index = target_metadata_df['title']

target_metadata_df.head()

Unnamed: 0_level_0,GSM_ID,title,geo_accession,status,submission_date,last_update_date,type,channel_count,source_name_ch1,organism_ch1,...,contact_institute,contact_address,contact_city,contact_zip/postal_code,contact_country,supplementary_file,data_row_count,cell type:ch1,sample id:ch1,source of cell:ch1
title,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Human embryonic stem cell H9,GSM1261682,Human embryonic stem cell H9,GSM1261682,Public on May 01 2014,Nov 08 2013,May 01 2014,genomic,1,Embryonic stem cells,Homo sapiens,...,The Hebrew University of Jerusalem,Givat Ram,Jerusalem,91904,Israel,NONE,27578,embryonic stem cell (H9),SAMPLE 1,ICM
Beta cell from human islet 1,GSM1261683,Beta cell from human islet 1,GSM1261683,Public on May 01 2014,Nov 08 2013,May 01 2014,genomic,1,human islets,Homo sapiens,...,The Hebrew University of Jerusalem,Givat Ram,Jerusalem,91904,Israel,NONE,27578,Beta cell,SAMPLE 2,Pancreatic islet
Beta cell from human islet 2,GSM1261684,Beta cell from human islet 2,GSM1261684,Public on May 01 2014,Nov 08 2013,May 01 2014,genomic,1,human islets,Homo sapiens,...,The Hebrew University of Jerusalem,Givat Ram,Jerusalem,91904,Israel,NONE,27578,Beta cell,SAMPLE 3,Pancreatic islet
Reprogrammed beta cell 1,GSM1261685,Reprogrammed beta cell 1,GSM1261685,Public on May 01 2014,Nov 08 2013,May 01 2014,genomic,1,Reprogrammed beta cell,Homo sapiens,...,The Hebrew University of Jerusalem,Givat Ram,Jerusalem,91904,Israel,NONE,27578,reprogrammed beta cell,SAMPLE 4,Reprogrammed from SAMPLE 2
Reprogrammed beta cell 2,GSM1261686,Reprogrammed beta cell 2,GSM1261686,Public on May 01 2014,Nov 08 2013,May 01 2014,genomic,1,Reprogrammed beta cell,Homo sapiens,...,The Hebrew University of Jerusalem,Givat Ram,Jerusalem,91904,Israel,NONE,27578,reprogrammed beta cell,SAMPLE 5,Reprogrammed from SAMPLE 3


In [17]:
reference_metadata_df = pd.read_feather(REFERENCE_METADATA_PATH)

reference_metadata_df.head()

Unnamed: 0,dataset,tissue_type,age,gender
8363800025_R04C02,E-MTAB-2372,blood wbc,37.0,F
8363800056_R01C02,E-MTAB-2372,blood wbc,50.0,M
8363800057_R02C02,E-MTAB-2372,blood wbc,68.0,F
8359018119_R06C01,E-MTAB-2372,blood wbc,43.0,F
8363800025_R06C01,E-MTAB-2372,blood wbc,43.0,F


### 7.2 Find Nearest Neighbors

In [82]:
# Define parameters
k = 5  # Number of nearest neighbors to consider

# Create FAISS index for efficient similarity search
index = faiss.IndexFlatL2(reference_sample_embeddings.shape[1])
index.add(reference_sample_embeddings)

# Find k nearest neighbors for each target sample
distances, indices = index.search(target_sample_embeddings, k)

# Prepare for label mapping
labels_to_predict = reference_metadata_df.columns.tolist()
mapped_labels = {label: [] for label in labels_to_predict}

# Process each target sample
for i in tqdm(range(target_sample_embeddings.shape[0])):

    for label in labels_to_predict:
        # Extract values from nearest neighbors
        neighbor_values = reference_metadata_df[label].iloc[indices[i]]
        
        if pd.api.types.is_float_dtype(neighbor_values):
            # Numerical values: use mean
            mapped_labels[label].append(neighbor_values.mean())
        else:
            # Categorical values: use majority vote
            if neighbor_values.isna().all() or len(neighbor_values) == 0:
                # If all values are None or empty, append None
                mapped_labels[label].append(None)
            else:
                value_counts = neighbor_values.value_counts()
                if len(value_counts) > 0:
                    most_common_value = value_counts.index[0]  # First value in case of ties
                    mapped_labels[label].append(most_common_value)
                else:
                    mapped_labels[label].append(None)

# Create final DataFrame with mapped labels
mapped_labels_df = pd.DataFrame(mapped_labels, index=target_metadata_df.index)

Output()

In [84]:
mapped_labels_df.head()

Unnamed: 0_level_0,dataset,tissue_type,age,gender
title,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Human embryonic stem cell H9,GSE36642,blood cord,-0.09589,F
Beta cell from human islet 1,GSE36642,blood cord,8.146301,M
Beta cell from human islet 2,GSE22595,dermis,40.0,M
Reprogrammed beta cell 1,GSE36642,blood cord,-0.09589,M
Reprogrammed beta cell 2,GSE36642,blood cord,-0.076712,M
