# GPN-MSA Testing with Datasets

## Setup environment 

In [None]:
# only need once for GPN-MSA use
# !pip install git+https://github.com/songlab-cal/gpn.git

In [None]:
# !ldconfig /usr/lib64-nvidia

In [1]:
from gpn.data import GenomeMSA, Tokenizer
import gpn.model
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import torch
from transformers import AutoModel, AutoModelForMaskedLM



In [2]:
from datasets import load_dataset, disable_caching
from gpn.data import load_dataset_from_file_or_dir

In [3]:
dataset_prefix = "/expanse/lustre/projects/nji102/sgriesmer/gpn/Datasets/"

In [None]:
input_file =  dataset_prefix + "test.parquet"
dataset = load_dataset_from_file_or_dir(input_file, split="test", is_file=True)

## Create datasets

### Create Pathogenic dataset (only needed once)

In [None]:
dataset_pathogenic = dataset.filter(lambda v: v["source"]=="ClinVar" or (v["label"]=="Common" and "missense" in v["consequence"]))
dataset_pathogenic.shape

### Create Cancer dataset (only needed once)

In [None]:
dataset_cancer = dataset.filter(lambda v: v["source"]=="COSMIC" or (v["label"]=="Common" and "missense" in v["consequence"]))
dataset_cancer.shape

### Create Regulatory dataset (only needed once)

In [None]:
cs = ["5_prime_UTR", "upstream_gene", "intergenic", "3_prime_UTR", "non_coding_transcript_exon"]
dataset_regulatory = dataset.filter(lambda v: v["source"]=="OMIM" or (v["label"]=="Common" and "missense" not in v["consequence"] and any([c in v["consequence"] for c in cs])))
dataset_regulatory.shape

## Save datasets as parquet files

In [None]:
dataset_pathogenic_filename = dataset_prefix + "pathogenic.parquet"
dataset_pathogenic.to_parquet(dataset_pathogenic_filename)

In [None]:
dataset_cancer_filename = dataset_prefix + "cancer.parquet"
dataset_cancer.to_parquet(dataset_cancer_filename)

In [None]:
dataset_regulatory_filename = dataset_prefix + "regulatory.parquet"
dataset_regulatory.to_parquet(dataset_regulatory_filename)

# Test and Score with GPN-MSA

In [None]:
part = "adad"
dataset_pathogenic_filename = dataset_prefix + "PAT_dataset_XY-named-equal-hg38-part-" + part +".vcf"
#dataset_pathogenic_filename = dataset_prefix + "example.vcf"
dataset_pathogenic = load_dataset_from_file_or_dir(dataset_pathogenic_filename, split="test", is_file=True)
dataset_pathogenic.shape

In [None]:
dataset_pathogenic[0]

## Load Dataset to be Scored (if previously created)

In [None]:
#dataset_pathogenic_filename = dataset_prefix + "pathogenic.parquet"
#dataset_pathogenic = load_dataset_from_file_or_dir(dataset_pathogenic_filename, split="test", is_file=True)
#dataset_pathogenic.shape

## Load MSA data

In [None]:
msa_path = "zip:///::https://huggingface.co/datasets/songlab/multiz100way/resolve/main/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)  # can take a minute or two

## Load inference model

In [None]:
from gpn.msa.vep import VEPInference

model_path = "songlab/gpn-msa-sapiens"
window_size = 128
vep_inf = VEPInference(model_path, genome_msa, window_size, disable_aux_features=False)

## Pathogenic Dataset

## Subset Dataset to be Scored

In [None]:
set_start = 0
set_end = 50
dataset_pathogenic_set = dataset_pathogenic.select(range(set_start, set_end))
df_pathogenic_set = pd.DataFrame(dataset_pathogenic_set)
df_pathogenic_set

## Tokenize Dataset

In [None]:
dataset_pathogenic_set.set_transform(vep_inf.tokenize_function)

## Set Arguments for Testing

In [None]:
from transformers import Trainer, TrainingArguments

output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

batch_size = 100
training_args = TrainingArguments(
  output_dir=output_dir,
  per_device_eval_batch_size=batch_size,
  dataloader_num_workers=0,
  remove_unused_columns=False,
#  torch_compile=True,
#  fp16=True,
)

In [None]:
trainer = Trainer(
    model=vep_inf.model,
    args=training_args
)

## Test and Score Dataset Subset

In [None]:
# for debugging purposes
!TORCH_LOGS="+dynamo"
!TORCHDYNAMO_VERBOSE=1

In [None]:
pred=trainer.predict(test_dataset=dataset_pathogenic_set).predictions

## Add Score to Dataframe

