# 🧬 ⚙️ CpGPT Quick Setup Tutorial ⚙️ 🧬

Welcome to the CpGPT Quick Setup Tutorial! 👋 

In this notebook, we'll walk you through the fastest way of using CpGPT for your research.

## Table of Contents

1. [Setup Environment](#1-setup-environment)
2. [Retrieve DNA LLM Embeddings](#2-retrieve-dna-llm-embeddings)
3. [Download and Load Model](#3-download-and-load-model)
4. [Prepare Data Objects](#4-prepare-data-objects)
5. [Run Inference](#5-run-inference)

## 1. Setup Environment

We'll import the necessary Python packages and set up our environment for CpGPT. We'll be using a mix of standard data science libraries and CpGPT-specific modules. We'll also set some important variables that will be used throughout the notebook. Pay attention to these as you may need to adjust them based on your specific setup and requirements.

In [18]:
# Random seed for reproducibility
RANDOM_SEED = 42

# Directory paths
DEPENDENCIES_DIR = "../dependencies"
LLM_DEPENDENCIES_DIR = DEPENDENCIES_DIR + "/human"
DATA_DIR = "../data"
PROCESSED_DIR = "../data/tutorials/processed/fhs_setup"

MODEL_NAME = "age" #"cancer"
MODEL_CHECKPOINT_PATH = f"../dependencies/model/weights/{MODEL_NAME}.ckpt"
MODEL_CONFIG_PATH = f"../dependencies/model/config/{MODEL_NAME}.yaml"
MODEL_VOCAB_PATH = f"../dependencies/model/vocab/{MODEL_NAME}.json"

# ARROW_DF_PATH = "../data/cpgcorpus/raw/GSE182215/GPL13534/betas/QCDPB.arrow"
ARROW_DF_FILTERED_PATH = "../data/tutorials/raw/fhs_filtered.arrow"

# The maximum context length to give to the model
MAX_INPUT_LENGTH = 20_000 # you might wanna go higher hardware permitting
MAX_ATTN_LENGTH = 1_000

> **⚠️ Warning**
> 
> It is recommended to have a GPU for inference as CPU might be slow.
> 
> Reconstructing the methylome for a few hundred samples might take up to one hour on a CPU. ⌛
>
> This might be a great exercise in testing your patience.

### 1.2 Import packages


In [19]:
# Standard library imports
import warnings
import os
import json

warnings.simplefilter(action="ignore", category=FutureWarning)

# Plotting imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyaging as pya
import seaborn as sns

# Lightning imports
from lightning.pytorch import seed_everything

# cpgpt-specific imports
from cpgpt.data.components.cpgpt_datasaver import CpGPTDataSaver
from cpgpt.data.cpgpt_datamodule import CpGPTDataModule
from cpgpt.trainer.cpgpt_trainer import CpGPTTrainer
from cpgpt.data.components.dna_llm_embedder import DNALLMEmbedder
from cpgpt.data.components.illumina_methylation_prober import IlluminaMethylationProber
from cpgpt.infer.cpgpt_inferencer import CpGPTInferencer
from cpgpt.model.cpgpt_module import m_to_beta

# Set random seed for reproducibility
seed_everything(RANDOM_SEED, workers=True)

Seed set to 42


42

## 2. Retrieve DNA LLM Embeddings

To retrieve the DNA LLM Embeddings, there are two options:
- download the dependencies with all of the sequence embeddings for the CpG sites targeted by the Illumina arrays;
- generate from scratch using the DNA LLM directly for loci outside of the ones already available for download.

In [20]:
# First let's declare the inferencer
inferencer = CpGPTInferencer(dependencies_dir=DEPENDENCIES_DIR, data_dir=DATA_DIR)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInitializing class CpGPTInferencer.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cpu.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing dependencies directory: ../dependencies[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing data directory: ../data[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 19 CpGPT models available such as age, age_cot, average_adultweight, boa, cancer, clock_proxies, diseases, epicvmammal, hannum, hannum_cot, human_rrbs_atlas, large, mammalian, maximum_lifespan, mortality, proteins, relative_age, scimetv3, small, etc.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mThere are 2089 GSE datasets available such as GSE100184, GSE100208, GSE100209, etc.[0m


In [21]:
inferencer

<cpgpt.infer.cpgpt_inferencer.CpGPTInferencer at 0x7f503271d9d0>

### 2.1 Download Dependencies


The already-processed dependencies contain the sequence embeddings for both human (`s3://cpgpt-lucascamillo-public/dependencies/human`) and several mammalian species (`s3://cpgpt-lucascamillo-public/dependencies/mammalian`). Here, let's use the human as an example:

In [22]:
inferencer.download_dependencies(species="human")

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mDependencies for human already exist at ../dependencies/human (skipping download).[0m


### 2.2 Generate DNA LLM Embeddings


To generate genomic embeddings for loci outside of the ones already available for download, we can use the `DNALLMEmbedder` class. We need the loci in a list with the following format from ENSEMBL: 'chromosome:position'. Be mindful as this function can take a long time to run dependending on your GPU. For instance, embeddings ~1M genomic loci from the Illumina arrays takes about 12h in an RTX 4090.

In [23]:
if not os.path.exists(LLM_DEPENDENCIES_DIR):

    # List CpG genomic locations
    example_genomic_locations = ['1:100000', '1:250500', 'X:2031253']

    # Declare required class
    embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)

    # Parse the embeddings
    embedder.parse_dna_embeddings(
        example_genomic_locations,
        "homo_sapiens",
        dna_llm="nucleotide-transformer-v2-500m-multi-species",
        dna_context_len=2001,
    )

## 3. Download and Load Model

Please first check the model zoo for the available models and their corresponding features on the README.md file. To load any given model, you first need to define the dictionary structure with the hyperparameters and use the `CpGPTInferencer` class.

### 3.1 Download Checkpoint and Configuration Files

In [24]:
# Download the checkpoint and configuration files
inferencer.download_model(MODEL_NAME)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel checkpoint already exists at ../dependencies/model/weights/age.ckpt (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel config already exists at ../dependencies/model/config/age.yaml (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mModel vocabulary already exists at ../dependencies/model/vocab/age.json (skipping download).[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mSuccessfully downloaded model 'age'.[0m


### 3.2 Load Model

In [25]:
# Load the model configuration
config = inferencer.load_cpgpt_config(MODEL_CONFIG_PATH)

# Load the model weights
model = inferencer.load_cpgpt_model(
    config,
    model_ckpt_path=MODEL_CHECKPOINT_PATH,
    strict_load=True,
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoaded CpGPT model config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mInstantiated CpGPT model from config.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mUsing device: cpu.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mLoading checkpoint from: ../dependencies/model/weights/age.ckpt[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTInferencer[0m: [1mCheckpoint loaded into the model.[0m


In [26]:
model

CpGPTLitModule(
  (net): CpGPT(
    (position_encoder): RotaryPositionalEmbeddings()
    (absolute_position_encoder): AbsolutePositionalEncoding(
      (dropout): Dropout(p=0.01, inplace=False)
    )
    (dna_encoder): MLPBlock(
      (input_norm): Identity()
      (input_adapter): Linear(in_features=1024, out_features=128, bias=True)
      (blocks): ModuleList(
        (0-2): 3 x Sequential(
          (0): RMSNorm((128,), eps=None, elementwise_affine=True)
          (1): Linear(in_features=128, out_features=512, bias=False)
          (2): SwiGLU()
          (3): Dropout(p=0.01, inplace=False)
          (4): Linear(in_features=256, out_features=128, bias=False)
        )
      )
      (output_norm): Identity()
      (output_adapter): Linear(in_features=128, out_features=128, bias=False)
    )
    (meth_encoder): MLPBlock(
      (input_norm): Identity()
      (input_adapter): Linear(in_features=1, out_features=128, bias=True)
      (blocks): ModuleList(
        (0): Sequential(
        

## 4 Prepare Data Objects

In order to perform inference, we need to prepare the data objects, which are essentially memory-mapped versions for faster loading. As an example, let's download a toy dataset from the _CpGCorpus_ database.

### 4.1 load FHS Data

In [None]:
root_dir = "/grand/GeomicVar/tarak/cpgpt/CpGPT/data_kirmani"
data_dir = os.path.join(root_dir, "phg001091.v5.FHS_DNAMethylation.methylation-data-matrixfmt.c1")

# Load the parquet file
# The CSV file is too large to load into memory, so we will use a parquet file instead (converted from CSV using convert_csv_to_parquet.py)
parquet_file = os.path.join(data_dir, "gen3_methylation_c1.parquet")
df = pd.read_parquet(parquet_file)

In [28]:
# load the metadata from Yash
metadata_df = pd.read_csv("/grand/GeomicVar/tarak/methylGPT/data_kirmani/fhs_chip_metadata_yp_05092025.tsv", sep="\t")

In [29]:
metadata_df

Unnamed: 0,Sample,subject_id,AgeAtBloodDraw,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,haschip,Gene,ExonicFunc,VAF,chip_binary
0,NWD101503,11465,63.503015,F,0.001995,-0.000976,-0.000891,-0.005666,-0.011618,-0.007526,0.000804,-0.003980,-0.002093,0.004659,-0.001234,1,DNMT3A,splicing,0.170,1.0
1,NWD122068,17510,71.623647,F,0.002355,-0.001631,-0.000852,0.001525,0.002503,-0.000561,0.000437,-0.000893,0.000159,0.001559,-0.000292,1,DNMT3A,nonsynonymous SNV,0.121,1.0
2,NWD125867,3253,56.318747,F,0.002322,-0.001669,-0.000809,0.000380,0.001179,0.000148,-0.000927,-0.000341,0.001978,-0.000106,0.000994,1,DNMT3A,frameshift insertion,0.094,1.0
3,NWD126946,5657,73.918013,F,0.002306,-0.001581,-0.000887,0.001144,0.001952,-0.001457,0.001603,-0.000156,-0.008303,-0.004131,0.000488,1,DNMT3A,frameshift deletion,0.161,1.0
4,NWD143985,9868,76.798292,F,0.002263,-0.001670,-0.000842,-0.000322,-0.000775,0.000953,-0.001219,0.000528,0.003866,-0.000707,0.001694,1,DNMT3A,splicing,0.174,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1939,NWD995511,22904,76.888642,F,0.002305,-0.001642,-0.000772,0.000911,0.001931,0.000218,-0.000435,-0.000511,0.001944,0.001706,-0.000293,0,,,,
1940,NWD995734,24905,64.064286,M,0.002358,-0.001689,-0.000774,0.001496,0.002349,-0.000174,-0.000610,0.000027,0.003265,0.003716,0.001137,0,,,,
1941,NWD998499,10703,64.028693,F,0.002350,-0.001710,-0.000661,0.001111,0.002059,0.000170,-0.000626,-0.000338,0.003814,0.004609,-0.000103,0,,,,
1942,NWD998833,21202,69.394991,M,0.002340,-0.001616,-0.000801,0.001095,0.002020,0.000033,-0.000615,-0.000132,0.003422,0.003995,0.000849,0,,,,


In [30]:
df

Unnamed: 0_level_0,3630,10226,22854,5641,13515,26098,4354,4892,8567,393,...,9791,24702,5833,1692,11354,5197,18255,2077,19268,1891
probe.id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
cg00455876,0.340616,0.656693,0.279148,0.363048,0.373903,0.341001,0.586655,0.472144,0.354304,0.406437,...,0.318181,0.574336,0.402250,0.542247,0.358957,0.407751,0.272390,0.430805,0.655632,0.391442
cg01707559,0.341025,0.128761,0.393184,0.331995,0.352611,0.371087,0.094050,0.317272,0.285407,0.301778,...,0.399217,0.087735,0.296028,0.104332,0.377702,0.334039,0.341866,0.333112,0.075540,0.355332
cg03244189,0.318389,0.107887,0.239077,0.223002,0.285023,0.264701,0.079138,0.345102,0.346678,0.288474,...,0.266094,0.102858,0.248124,0.151180,0.295379,0.364659,0.369742,0.369642,0.106677,0.292839
cg03695421,0.431295,0.707087,0.413813,0.387546,0.343865,,0.573828,0.435571,0.314586,0.372958,...,0.373995,0.757837,0.400924,0.761376,,0.386185,0.363767,0.432547,0.733518,0.411337
cg04689676,0.195909,0.068059,0.188884,0.200552,0.209481,0.218308,0.055082,0.301970,0.290017,0.348960,...,0.234856,0.081985,0.264703,0.049703,0.195470,0.342306,0.276520,0.297719,0.077993,0.252178
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
cg27553637,0.100750,0.092585,0.085468,0.085335,0.103806,0.071816,0.080868,0.147395,0.105346,0.084725,...,0.076533,0.086947,0.076649,0.081364,0.099946,0.077326,0.081617,0.123232,0.086865,0.090997
cg27575890,0.433371,0.380250,0.374776,0.381071,0.377906,0.337933,0.408589,0.427982,0.419254,0.376659,...,0.406881,0.333335,0.334248,0.480800,0.385733,0.395870,0.397042,0.408952,0.354646,0.417783
cg27585287,0.035770,0.043603,0.043326,0.044939,0.046019,0.038746,0.041655,0.055870,0.047275,0.062313,...,0.028468,0.045368,0.052548,0.068112,0.042444,0.043853,0.048230,0.048466,0.040517,0.054833
cg27592453,0.829706,0.826126,0.843467,0.798517,0.801194,0.823062,0.834816,0.830896,0.839264,0.835588,...,0.840031,0.874013,0.844692,0.820715,0.807588,0.864127,0.854933,0.855472,0.853569,0.787941


In [31]:
df = df.T # to have samples as rows and CpG sites as columns
df.index.name = 'sample_id'
# Adjust the index and remove 'probe.id' column
# df = df.rename(columns={'probe.id': 'sample_id'}).set_index('sample_id')
df.columns.name = None

In [32]:
df.head()

Unnamed: 0_level_0,cg00455876,cg01707559,cg03244189,cg03695421,cg04689676,cg04792227,cg04964672,cg13851368,cg14180491,cg14210405,...,cg27532867,cg27534599,cg27536559,cg27545494,cg27552198,cg27553637,cg27575890,cg27585287,cg27592453,cg27598806
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3630,0.340616,0.341025,0.318389,0.431295,0.195909,0.320329,0.700558,0.372461,0.519657,0.389533,...,0.052985,0.802725,0.178046,0.045268,0.840077,0.10075,0.433371,0.03577,0.829706,0.890113
10226,0.656693,0.128761,0.107887,0.707087,0.068059,0.19481,0.908372,0.82885,0.093967,0.400479,...,0.052807,0.793719,0.180809,0.037093,0.859426,0.092585,0.38025,0.043603,0.826126,0.879804
22854,0.279148,0.393184,0.239077,0.413813,0.188884,0.33169,0.58079,0.383639,0.52526,,...,0.043946,0.781412,0.18076,0.046037,0.876665,0.085468,0.374776,0.043326,0.843467,0.893957
5641,0.363048,0.331995,0.223002,0.387546,0.200552,0.35245,0.627094,0.340513,0.545357,0.349725,...,0.061374,0.76362,0.140669,0.047586,0.862237,0.085335,0.381071,0.044939,0.798517,0.894712
13515,0.373903,0.352611,0.285023,0.343865,0.209481,0.317307,0.654633,0.35459,0.478328,0.361801,...,0.059121,0.777402,0.140125,0.049228,0.854127,0.103806,0.377906,0.046019,0.801194,0.87672


In [33]:
# # Convert subject_id to string since sample_id is string type
# metadata_df['subject_id'] = metadata_df['subject_id'].astype(str)

# # Find common elements
# common_ids = set(df.index).intersection(set(metadata_df['subject_id']))
# print(f"Number of common IDs: {len(common_ids)}")

In [34]:
# common_ids

In [None]:
# Convert both identifiers to same type (string) for comparison
metadata_df['subject_id'] = metadata_df['subject_id'].astype(str)



In [36]:
df.index

Index(['3630', '10226', '22854', '5641', '13515', '26098', '4354', '4892',
       '8567', '393',
       ...
       '9791', '24702', '5833', '1692', '11354', '5197', '18255', '2077',
       '19268', '1891'],
      dtype='object', name='sample_id', length=1425)

In [37]:
metadata_df['subject_id']

0       11465
1       17510
2        3253
3        5657
4        9868
        ...  
1939    22904
1940    24905
1941    10703
1942    21202
1943     7114
Name: subject_id, Length: 1944, dtype: int64

In [38]:
# Get common IDs
common_ids = set(df.index.astype(str)).intersection(
    set(metadata_df['subject_id'].astype(str))
)
print(f"Number of common IDs: {len(common_ids)}")


Number of common IDs: 474


In [42]:
common_ids

{'10001',
 '10006',
 '10008',
 '10022',
 '10026',
 '10057',
 '10073',
 '10077',
 '10082',
 '10121',
 '10274',
 '10358',
 '1037',
 '10391',
 '10393',
 '10439',
 '10496',
 '10534',
 '10699',
 '107',
 '10709',
 '10729',
 '10734',
 '10793',
 '10794',
 '10817',
 '10819',
 '1083',
 '10943',
 '11016',
 '11147',
 '11179',
 '11285',
 '11316',
 '11354',
 '11436',
 '11656',
 '117',
 '1170',
 '11800',
 '11853',
 '11945',
 '12076',
 '1210',
 '12264',
 '12276',
 '12436',
 '1246',
 '12535',
 '1265',
 '12707',
 '12814',
 '12848',
 '12912',
 '12989',
 '13004',
 '13025',
 '1308',
 '13094',
 '13192',
 '13227',
 '13272',
 '13299',
 '13425',
 '13454',
 '1348',
 '13515',
 '13519',
 '13618',
 '13647',
 '13780',
 '13794',
 '13823',
 '1387',
 '13998',
 '1423',
 '14251',
 '14303',
 '14341',
 '14405',
 '14417',
 '14419',
 '14467',
 '14533',
 '14554',
 '14594',
 '14614',
 '14640',
 '14674',
 '14676',
 '14678',
 '14767',
 '14829',
 '15050',
 '15085',
 '15120',
 '15184',
 '15437',
 '15458',
 '15543',
 '15570',
 '15

In [43]:
# Filter both dataframes to keep only common IDs
filtered_df = df[df.index.astype(str).isin(common_ids)]
filtered_metadata = metadata_df[metadata_df['subject_id'].astype(str).isin(common_ids)]

# Reset index for metadata_df to maintain consistency
filtered_metadata = filtered_metadata.reset_index(drop=True)

print("\nShape of filtered methylation data:", filtered_df.shape)
print("Shape of filtered metadata:", filtered_metadata.shape)


Shape of filtered methylation data: (474, 443206)
Shape of filtered metadata: (478, 20)


In [45]:
filtered_df

Unnamed: 0_level_0,cg00455876,cg01707559,cg03244189,cg03695421,cg04689676,cg04792227,cg04964672,cg13851368,cg14180491,cg14210405,...,cg27532867,cg27534599,cg27536559,cg27545494,cg27552198,cg27553637,cg27575890,cg27585287,cg27592453,cg27598806
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
13515,0.373903,0.352611,0.285023,0.343865,0.209481,0.317307,0.654633,0.354590,0.478328,0.361801,...,0.059121,0.777402,0.140125,0.049228,0.854127,0.103806,0.377906,0.046019,0.801194,0.876720
4892,0.472144,0.317272,0.345102,0.435571,0.301970,0.346044,0.686608,0.455047,0.517205,0.399917,...,0.061810,0.758235,0.141635,0.057578,0.838716,0.147395,0.427982,0.055870,0.830896,0.831732
24654,0.364213,0.373167,0.383677,0.371633,0.400758,0.417855,0.709025,0.408174,0.521777,0.438885,...,0.064607,0.765001,0.236049,0.049031,0.848972,0.097279,0.441205,0.053650,0.817943,0.848297
8276,0.645996,0.118774,0.109488,0.650313,0.073410,0.202634,0.907011,0.842083,0.103456,0.517147,...,0.053871,0.727399,0.181261,0.041696,0.846080,0.077593,0.353020,0.047327,0.868281,0.872010
23495,0.393172,0.351914,0.389313,0.432429,0.389742,0.377637,0.760646,0.465659,0.482143,0.486697,...,0.055474,0.791954,0.221617,0.032351,0.859742,0.078798,0.421793,0.055254,0.915611,0.863411
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7786,0.607018,0.097955,0.099898,0.712607,0.087751,0.176667,0.914319,0.820917,0.084572,0.525990,...,0.056571,0.800210,0.183646,0.046256,0.862467,0.079173,0.367596,0.049488,0.825740,0.884708
24702,0.574336,0.087735,0.102858,0.757837,0.081985,0.140933,0.939900,0.802047,0.075081,0.469135,...,0.045415,0.786400,0.170066,0.038829,0.881907,0.086947,0.333335,0.045368,0.874013,0.907831
11354,0.358957,0.377702,0.295379,,0.195470,0.293281,0.684217,,0.553361,,...,0.063208,0.775111,0.186350,0.055446,0.837358,0.099946,0.385733,0.042444,0.807588,0.869527
5197,0.407751,0.334039,0.364659,0.386185,0.342306,0.415866,0.753173,0.414093,0.482597,0.372852,...,0.044422,0.770602,0.199192,0.038365,0.851601,0.077326,0.395870,0.043853,0.864127,0.883756


In [44]:
filtered_metadata

Unnamed: 0,Sample,subject_id,AgeAtBloodDraw,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,haschip,Gene,ExonicFunc,VAF,chip_binary
0,NWD290865,5360,36.709857,M,0.002307,-0.001598,-0.000748,0.000724,0.001564,0.000178,-0.000338,0.000049,0.000842,0.000176,0.000036,1,NF1,stopgain,0.204,1.0
1,NWD757156,10022,59.817792,F,0.002145,-0.000661,-0.001230,-0.000578,-0.002286,0.000284,0.000642,0.002438,0.001718,-0.002682,0.001535,1,DNMT3A,nonsynonymous SNV,0.118,1.0
2,NWD925538,26014,60.461200,F,0.002308,-0.001293,-0.000612,0.000759,0.001243,-0.001005,0.000830,0.000047,-0.004544,-0.001970,0.001056,1,DNMT3A,nonsynonymous SNV,0.289,1.0
3,NWD100436,7420,37.320410,F,0.002331,-0.001710,-0.000732,0.001101,0.001999,0.000338,-0.000963,-0.000289,0.002154,0.002027,-0.000200,0,,,,
4,NWD102395,5048,54.355668,F,0.002153,-0.001422,-0.000712,-0.002858,-0.005667,0.000936,-0.000473,0.003260,0.004069,-0.008498,0.004219,0,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
473,NWD983190,2482,52.132487,M,0.002320,-0.001663,-0.000785,0.000941,0.001791,0.000223,-0.000184,-0.000754,0.002305,0.002894,0.000044,0,,,,
474,NWD985467,9387,52.800537,F,0.002239,-0.001547,-0.000839,-0.000648,-0.001384,-0.001557,0.000079,-0.001703,0.001573,0.002124,-0.002207,0,,,,
475,NWD985989,16928,53.476800,F,0.002172,-0.001444,-0.000717,-0.002625,-0.005104,0.000367,-0.000662,0.002821,0.002219,-0.009627,0.003023,0,,,,
476,NWD986461,7658,51.341232,M,0.002356,-0.001783,-0.000793,0.000784,0.001350,0.000012,-0.000920,-0.000942,0.002105,0.001099,-0.000106,0,,,,


In [47]:
filtered_metadata['subject_id'].nunique()

474

In [48]:
repeated_subject_ids = metadata_df[metadata_df['subject_id'].duplicated(keep=False)]

In [49]:
repeated_subject_ids

Unnamed: 0,Sample,subject_id,AgeAtBloodDraw,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,haschip,Gene,ExonicFunc,VAF,chip_binary
204,NWD143123,15960,58.128504,M,,,,,,,,,,,,0,,,,
210,NWD145350,18324,37.279342,F,,,,,,,,,,,,0,,,,
249,NWD167155,18324,37.279342,F,0.002343,-0.001701,-0.000684,0.001081,0.002007,0.000348,-0.000267,-2.6e-05,0.002223,0.002199,0.000453,0,,,,
489,NWD284313,15960,58.128504,M,,,,,,,,,,,,0,,,,
566,NWD321439,13823,32.008871,F,0.002324,-0.001627,-0.0007,0.000817,0.001509,0.000217,-0.000138,-0.000703,0.002087,0.002442,0.00091,0,,,,
783,NWD433184,13823,32.008871,F,,,,,,,,,,,,0,,,,
837,NWD465832,13823,32.008871,F,,,,,,,,,,,,0,,,,
873,NWD481739,13823,32.008871,F,,,,,,,,,,,,0,,,,
926,NWD507659,15960,58.128504,M,,,,,,,,,,,,0,,,,
1212,NWD641266,20156,56.912873,F,,,,,,,,,,,,0,,,,


In [50]:
# keep only the first occurrence of each subject_id
filtered_metadata = filtered_metadata.drop_duplicates(subset='subject_id')

In [51]:
filtered_metadata

Unnamed: 0,Sample,subject_id,AgeAtBloodDraw,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,haschip,Gene,ExonicFunc,VAF,chip_binary
0,NWD290865,5360,36.709857,M,0.002307,-0.001598,-0.000748,0.000724,0.001564,0.000178,-0.000338,0.000049,0.000842,0.000176,0.000036,1,NF1,stopgain,0.204,1.0
1,NWD757156,10022,59.817792,F,0.002145,-0.000661,-0.001230,-0.000578,-0.002286,0.000284,0.000642,0.002438,0.001718,-0.002682,0.001535,1,DNMT3A,nonsynonymous SNV,0.118,1.0
2,NWD925538,26014,60.461200,F,0.002308,-0.001293,-0.000612,0.000759,0.001243,-0.001005,0.000830,0.000047,-0.004544,-0.001970,0.001056,1,DNMT3A,nonsynonymous SNV,0.289,1.0
3,NWD100436,7420,37.320410,F,0.002331,-0.001710,-0.000732,0.001101,0.001999,0.000338,-0.000963,-0.000289,0.002154,0.002027,-0.000200,0,,,,
4,NWD102395,5048,54.355668,F,0.002153,-0.001422,-0.000712,-0.002858,-0.005667,0.000936,-0.000473,0.003260,0.004069,-0.008498,0.004219,0,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
473,NWD983190,2482,52.132487,M,0.002320,-0.001663,-0.000785,0.000941,0.001791,0.000223,-0.000184,-0.000754,0.002305,0.002894,0.000044,0,,,,
474,NWD985467,9387,52.800537,F,0.002239,-0.001547,-0.000839,-0.000648,-0.001384,-0.001557,0.000079,-0.001703,0.001573,0.002124,-0.002207,0,,,,
475,NWD985989,16928,53.476800,F,0.002172,-0.001444,-0.000717,-0.002625,-0.005104,0.000367,-0.000662,0.002821,0.002219,-0.009627,0.003023,0,,,,
476,NWD986461,7658,51.341232,M,0.002356,-0.001783,-0.000793,0.000784,0.001350,0.000012,-0.000920,-0.000942,0.002105,0.001099,-0.000106,0,,,,


In [None]:
# inferencer.download_cpgcorpus_dataset("GSE182215")

There is no need to impute the methylation data for CpGPT -- it simply ignores the missing values.

In [None]:
# df = pd.read_feather(ARROW_DF_PATH)
# df.set_index('GSM_ID', inplace=True)
# df.head()

In [58]:
filtered_df.shape

(474, 443206)

### 4.2 Filter Vocab Features and Save Data

While not strictly required, filtering for the features used in finetuning gives you the best chance of achieving good performance.

In [52]:
# Load list
vocab = json.load(open(MODEL_VOCAB_PATH, 'r'))

In [53]:
vocab

{'input': ['cg00000292',
  'cg00002426',
  'cg00003994',
  'cg00005847',
  'cg00007981',
  'cg00008493',
  'cg00008713',
  'cg00009407',
  'cg00011459',
  'cg00012199',
  'cg00012386',
  'cg00012792',
  'cg00013618',
  'cg00014085',
  'cg00014837',
  'cg00015770',
  'cg00019495',
  'cg00020533',
  'cg00021527',
  'cg00022866',
  'cg00024396',
  'cg00024812',
  'cg00025991',
  'cg00027083',
  'cg00027674',
  'cg00029826',
  'cg00031162',
  'cg00032227',
  'cg00033516',
  'cg00033773',
  'cg00034039',
  'cg00035347',
  'cg00035623',
  'cg00037763',
  'cg00037940',
  'cg00040861',
  'cg00040873',
  'cg00043004',
  'cg00043080',
  'cg00044245',
  'cg00047050',
  'cg00047469',
  'cg00050312',
  'cg00051979',
  'cg00054706',
  'cg00056767',
  'cg00057593',
  'cg00058938',
  'cg00059424',
  'cg00059930',
  'cg00060762',
  'cg00061059',
  'cg00062776',
  'cg00063144',
  'cg00065385',
  'cg00065408',
  'cg00066816',
  'cg00067471',
  'cg00069261',
  'cg00071250',
  'cg00072216',
  'cg00075967',

In [96]:
vocab.keys()

dict_keys(['input', 'output'])

In [97]:
vocab['input']

['cg00000292',
 'cg00002426',
 'cg00003994',
 'cg00005847',
 'cg00007981',
 'cg00008493',
 'cg00008713',
 'cg00009407',
 'cg00011459',
 'cg00012199',
 'cg00012386',
 'cg00012792',
 'cg00013618',
 'cg00014085',
 'cg00014837',
 'cg00015770',
 'cg00019495',
 'cg00020533',
 'cg00021527',
 'cg00022866',
 'cg00024396',
 'cg00024812',
 'cg00025991',
 'cg00027083',
 'cg00027674',
 'cg00029826',
 'cg00031162',
 'cg00032227',
 'cg00033516',
 'cg00033773',
 'cg00034039',
 'cg00035347',
 'cg00035623',
 'cg00037763',
 'cg00037940',
 'cg00040861',
 'cg00040873',
 'cg00043004',
 'cg00043080',
 'cg00044245',
 'cg00047050',
 'cg00047469',
 'cg00050312',
 'cg00051979',
 'cg00054706',
 'cg00056767',
 'cg00057593',
 'cg00058938',
 'cg00059424',
 'cg00059930',
 'cg00060762',
 'cg00061059',
 'cg00062776',
 'cg00063144',
 'cg00065385',
 'cg00065408',
 'cg00066816',
 'cg00067471',
 'cg00069261',
 'cg00071250',
 'cg00072216',
 'cg00075967',
 'cg00076645',
 'cg00077877',
 'cg00078194',
 'cg00079056',
 'cg000795

In [54]:
len(vocab['input'])

21368

In [55]:
vocab['output']

['cpgpt_age']

In [59]:
filtered_df.columns

Index(['cg00455876', 'cg01707559', 'cg03244189', 'cg03695421', 'cg04689676',
       'cg04792227', 'cg04964672', 'cg13851368', 'cg14180491', 'cg14210405',
       ...
       'cg27532867', 'cg27534599', 'cg27536559', 'cg27545494', 'cg27552198',
       'cg27553637', 'cg27575890', 'cg27585287', 'cg27592453', 'cg27598806'],
      dtype='object', length=443206)

In [60]:
filtered_df

Unnamed: 0_level_0,cg00455876,cg01707559,cg03244189,cg03695421,cg04689676,cg04792227,cg04964672,cg13851368,cg14180491,cg14210405,...,cg27532867,cg27534599,cg27536559,cg27545494,cg27552198,cg27553637,cg27575890,cg27585287,cg27592453,cg27598806
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
13515,0.373903,0.352611,0.285023,0.343865,0.209481,0.317307,0.654633,0.354590,0.478328,0.361801,...,0.059121,0.777402,0.140125,0.049228,0.854127,0.103806,0.377906,0.046019,0.801194,0.876720
4892,0.472144,0.317272,0.345102,0.435571,0.301970,0.346044,0.686608,0.455047,0.517205,0.399917,...,0.061810,0.758235,0.141635,0.057578,0.838716,0.147395,0.427982,0.055870,0.830896,0.831732
24654,0.364213,0.373167,0.383677,0.371633,0.400758,0.417855,0.709025,0.408174,0.521777,0.438885,...,0.064607,0.765001,0.236049,0.049031,0.848972,0.097279,0.441205,0.053650,0.817943,0.848297
8276,0.645996,0.118774,0.109488,0.650313,0.073410,0.202634,0.907011,0.842083,0.103456,0.517147,...,0.053871,0.727399,0.181261,0.041696,0.846080,0.077593,0.353020,0.047327,0.868281,0.872010
23495,0.393172,0.351914,0.389313,0.432429,0.389742,0.377637,0.760646,0.465659,0.482143,0.486697,...,0.055474,0.791954,0.221617,0.032351,0.859742,0.078798,0.421793,0.055254,0.915611,0.863411
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7786,0.607018,0.097955,0.099898,0.712607,0.087751,0.176667,0.914319,0.820917,0.084572,0.525990,...,0.056571,0.800210,0.183646,0.046256,0.862467,0.079173,0.367596,0.049488,0.825740,0.884708
24702,0.574336,0.087735,0.102858,0.757837,0.081985,0.140933,0.939900,0.802047,0.075081,0.469135,...,0.045415,0.786400,0.170066,0.038829,0.881907,0.086947,0.333335,0.045368,0.874013,0.907831
11354,0.358957,0.377702,0.295379,,0.195470,0.293281,0.684217,,0.553361,,...,0.063208,0.775111,0.186350,0.055446,0.837358,0.099946,0.385733,0.042444,0.807588,0.869527
5197,0.407751,0.334039,0.364659,0.386185,0.342306,0.415866,0.753173,0.414093,0.482597,0.372852,...,0.044422,0.770602,0.199192,0.038365,0.851601,0.077326,0.395870,0.043853,0.864127,0.883756


In [61]:
filtered_df = filtered_df.loc[:, filtered_df.columns.isin(vocab['input'])]
filtered_df.head()

Unnamed: 0_level_0,cg00105470,cg00290506,cg00476580,cg00565688,cg00630164,cg00650762,cg00712898,cg00930078,cg00941229,cg01419479,...,cg27118809,cg27158143,cg27187881,cg27195224,cg27281093,cg27324619,cg27378424,cg27416437,cg27501458,cg27532722
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
13515,0.027791,0.036983,0.066271,0.168158,0.06954,0.033915,0.043883,0.04343,0.049154,0.081303,...,0.099306,0.080736,0.096607,0.824138,0.325251,0.751577,0.060099,0.0709,0.202823,0.796264
4892,0.02441,0.042328,0.104483,0.100738,0.068307,0.042724,0.046198,0.061989,0.048411,0.058582,...,0.104458,0.085971,0.176423,0.778738,0.350453,0.728809,0.079687,0.098191,0.205745,0.709666
24654,0.039425,0.041783,0.07773,0.093721,0.06277,0.040048,0.053491,0.049444,0.043913,0.053572,...,0.082414,0.060605,0.16613,0.81254,0.307506,0.713916,0.079062,0.059865,0.182273,0.775308
8276,0.037447,0.044743,0.085925,0.122843,0.065978,0.046319,0.046198,0.049748,0.050062,0.063004,...,0.070244,0.156625,0.134149,0.803444,0.290696,0.736893,0.05213,0.052951,0.168628,0.835507
23495,0.02951,0.035749,0.065447,0.129075,0.067116,0.043318,0.051262,0.030788,0.037853,0.040459,...,0.075782,0.065043,0.101919,0.782445,0.324285,0.775781,0.05184,0.054367,0.178012,0.821515


In [62]:
filtered_df

Unnamed: 0_level_0,cg00105470,cg00290506,cg00476580,cg00565688,cg00630164,cg00650762,cg00712898,cg00930078,cg00941229,cg01419479,...,cg27118809,cg27158143,cg27187881,cg27195224,cg27281093,cg27324619,cg27378424,cg27416437,cg27501458,cg27532722
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
13515,0.027791,0.036983,0.066271,0.168158,0.069540,0.033915,0.043883,0.043430,0.049154,0.081303,...,0.099306,0.080736,0.096607,0.824138,0.325251,0.751577,0.060099,0.070900,0.202823,0.796264
4892,0.024410,0.042328,0.104483,0.100738,0.068307,0.042724,0.046198,0.061989,0.048411,0.058582,...,0.104458,0.085971,0.176423,0.778738,0.350453,0.728809,0.079687,0.098191,0.205745,0.709666
24654,0.039425,0.041783,0.077730,0.093721,0.062770,0.040048,0.053491,0.049444,0.043913,0.053572,...,0.082414,0.060605,0.166130,0.812540,0.307506,0.713916,0.079062,0.059865,0.182273,0.775308
8276,0.037447,0.044743,0.085925,0.122843,0.065978,0.046319,0.046198,0.049748,0.050062,0.063004,...,0.070244,0.156625,0.134149,0.803444,0.290696,0.736893,0.052130,0.052951,0.168628,0.835507
23495,0.029510,0.035749,0.065447,0.129075,0.067116,0.043318,0.051262,0.030788,0.037853,0.040459,...,0.075782,0.065043,0.101919,0.782445,0.324285,0.775781,0.051840,0.054367,0.178012,0.821515
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7786,0.035650,0.041770,0.065078,0.139177,0.074692,0.050740,0.048392,0.045705,0.059791,0.059666,...,0.075886,0.070754,0.105584,0.805771,0.227886,0.724508,0.058492,0.069246,0.160121,0.787051
24702,0.035340,0.045327,0.063408,0.115242,0.072566,0.037346,0.065168,0.039956,0.056086,0.079971,...,0.070279,0.099391,0.097797,0.786389,0.239924,0.771349,0.051317,0.054724,0.115677,0.810453
11354,0.031466,0.036925,0.094944,0.149997,0.052393,0.032636,0.058376,0.052702,0.042360,0.068223,...,0.068429,0.108781,0.120557,0.819198,0.228437,0.767523,0.060011,0.061483,0.155548,0.787109
5197,0.032759,0.053247,0.078296,0.126129,0.042851,0.046702,0.053191,0.041976,0.044154,0.087022,...,0.087618,0.078737,0.097086,0.800205,0.315281,0.732184,0.062989,0.063164,0.218542,0.788707


In [63]:
ARROW_DF_FILTERED_PATH

'../data/tutorials/raw/fhs_filtered.arrow'

In [64]:
filtered_df.to_feather(ARROW_DF_FILTERED_PATH)

In [None]:
kk

### 4.3 Memory-Map Data

In order to perform inference, we need to memory-map the data. This is done by using the `CpGPTDataSaver` class. We first need to define the `DNALLMEmbedder` and `IlluminaMethylationProber` classes, which contain the information about the DNA LLM Embeddings and the conversion between Illumina array probes to genomic locations, respectively.

In [65]:
embedder = DNALLMEmbedder(dependencies_dir=LLM_DEPENDENCIES_DIR)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


In [66]:
prober = IlluminaMethylationProber(dependencies_dir=LLM_DEPENDENCIES_DIR, embedder=embedder)

[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mInitializing class IlluminaMethylationProber.[0m


[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina methylation manifest files will be stored under ../dependencies/human/manifests.[0m
[1m[34mcpgpt[0m[1m[0m: [36mIlluminaMethylationProber[0m: [1mIllumina metadata dictionary loaded successfully.[0m


In [108]:
prober

<cpgpt.data.components.illumina_methylation_prober.IlluminaMethylationProber at 0x7f9a49e28a90>

In [68]:
# Define datasaver
quick_setup_datasaver = CpGPTDataSaver(data_paths=ARROW_DF_FILTERED_PATH, processed_dir=PROCESSED_DIR)

# Process the file
quick_setup_datasaver.process_files(prober, embedder)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mInitializing class CpGPTDataSaver.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mDataset folders will be stored under ../data/tutorials/processed/fhs_setup.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mNo existing dataset metrics found. Please process files.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mNo existing genomic locations found. Please process files.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mStarting file processing.[0m


Output()

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [33m[1mNo species column found. Defaulting to homo_sapiens.[0m


[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataSaver[0m: [1mFile processing completed.[0m


### 4.4 Declare data module

Let's define two data modules: one for the forward pass and reconstructing the methylation, and another one the attention weights.

In [69]:
# Define datamodule
quick_setup_datamodule = CpGPTDataModule(
    predict_dir=PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_INPUT_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

# Define datamodule
quick_setup_datamodule_attn = CpGPTDataModule(
    predict_dir=PROCESSED_DIR,
    dependencies_dir=LLM_DEPENDENCIES_DIR,
    batch_size=1,
    num_workers=0,
    max_length=MAX_ATTN_LENGTH,
    dna_llm=config.data.dna_llm,
    dna_context_len=config.data.dna_context_len,
    sorting_strategy=config.data.sorting_strategy,
    pin_memory=False,
)

[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m


[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mInitializing class DNALLMEmbedder.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mGenome files will be stored under ../dependencies/human/genomes.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.[0m
[1m[34mcpgpt[0m[1m[0m: [36mDNALLMEmbedder[0m: [1mEnsembl metadata dictionary loaded successfully[0m


## 5. Run Inference

There are several ways to perform inference with CpGPT. Here, we'll go through the most common ones.

### 5.1 Declare Trainer

Given all models were trained under mixed precision, we'll use the `precision="16-mixed"` argument. However, if you finetune it using a different precision, you can change that accordingly.

In [70]:
trainer = CpGPTTrainer(precision="16-mixed")

/grand/GeomicVar/tarak/cpgpt/cpgpt_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### 5.2 Get Sample Embeddings

In [71]:
quick_setup_sample_embeddings = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="forward",
    return_keys=["sample_embedding"]
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


Output()

/grand/GeomicVar/tarak/cpgpt/cpgpt_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [28]:
quick_setup_sample_embeddings

{'sample_embedding': tensor([[-2.6272e-01, -7.3696e-02, -2.6239e-02,  ..., -9.3649e-02,
          -3.6611e-02, -1.8364e-02],
         [-3.0251e-01, -7.7571e-02, -5.5078e-02,  ..., -1.1866e-01,
          -5.1626e-02, -2.7392e-03],
         [-5.5322e-02,  1.3166e-03, -7.8424e-02,  ..., -3.0264e-02,
          -1.0671e-02, -1.4953e-01],
         ...,
         [-2.0327e-01, -5.4256e-02, -7.8379e-02,  ..., -9.7128e-02,
           3.3823e-02, -4.2534e-02],
         [-2.3025e-01, -7.0735e-02, -6.5348e-02,  ..., -1.0442e-01,
           1.2583e-02, -2.7886e-02],
         [-2.0669e-01, -5.3605e-02, -1.0745e-01,  ..., -9.1940e-02,
          -2.3419e-04, -7.1221e-02]])}

In [30]:
quick_setup_sample_embeddings.keys()

dict_keys(['sample_embedding'])

In [31]:
len(quick_setup_sample_embeddings['sample_embedding'])

38

### 5.3 Predict Phenotypes

In [20]:
quick_setup_pred_conditions = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="forward",
    return_keys=["pred_conditions"]
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [21]:
quick_setup_pred_conditions

{'pred_conditions': tensor([[ 0.3667],
         [ 0.0154],
         [ 6.9766],
         [ 0.4402],
         [-0.6748],
         [-0.2629],
         [-0.7241],
         [-2.7090],
         [-0.1346],
         [-0.4250],
         [-0.0770],
         [ 0.1849],
         [-0.0803],
         [ 0.4863],
         [ 2.5410],
         [-1.4512],
         [-3.1914],
         [ 3.1328],
         [-0.7524],
         [-1.4854],
         [-1.8594],
         [-1.4404],
         [-2.0391],
         [-1.4297],
         [-2.4863],
         [-2.0703],
         [-2.4922],
         [-2.4121],
         [-2.3438],
         [-1.7637],
         [-1.4941],
         [-2.4941],
         [-0.4998],
         [-0.5352],
         [-1.9775],
         [-3.3359],
         [-1.0107],
         [-0.5669]], dtype=torch.float16)}

### 5.4 Reconstruct Methylation

As an example, let's get some the reconstructed methylation values for some locations of interest based on the Illumina probes.

In [22]:
# Random probes for demonstration
probes = list(df.columns[0:100])

probes[0:5]

['cg00000292', 'cg00002426', 'cg00003994', 'cg00005847', 'cg00008493']

In [23]:
# Convert probes to genomic locations
genomic_locations = prober.locate_probes(probes, "homo_sapiens")

genomic_locations[0:5]

['16:28878778', '3:57757815', '7:15686236', '2:176164344', '14:93347430']

In [24]:
quick_setup_pred_meth = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="reconstruct",
    genomic_locations=genomic_locations,
    species="homo_sapiens",
    return_keys=["pred_meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Be mindful as the reconstructed values are M values, not beta values. Therefore, you need to convert them to beta values using the `m_to_beta` function.

In [25]:
quick_setup_pred_meth["pred_meth"] = m_to_beta(quick_setup_pred_meth["pred_meth"])
quick_setup_pred_meth

{'pred_meth': tensor([[0.8501, 0.8633, 0.0503,  ..., 0.0348, 0.9292, 0.7095],
         [0.8706, 0.8833, 0.0492,  ..., 0.0340, 0.9419, 0.7275],
         [0.3799, 0.4067, 0.2776,  ..., 0.3308, 0.4133, 0.2937],
         ...,
         [0.7925, 0.7881, 0.0529,  ..., 0.0367, 0.9351, 0.7075],
         [0.8247, 0.8291, 0.0523,  ..., 0.0337, 0.9351, 0.7031],
         [0.6494, 0.4927, 0.0576,  ..., 0.0337, 0.9385, 0.7085]],
        dtype=torch.float16)}

A more powerful way of reconstructing the methylation values is using chain-of-thought. With additional test-time compute, we can let the model "think harder" about the problem, which can lead to better performance. However, it also takes considerably longer dependending on the number of thinking steps.

In [26]:
quick_setup_pred_meth_cot = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule,
    predict_mode="reconstruct",
    genomic_locations=genomic_locations,
    species="homo_sapiens",
    n_thinking_steps=5,
    thinking_step_size=1000,
    uncertainty_quantile=0.1,
    return_keys=["pred_meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [27]:
quick_setup_pred_meth_cot["pred_meth"] = m_to_beta(quick_setup_pred_meth_cot["pred_meth"])
quick_setup_pred_meth_cot

{'pred_meth': tensor([[0.8516, 0.8486, 0.0482,  ..., 0.0344, 0.9248, 0.7080],
         [0.8691, 0.8696, 0.0484,  ..., 0.0330, 0.9434, 0.7383],
         [0.4485, 0.4104, 0.3074,  ..., 0.3823, 0.3540, 0.2976],
         ...,
         [0.7856, 0.7690, 0.0512,  ..., 0.0355, 0.9326, 0.6919],
         [0.8271, 0.8198, 0.0520,  ..., 0.0332, 0.9331, 0.6934],
         [0.6523, 0.4868, 0.0560,  ..., 0.0331, 0.9297, 0.7061]],
        dtype=torch.float16)}

### 5.5 Analyze Attention Weights

The amount of memory required to store the attention weights is enormous. Therefore, we only use 1000 features for the demonstration. Also, remember that the the first token is the CLS token.

In [28]:
quick_setup_attn = trainer.predict(
    model=model,
    datamodule=quick_setup_datamodule_attn,
    predict_mode="attention",
    aggregate_heads="mean",
    layer_index=-1,
    return_keys=["attention_weights", "chroms", "positions", "mask_na", "meth"],
)

[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mInitializing class CpGPTDataset.[0m
[1m[34mcpgpt[0m[1m[0m: [36mCpGPTDataset[0m: [1mLoaded existing dataset metrics.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [29]:
quick_setup_attn

{'attention_weights': tensor([[[0.0011, 0.0010, 0.0009,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          ...,
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],
 
         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0011, 0.0010,  ...,    nan,    nan,    nan],
          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],
          ...,
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],
 
         [[0.0121, 0.0115, 0.0097,  ...,    nan,    nan,    nan],
          [0.0114, 0.0123, 0.0098,  ...,    nan,    nan,    nan],
          [0.0105, 

: 