# 🧬 ⚙️ CpGPT Quick Setup Tutorial ⚙️ 🧬

Welcome to the CpGPT Quick Setup Tutorial! 👋 

In this notebook, we'll walk you through the fastest way of using CpGPT for your research.

## Table of Contents

1. [Setup Environment](#1-setup-environment)
2. [Retrieve DNA LLM Embeddings](#2-retrieve-dna-llm-embeddings)
3. [Download and Load Model](#3-download-and-load-model)
4. [Prepare Data Objects](#4-prepare-data-objects)
5. [Run Inference](#5-run-inference)

## 1. Setup Environment

We'll import the necessary Python packages and set up our environment for CpGPT. We'll be using a mix of standard data science libraries and CpGPT-specific modules. We'll also set some important variables that will be used throughout the notebook. Pay attention to these as you may need to adjust them based on your specific setup and requirements.

In [1]:
# Random seed for reproducibility
RANDOM_SEED = 42

# Directory paths
DEPENDENCIES_DIR = "../dependencies"
LLM_DEPENDENCIES_DIR = DEPENDENCIES_DIR + "/human"
DATA_DIR = "../data"
PROCESSED_DIR = "../data/tutorials/processed/quick_setup"

MODEL_NAME = "cancer"
MODEL_CHECKPOINT_PATH = f"../dependencies/model/weights/{MODEL_NAME}.ckpt"
MODEL_CONFIG_PATH = f"../dependencies/model/config/{MODEL_NAME}.yaml"
MODEL_VOCAB_PATH = f"../dependencies/model/vocab/{MODEL_NAME}.json"

ARROW_DF_PATH = "../data/cpgcorpus/raw/GSE182215/GPL13534/betas/QCDPB.arrow"
ARROW_DF_FILTERED_PATH = "../data/tutorials/raw/toy_filtered.arrow"

# The maximum context length to give to the model
MAX_INPUT_LENGTH = 20_000 # you might wanna go higher hardware permitting
MAX_ATTN_LENGTH = 1_000

> **⚠️ Warning**
> 
> It is recommended to have a GPU for inference as CPU might be slow.
> 
> Reconstructing the methylome for a few hundred samples might take up to one hour on a CPU. ⌛
>
> This might be a great exercise in testing your patience.

### 1.2 Import packages


In [2]:
# Standard library imports
import warnings
import os
import json

warnings.simplefilter(action="ignore", category=FutureWarning)

# Plotting imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyaging as pya
import seaborn as sns

# Lightning imports
from lightning.pytorch import seed_everything

# cpgpt-specific imports
from cpgpt.data.components.cpgpt_datasaver import CpGPTDataSaver
from cpgpt.data.cpgpt_datamodule import CpGPTDataModule
from cpgpt.trainer.cpgpt_trainer import CpGPTTrainer
from cpgpt.data.components.dna_llm_embedder import DNALLMEmbedder
from cpgpt.data.components.illumina_methylation_prober import IlluminaMethylationProber
from cpgpt.infer.cpgpt_inferencer import CpGPTInferencer
from cpgpt.model.cpgpt_module import m_to_beta

# Set random seed for reproducibility
seed_everything(RANDOM_SEED, workers=True)

Seed set to 42


42

## 2. Retrieve DNA LLM Embeddings

To retrieve the DNA LLM Embeddings, there are two options:
- download the dependencies with all of the sequence embeddings for the CpG sites targeted by the Illumina arrays;
- generate from scratch using the DNA LLM directly for loci outside of the ones already available for download.

In [3]:
# First let's declare the inferencer
inferencer = CpGPTInferencer(dependencies_dir=DEPENDENCIES_DIR, data_dir=DATA_DIR)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInitializing class CpGPTInferencer.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cpu.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing dependencies directory: ../dependencies[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing data directory: ../data[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 19 CpGPT models available such as age, age_cot, average_adultweight, boa, cancer, clock_proxies, diseases, epicvmammal, hannum, hannum_cot, human_rrbs_atlas, large, mammalian, maximum_lifespan, mortality, proteins, relative_age, scimetv3, small, etc.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 2089 GSE datasets available such as GSE100184, GSE100208, GSE100209, etc.[0m


In [4]:
inferencer

<cpgpt.infer.cpgpt_inferencer.CpGPTInferencer at 0x7ff60980a190>

### 2.1 Download Dependencies


The already-processed dependencies contain the sequence embeddings for both human (`s3://cpgpt-lucascamillo-public/dependencies/human`) and several mammalian species (`s3://cpgpt-lucascamillo-public/dependencies/mammalian`). Here, let's use the human as an example:

In [5]:
inferencer.download_dependencies(species="human")

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mDependencies for human already exist at ../dependencies/human (skipping download).[0m


### 2.2 Generate DNA LLM Embeddings


To generate genomic embeddings for loci outside of the ones already available for download, we can use the `DNALLMEmbedder` class. We need the loci in a list with the following format from ENSEMBL: 'chromosome:position'. Be mindful as this function can take a long time to run dependending on your GPU. For instance, embeddings ~1M genomic loci from the Illumina arrays takes about 12h in an RTX 4090.

In [6]:
if not os.path.exists(LLM_DEPENDENCIES_DIR):

    # List CpG genomic locations
    example_genomic_locations = ['1:100000', '1:250500', 'X:2031253']

    # Declare required class
    embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)

    # Parse the embeddings
    embedder.parse_dna_embeddings(
        example_genomic_locations,
        "homo_sapiens",
        dna_llm="nucleotide-transformer-v2-500m-multi-species",
        dna_context_len=2001,
    )

## 3. Download and Load Model

Please first check the model zoo for the available models and their corresponding features on the README.md file. To load any given model, you first need to define the dictionary structure with the hyperparameters and use the `CpGPTInferencer` class.

### 3.1 Download Checkpoint and Configuration Files

In [7]:
# Download the checkpoint and configuration files
inferencer.download_model(MODEL_NAME)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel checkpoint already exists at ../dependencies/model/weights/cancer.ckpt (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel config already exists at ../dependencies/model/config/cancer.yaml (skipping download).[0m


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel vocabulary already exists at ../dependencies/model/vocab/cancer.json (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mSuccessfully downloaded model 'cancer'.[0m


### 3.2 Load Model

In [8]:
# Load the model configuration
config = inferencer.load_cpgpt_config(MODEL_CONFIG_PATH)

# Load the model weights
model = inferencer.load_cpgpt_model(
    config,
    model_ckpt_path=MODEL_CHECKPOINT_PATH,
    strict_load=True,
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoaded CpGPT model config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInstantiated CpGPT model from config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cpu.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoading checkpoint from: ../dependencies/model/weights/cancer.ckpt[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mCheckpoint loaded into the model.[0m


In [9]:
model

CpGPTLitModule(
  (net): CpGPT(
    (position_encoder): RotaryPositionalEmbeddings()
    (absolute_position_encoder): AbsolutePositionalEncoding(
      (dropout): Dropout(p=0.01, inplace=False)
    )
    (dna_encoder): MLPBlock(
      (input_norm): Identity()
      (input_adapter): Linear(in_features=1024, out_features=128, bias=True)
      (blocks): ModuleList(
        (0-2): 3 x Sequential(
          (0): RMSNorm((128,), eps=None, elementwise_affine=True)
          (1): Linear(in_features=128, out_features=512, bias=False)
          (2): SwiGLU()
          (3): Dropout(p=0.01, inplace=False)
          (4): Linear(in_features=256, out_features=128, bias=False)
        )
      )
      (output_norm): Identity()
      (output_adapter): Linear(in_features=128, out_features=128, bias=False)
    )
    (meth_encoder): MLPBlock(
      (input_norm): Identity()
      (input_adapter): Linear(in_features=1, out_features=128, bias=True)
      (blocks): ModuleList(
        (0): Sequential(
        

## 4 Prepare Data Objects

In order to perform inference, we need to prepare the data objects, which are essentially memory-mapped versions for faster loading. As an example, let's download a toy dataset from the _CpGCorpus_ database.

### 4.1 Download and Load Toy Data

In [10]:
inferencer.download_cpgcorpus_dataset("GSE182215")

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mDataset GSE182215 already exists at ../data/cpgcorpus/raw/GSE182215 (skipping download).[0m


In [11]:
inferencer

<cpgpt.infer.cpgpt_inferencer.CpGPTInferencer at 0x7ff60980a190>

In [12]:
df = pd.read_feather(ARROW_DF_PATH)

In [13]:
df

Unnamed: 0,GSM_ID,cg00000029,cg00000108,cg00000109,cg00000165,cg00000236,cg00000289,cg00000292,cg00000321,cg00000363,...,rs7746156,rs798149,rs845016,rs877309,rs9292570,rs9363764,rs939290,rs951295,rs966367,rs9839873
0,GSM5525203,0.324878,0.958127,0.853575,,0.854126,,0.791321,0.249139,0.291044,...,0.025733,0.462424,0.100629,0.979795,0.019506,0.947983,0.054746,0.495169,0.675477,0.782523
1,GSM5525204,0.298781,0.939107,0.897159,0.217733,0.874908,,0.828892,0.228412,0.397047,...,0.978182,0.017012,0.930166,0.568672,0.493534,0.042713,0.620214,0.511735,0.069878,0.941439
2,GSM5525205,,,,,0.472385,,,,,...,,,,,,,,,,
3,GSM5525206,0.125208,0.961126,0.822223,0.229362,0.861563,,0.877324,0.178019,0.377428,...,0.550734,0.444213,0.64401,0.56341,0.460727,0.961204,0.051867,0.487339,0.126543,0.899515
4,GSM5525207,0.278861,0.970059,0.929905,0.171255,0.907603,0.820531,0.893471,0.185116,0.377647,...,0.495878,0.020168,0.483792,0.021345,0.015444,0.968064,0.551504,0.970979,0.062037,0.764114
5,GSM5525208,0.094985,0.954513,,,,,0.804718,0.285387,0.261808,...,0.031089,0.020395,0.858414,0.018124,0.978014,0.504218,0.043601,0.528654,,0.880408
6,GSM5525209,0.158659,0.968949,0.914395,0.304291,0.892641,,0.865899,0.197152,0.388566,...,0.031904,0.979464,0.057517,0.978749,0.971558,0.563094,0.972008,0.513547,0.588638,0.763504
7,GSM5525210,0.428829,0.962406,0.853895,,0.867415,,0.822647,0.233513,0.347846,...,0.564834,0.020484,0.564318,0.977987,0.975255,0.585108,0.620957,0.553581,,0.782224
8,GSM5525211,0.353596,0.95313,,,0.841879,,0.916654,0.247867,0.242172,...,0.034661,0.022217,,0.024152,0.435065,0.502818,0.561926,0.574419,0.893579,0.881452
9,GSM5525212,0.254832,0.962125,0.913793,0.146354,0.859664,,0.85245,0.255339,0.336289,...,0.979984,0.983415,0.947795,0.583705,0.488332,0.534642,0.039321,0.521485,0.054215,0.791264


There is no need to impute the methylation data for CpGPT -- it simply ignores the missing values.

In [7]:
df = pd.read_feather(ARROW_DF_PATH)
df.set_index('GSM_ID', inplace=True)
df.head()

Unnamed: 0_level_0,cg00000029,cg00000108,cg00000109,cg00000165,cg00000236,cg00000289,cg00000292,cg00000321,cg00000363,cg00000622,...,rs7746156,rs798149,rs845016,rs877309,rs9292570,rs9363764,rs939290,rs951295,rs966367,rs9839873
GSM_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GSM5525203,0.324878,0.958127,0.853575,,0.854126,,0.791321,0.249139,0.291044,0.013388,...,0.025733,0.462424,0.100629,0.979795,0.019506,0.947983,0.054746,0.495169,0.675477,0.782523
GSM5525204,0.298781,0.939107,0.897159,0.217733,0.874908,,0.828892,0.228412,0.397047,0.013511,...,0.978182,0.017012,0.930166,0.568672,0.493534,0.042713,0.620214,0.511735,0.069878,0.941439
GSM5525205,,,,,0.472385,,,,,,...,,,,,,,,,,
GSM5525206,0.125208,0.961126,0.822223,0.229362,0.861563,,0.877324,0.178019,0.377428,0.017229,...,0.550734,0.444213,0.64401,0.56341,0.460727,0.961204,0.051867,0.487339,0.126543,0.899515
GSM5525207,0.278861,0.970059,0.929905,0.171255,0.907603,0.820531,0.893471,0.185116,0.377647,0.012591,...,0.495878,0.020168,0.483792,0.021345,0.015444,0.968064,0.551504,0.970979,0.062037,0.764114


In [8]:
df.shape

(38, 485578)

### 4.2 Filter Vocab Features and Save Data

While not strictly required, filtering for the features used in finetuning gives you the best chance of achieving good performance.

In [9]:
# Load list
vocab = json.load(open(MODEL_VOCAB_PATH, 'r'))

In [10]:
vocab

{'input': ['cg00000292',
  'cg00002426',
  'cg00003994',
  'cg00005847',
  'cg00008493',
  'cg00009407',
  'cg00011459',
  'cg00012199',
  'cg00012386',
  'cg00012792',
  'cg00013618',
  'cg00014085',
  'cg00014837',
  'cg00015770',
  'cg00019495',
  'cg00020533',
  'cg00021527',
  'cg00022866',
  'cg00024396',
  'cg00024812',
  'cg00025991',
  'cg00027083',
  'cg00027674',
  'cg00029826',
  'cg00031162',
  'cg00032227',
  'cg00033516',
  'cg00033773',
  'cg00034039',
  'cg00035347',
  'cg00035623',
  'cg00037763',
  'cg00037940',
  'cg00040861',
  'cg00040873',
  'cg00043004',
  'cg00043080',
  'cg00044245',
  'cg00047050',
  'cg00050312',
  'cg00051979',
  'cg00054706',
  'cg00056767',
  'cg00057593',
  'cg00058938',
  'cg00059424',
  'cg00059930',
  'cg00060762',
  'cg00061059',
  'cg00062776',
  'cg00063144',
  'cg00065385',
  'cg00065408',
  'cg00066816',
  'cg00067471',
  'cg00069261',
  'cg00071250',
  'cg00072216',
  'cg00075967',
  'cg00076645',
  'cg00077877',
  'cg00078194',

In [15]:
type(vocab.keys())

dict_keys

In [16]:
vocab.keys()

dict_keys(['input', 'output'])

In [17]:
vocab['input']

['cg00000292',
 'cg00002426',
 'cg00003994',
 'cg00005847',
 'cg00008493',
 'cg00009407',
 'cg00011459',
 'cg00012199',
 'cg00012386',
 'cg00012792',
 'cg00013618',
 'cg00014085',
 'cg00014837',
 'cg00015770',
 'cg00019495',
 'cg00020533',
 'cg00021527',
 'cg00022866',
 'cg00024396',
 'cg00024812',
 'cg00025991',
 'cg00027083',
 'cg00027674',
 'cg00029826',
 'cg00031162',
 'cg00032227',
 'cg00033516',
 'cg00033773',
 'cg00034039',
 'cg00035347',
 'cg00035623',
 'cg00037763',
 'cg00037940',
 'cg00040861',
 'cg00040873',
 'cg00043004',
 'cg00043080',
 'cg00044245',
 'cg00047050',
 'cg00050312',
 'cg00051979',
 'cg00054706',
 'cg00056767',
 'cg00057593',
 'cg00058938',
 'cg00059424',
 'cg00059930',
 'cg00060762',
 'cg00061059',
 'cg00062776',
 'cg00063144',
 'cg00065385',
 'cg00065408',
 'cg00066816',
 'cg00067471',
 'cg00069261',
 'cg00071250',
 'cg00072216',
 'cg00075967',
 'cg00076645',
 'cg00077877',
 'cg00078194',
 'cg00079056',
 'cg00079563',
 'cg00080012',
 'cg00081935',
 'cg000846

In [19]:
len(vocab['input'])

19948

In [18]:
vocab['output']

['cpgpt_cancer_logit']

In [11]:
df = df.loc[:, df.columns.isin(vocab['input'])]
df.head()

Unnamed: 0_level_0,cg00000292,cg00002426,cg00003994,cg00005847,cg00008493,cg00009407,cg00011459,cg00012199,cg00012386,cg00012792,...,cg27650175,cg27650434,cg27652350,cg27653134,cg27654142,cg27655905,cg27657283,cg27662379,cg27662877,cg27665659
GSM_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GSM5525203,0.791321,0.905377,0.091164,0.090651,0.951774,0.048334,0.93864,0.035517,0.056877,0.08757,...,0.045768,0.072923,0.132974,0.94982,0.065028,0.063921,0.052416,0.07748,0.04327,0.05759
GSM5525204,0.828892,0.96462,0.042569,0.125086,0.945982,0.052446,0.951808,0.05204,0.063641,0.102222,...,0.053291,0.089544,0.130468,0.959386,0.066889,0.055794,0.044713,0.069569,0.039361,0.070515
GSM5525205,,,,,,,,,0.05828,,...,,,,,,,,,,
GSM5525206,0.877324,0.920315,0.042593,0.201401,0.951397,0.058045,0.947452,0.051123,0.060039,0.091149,...,0.046296,0.112071,0.096274,0.927876,0.088435,0.056817,0.050824,0.068512,0.070187,0.063018
GSM5525207,0.893471,0.943485,0.039004,0.139082,0.952827,0.045738,0.950823,0.036847,0.041898,0.082241,...,0.047588,0.105776,0.114468,0.940735,0.057654,0.04563,0.036157,0.05143,0.041078,0.082783


In [21]:
df.to_feather(ARROW_DF_FILTERED_PATH)

### 4.3 Memory-Map Data

In order to perform inference, we need to memory-map the data. This is done by using the `CpGPTDataSaver` class. We first need to define the `DNALLMEmbedder` and `IlluminaMethylationProber` classes, which contain the information about the DNA LLM Embeddings and the conversion between Illumina array probes to genomic locations, respectively.

In [22]:
embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


In [23]:
prober = IlluminaMethylationProber(dependencies_dir=LLM_DEPENDENCIES_DIR, embedder=embedder)

[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mInitializing class IlluminaMethylationProber.[0m
[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina methylation manifest files will be stored under ../dependencies/human/manifests.[0m


[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina metadata dictionary loaded successfully.[0m


In [24]:
# Define datasaver
quick_setup_datasaver = CpGPTDataSaver(data_paths=ARROW_DF_FILTERED_PATH, processed_dir=PROCESSED_DIR)

# Process the file
quick_setup_datasaver.process_files(prober, embedder)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mInitializing class CpGPTDataSaver.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mDataset folders will be stored under ../data/tutorials/processed/quick_setup.[0m


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mNo existing dataset metrics found. Please process files.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mNo existing genomic locations found. Please process files.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mStarting file processing.[0m


Output()

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [33m[1mNo species column found. Defaulting to homo_sapiens.[0m


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mFile processing completed.[0m


### 4.4 Declare data module

Let's define two data modules: one for the forward pass and reconstructing the methylation, and another one the attention weights.

In [25]:
# Define datamodule
quick_setup_datamodule = CpGPTDataModule(
    predict_dir=PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_INPUT_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

# Define datamodule
quick_setup_datamodule_attn = CpGPTDataModule(
    predict_dir=PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_ATTN_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m


[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


## 5. Run Inference

There are several ways to perform inference with CpGPT. Here, we'll go through the most common ones.

### 5.1 Declare Trainer

Given all models were trained under mixed precision, we'll use the `precision="16-mixed"` argument. However, if you finetune it using a different precision, you can change that accordingly.

In [26]:
trainer = CpGPTTrainer(precision="16-mixed")

/grand/GeomicVar/tarak/cpgpt/cpgpt_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### 5.2 Get Sample Embeddings

In [27]:
quick_setup_sample_embeddings = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="forward",
    return_keys=["sample_embedding"]
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


Output()

/grand/GeomicVar/tarak/cpgpt/cpgpt_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


In [28]:
quick_setup_sample_embeddings

{'sample_embedding': tensor([[-2.6272e-01, -7.3696e-02, -2.6239e-02,  ..., -9.3649e-02,
          -3.6611e-02, -1.8364e-02],
         [-3.0251e-01, -7.7571e-02, -5.5078e-02,  ..., -1.1866e-01,
          -5.1626e-02, -2.7392e-03],
         [-5.5322e-02,  1.3166e-03, -7.8424e-02,  ..., -3.0264e-02,
          -1.0671e-02, -1.4953e-01],
         ...,
         [-2.0327e-01, -5.4256e-02, -7.8379e-02,  ..., -9.7128e-02,
           3.3823e-02, -4.2534e-02],
         [-2.3025e-01, -7.0735e-02, -6.5348e-02,  ..., -1.0442e-01,
           1.2583e-02, -2.7886e-02],
         [-2.0669e-01, -5.3605e-02, -1.0745e-01,  ..., -9.1940e-02,
          -2.3419e-04, -7.1221e-02]])}

In [30]:
quick_setup_sample_embeddings.keys()

dict_keys(['sample_embedding'])

In [31]:
len(quick_setup_sample_embeddings['sample_embedding'])

38

### 5.3 Predict Phenotypes

In [20]:
quick_setup_pred_conditions = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="forward",
    return_keys=["pred_conditions"]
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [21]:
quick_setup_pred_conditions

{'pred_conditions': tensor([[ 0.3667],
         [ 0.0154],
         [ 6.9766],
         [ 0.4402],
         [-0.6748],
         [-0.2629],
         [-0.7241],
         [-2.7090],
         [-0.1346],
         [-0.4250],
         [-0.0770],
         [ 0.1849],
         [-0.0803],
         [ 0.4863],
         [ 2.5410],
         [-1.4512],
         [-3.1914],
         [ 3.1328],
         [-0.7524],
         [-1.4854],
         [-1.8594],
         [-1.4404],
         [-2.0391],
         [-1.4297],
         [-2.4863],
         [-2.0703],
         [-2.4922],
         [-2.4121],
         [-2.3438],
         [-1.7637],
         [-1.4941],
         [-2.4941],
         [-0.4998],
         [-0.5352],
         [-1.9775],
         [-3.3359],
         [-1.0107],
         [-0.5669]], dtype=torch.float16)}

### 5.4 Reconstruct Methylation

As an example, let's get some the reconstructed methylation values for some locations of interest based on the Illumina probes.

In [22]:
# Random probes for demonstration
probes = list(df.columns[0:100])

probes[0:5]

['cg00000292', 'cg00002426', 'cg00003994', 'cg00005847', 'cg00008493']

In [23]:
# Convert probes to genomic locations
genomic_locations = prober.locate_probes(probes, "homo_sapiens")

genomic_locations[0:5]

['16:28878778', '3:57757815', '7:15686236', '2:176164344', '14:93347430']

In [24]:
quick_setup_pred_meth = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="reconstruct",
    genomic_locations=genomic_locations,
    species="homo_sapiens",
    return_keys=["pred_meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Be mindful as the reconstructed values are M values, not beta values. Therefore, you need to convert them to beta values using the `m_to_beta` function.

In [25]:
quick_setup_pred_meth["pred_meth"] = m_to_beta(quick_setup_pred_meth["pred_meth"])
quick_setup_pred_meth

{'pred_meth': tensor([[0.8501, 0.8633, 0.0503,  ..., 0.0348, 0.9292, 0.7095],
         [0.8706, 0.8833, 0.0492,  ..., 0.0340, 0.9419, 0.7275],
         [0.3799, 0.4067, 0.2776,  ..., 0.3308, 0.4133, 0.2937],
         ...,
         [0.7925, 0.7881, 0.0529,  ..., 0.0367, 0.9351, 0.7075],
         [0.8247, 0.8291, 0.0523,  ..., 0.0337, 0.9351, 0.7031],
         [0.6494, 0.4927, 0.0576,  ..., 0.0337, 0.9385, 0.7085]],
        dtype=torch.float16)}

A more powerful way of reconstructing the methylation values is using chain-of-thought. With additional test-time compute, we can let the model "think harder" about the problem, which can lead to better performance. However, it also takes considerably longer dependending on the number of thinking steps.

In [26]:
quick_setup_pred_meth_cot = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="reconstruct",
    genomic_locations=genomic_locations,
    species="homo_sapiens",
    n_thinking_steps=5,
    thinking_step_size=1000,
    uncertainty_quantile=0.1,
    return_keys=["pred_meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [27]:
quick_setup_pred_meth_cot["pred_meth"] = m_to_beta(quick_setup_pred_meth_cot["pred_meth"])
quick_setup_pred_meth_cot

{'pred_meth': tensor([[0.8516, 0.8486, 0.0482,  ..., 0.0344, 0.9248, 0.7080],
         [0.8691, 0.8696, 0.0484,  ..., 0.0330, 0.9434, 0.7383],
         [0.4485, 0.4104, 0.3074,  ..., 0.3823, 0.3540, 0.2976],
         ...,
         [0.7856, 0.7690, 0.0512,  ..., 0.0355, 0.9326, 0.6919],
         [0.8271, 0.8198, 0.0520,  ..., 0.0332, 0.9331, 0.6934],
         [0.6523, 0.4868, 0.0560,  ..., 0.0331, 0.9297, 0.7061]],
        dtype=torch.float16)}

### 5.5 Analyze Attention Weights

The amount of memory required to store the attention weights is enormous. Therefore, we only use 1000 features for the demonstration. Also, remember that the the first token is the CLS token.

In [28]:
quick_setup_attn = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule_attn,
    predict_mode="attention",
    aggregate_heads="mean",
    layer_index=-1,
    return_keys=["attention_weights", "chroms", "positions", "mask_na", "meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [29]:
quick_setup_attn

{'attention_weights': tensor([[[0.0011, 0.0010, 0.0009,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          ...,
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],
 
         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0011, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          ...,
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],
 
         [[0.0121, 0.0115, 0.0097,  ...,    nan,    nan,    nan],
          [0.0114, 0.0123, 0.0098,  ...,    nan,    nan,    nan],
          [0.0105, 

: 