In [None]:
gpn_score = vep_inf.postprocess(pred)
df_pathogenic_set["gpn_score"] = gpn_score
df_pathogenic_set.head()

## Plot shows differentiation between Common and Pathogenic mutations

In [None]:
sns.histplot(data=df_pathogenic_set, x="gpn_score", hue="label")

## Save dataframe as CSV

In [None]:
output_file = output_dir + "/pathogenic_set_hg38_part_" + part + "_" + str(set_start) + "_" + str(set_end) + ".csv"
df_pathogenic_set.to_csv(output_file, index=False, sep=',')

## Load dataframe from CSV (if needed)

In [None]:
output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

input_csv_file = output_dir + "/pathogenic_set_0_39652.csv"
df_pathogenic_set = pd.read_csv(input_csv_file, sep=',')

In [None]:
sns.histplot(data=df_pathogenic_set, x="gpn_score", hue="label")

## Calculate metrics

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score
AUROC = roc_auc_score(df_pathogenic_set.label=="Pathogenic", -df_pathogenic_set.gpn_score)
AUPRC = average_precision_score(df_pathogenic_set.label=="Pathogenic", -df_pathogenic_set.gpn_score)
AUROC, AUPRC

## Plot ROC curve

In [None]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(df_pathogenic_set.label=="Pathogenic", -df_pathogenic_set.gpn_score)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b', lw=2, label=f'ROC curve (AUC = {AUROC:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

## Cancer Dataset

## Load Dataset to be Scored (if previously created)

In [19]:
#dataset_cancer_filename = dataset_prefix + "cancer.parquet"
#dataset_cancer = load_dataset_from_file_or_dir(dataset_cancer_filename, split="test", is_file=True)
#dataset_cancer.shape

In [168]:
part = "apaaabaaaaadac"
dataset_cancer_filename = dataset_prefix + "CAN_dataset_XY-named-hg38-part-" + part +".vcf"
dataset_cancer = load_dataset_from_file_or_dir(dataset_cancer_filename, split="test", is_file=True)
dataset_cancer.shape

(2, 4)

In [169]:
dataset_cancer.features

{'chrom': Value(dtype='string', id=None),
 'pos': Value(dtype='int64', id=None),
 'ref': Value(dtype='string', id=None),
 'alt': Value(dtype='string', id=None)}

In [170]:
dataset_cancer[0]

{'chrom': '9', 'pos': 14926, 'ref': 'C', 'alt': 'T'}

## Load MSA data

In [7]:
msa_path = "zip:///::https://huggingface.co/datasets/songlab/multiz100way/resolve/main/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)  # can take a minute or two

Loading MSA...
Loading MSA... Done


## Load inference model

In [8]:
from gpn.msa.vep import VEPInference

model_path = "songlab/gpn-msa-sapiens"
window_size = 128
vep_inf = VEPInference(model_path, genome_msa, window_size, disable_aux_features=False)

## Subset Dataset to be Scored

In [171]:
set_start = 0
set_end = 2
dataset_cancer_set = dataset_cancer.select(range(set_start, set_end))
df_cancer_set = pd.DataFrame(dataset_cancer_set)
df_cancer_set

Unnamed: 0,chrom,pos,ref,alt
0,9,14926,C,T
1,1,4606584,C,T


## Tokenize Dataset

In [172]:
dataset_cancer_set.set_transform(vep_inf.tokenize_function)

## Set Arguments for Testing

In [173]:
from transformers import Trainer, TrainingArguments

output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

batch_size = 100
training_args = TrainingArguments(
  output_dir=output_dir,
  per_device_eval_batch_size=batch_size,
  dataloader_num_workers=0,
  remove_unused_columns=False,
#  torch_compile=True,
#  fp16=True,
)

In [174]:
trainer = Trainer(
    model=vep_inf.model,
    args=training_args
)

## Test and Score Dataset Subset

In [175]:
# for debugging purposes
!TORCH_LOGS="+dynamo"
!TORCHDYNAMO_VERBOSE=1

In [176]:
pred=trainer.predict(test_dataset=dataset_cancer_set).predictions

## Add Score to Dataframe

In [177]:
gpn_score = vep_inf.postprocess(pred)
df_cancer_set["gpn_score"] = gpn_score
df_cancer_set

Unnamed: 0,chrom,pos,ref,alt,gpn_score
0,9,14926,C,T,-5.328727
1,1,4606584,C,T,-0.286613


In [178]:
df_cancer_set.drop([1], axis=0, inplace=True)
df_cancer_set

Unnamed: 0,chrom,pos,ref,alt,gpn_score
0,9,14926,C,T,-5.328727


## Plot shows differentiation between Common and COSMIC/Frequent mutations

In [179]:
#sns.histplot(data=df_cancer_set, x="gpn_score", hue="label")

## Save dataframe as CSV

In [180]:
output_file = output_dir + "/cancer_set_hg38_part_" + part + "_" + str(set_start) + "_" + str(set_end) + ".csv"
df_cancer_set.to_csv(output_file, index=False, sep=',')

## Calculate metrics

In [None]:
df_cancer_set[17585:17587]

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score
AUROC = roc_auc_score(df_cancer_set.label=="Frequent", -gpn_score)
AUPRC = average_precision_score(df_cancer_set.label=="Frequent", -gpn_score)
AUROC, AUPRC

## Plot ROC curve

In [None]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(df_cancer_set.label=="Frequent", -gpn_score)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b', lw=2, label=f'ROC curve (AUC = {AUROC:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

In [None]:
## Regulatory Dataset

## Load Dataset to be Scored (if previously created)

In [None]:
dataset_regulatory_filename = dataset_prefix + "regulatory.parquet"
dataset_regulatory = load_dataset_from_file_or_dir(dataset_regulatory_filename, split="test", is_file=True)
dataset_regulatory.shape

## Load MSA data

In [None]:
msa_path = "zip:///::https://huggingface.co/datasets/songlab/multiz100way/resolve/main/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)  # can take a minute or two

## Load inference model

In [None]:
from gpn.msa.vep import VEPInference

model_path = "songlab/gpn-msa-sapiens"
window_size = 128
vep_inf = VEPInference(model_path, genome_msa, window_size, disable_aux_features=False)

## Subset Dataset to be Scored

In [None]:
set_start = 600000
set_end = 800000
dataset_regulatory_set = dataset_regulatory.select(range(set_start, set_end))
df_regulatory_set = pd.DataFrame(dataset_regulatory_set)
df_regulatory_set

## Tokenize Dataset

In [None]:
dataset_regulatory_set.set_transform(vep_inf.tokenize_function)

## Set Arguments for Testing

In [None]:
from transformers import Trainer, TrainingArguments

output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

batch_size = 500
training_args = TrainingArguments(
  output_dir=output_dir,
  per_device_eval_batch_size=batch_size,
  dataloader_num_workers=0,
  remove_unused_columns=False,
#  torch_compile=True,
#  fp16=True,
)

In [None]:
trainer = Trainer(
    model=vep_inf.model,
    args=training_args
)

## Test and Score Dataset Subset

In [None]:
# for debugging purposes
!TORCH_LOGS="+dynamo"
!TORCHDYNAMO_VERBOSE=1

In [None]:
pred=trainer.predict(test_dataset=dataset_regulatory_set).predictions

## Add Score to Dataframe

In [None]:
gpn_score = vep_inf.postprocess(pred)
df_regulatory_set["gpn_score"] = gpn_score
df_regulatory_set.head()

## Plot shows differentiation between Common and OMIM/Pathogenic mutations

In [None]:
sns.histplot(data=df_regulatory_set, x="gpn_score", hue="label")

## Save dataframe as CSV

In [None]:
output_file = output_dir + "/regulatory_set" + "_" + str(set_start) + "_" + str(set_end) + ".csv"
df_regulatory_set.to_csv(output_file, index=False, sep=',')

In [None]:
## Load dataframe from CSV (if needed)

In [None]:
output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

input_csv_file = [output_dir + "/regulatory_set_0_100000.csv", 
output_dir + "/regulatory_set_100000_200000.csv",
output_dir + "/regulatory_set_200000_300000.csv",
output_dir + "/regulatory_set_300000_400000.csv",
output_dir + "/regulatory_set_400000_600000.csv",
output_dir + "/regulatory_set_600000_800000.csv"]

df_regulatory_subset = []
for f in input_csv_file:
    df_regulatory_subset.append(pd.read_csv(f, sep=','))
    
df_regulatory_set = pd.concat(df_regulatory_subset, axis=0)

In [None]:
df_regulatory_set

In [None]:
sns.histplot(data=df_regulatory_set, x="gpn_score", hue="label")

## Calculate metrics

In [None]:
df_regulatory_set

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score
AUROC = roc_auc_score(df_regulatory_set.label=="Pathogenic", -df_regulatory_set.gpn_score)
AUPRC = average_precision_score(df_regulatory_set.label=="Pathogenic", -df_regulatory_set.gpn_score)
AUROC, AUPRC

## Plot ROC curve

In [None]:
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(df_regulatory_set.label=="Pathogenic", -df_regulatory_set.gpn_score)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b', lw=2, label=f'ROC curve (AUC = {AUROC:.2f})')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()