# GPN-MSA Testing with Datasets

## Setup environment 

In [None]:
# only need once for GPN-MSA use
# !pip install git+https://github.com/songlab-cal/gpn.git

In [None]:
# !ldconfig /usr/lib64-nvidia

In [1]:
from gpn.data import GenomeMSA, Tokenizer
import gpn.model
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import torch
from transformers import AutoModel, AutoModelForMaskedLM



In [2]:
from datasets import load_dataset, disable_caching
from gpn.data import load_dataset_from_file_or_dir

In [3]:
dataset_prefix = "/expanse/lustre/projects/nji102/sgriesmer/gpn/Datasets/"

## CAD Datasets

### Load GPN-MSA Model

In [4]:
msa_path = "zip:///::https://huggingface.co/datasets/songlab/multiz100way/resolve/main/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)  # can take a minute or two

Loading MSA...
Loading MSA... Done


### Load Inference Model

In [5]:
from gpn.msa.vep import VEPInference

model_path = "songlab/gpn-msa-sapiens"
window_size = 128
vep_inf = VEPInference(model_path, genome_msa, window_size, disable_aux_features=False)

### Subset Dataset to be Scored

In [6]:
# keep track of successful and unsuccessful scores in results file

output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"
results_file = output_dir + "/cad-gt5-results-check.csv"

f = open(results_file, 'a') 

# create parts list

import sre_yield

parts = []
parts = list(sre_yield.AllStrings(r'aaaaaaa[a-j]'))
parts.sort()

# parts processed so far

parts_left = parts.copy()

for part in parts:
    
# keep track of parts left to process in case of error
    
  parts_left.remove(part)

  dataset_cad_filename = dataset_prefix + "random_sampling_gt5-hg38-part-" + part +".vcf"
  dataset_cad = load_dataset_from_file_or_dir(dataset_cad_filename, split="test", is_file=True)

# create dataset and dataframe

  set_start = 0
  set_end = len(dataset_cad)
  dataset_cad_set = dataset_cad.select(range(set_start, set_end))
  df_cad_set = pd.DataFrame(dataset_cad_set)
  df_cad_set

# tokenize dataset

  dataset_cad_set.set_transform(vep_inf.tokenize_function)

# set arguments for testing

  from transformers import Trainer, TrainingArguments

  batch_size = 500
  training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=0,
    remove_unused_columns=False,
#   torch_compile=True,
#   fp16=True,
)

  trainer = Trainer(
    model=vep_inf.model,
    args=training_args
  )

# Test and score dataset subset

# for debugging purposes
  !TORCH_LOGS="+dynamo"
  !TORCHDYNAMO_VERBOSE=1

  try:
    pred=trainer.predict(test_dataset=dataset_cad_set).predictions

      
    gpn_score = vep_inf.postprocess(pred)
    df_cad_set["gpn_score"] = gpn_score
        
# drop stub if down to 1-result
        
    df_cad_set.drop([1], axis=0, inplace=True)

    
# Write success to terminal and results file
    
    print(part + ',' + str(df_cad_set["gpn_score"][0]))
    f.write(part + ',' + str(df_cad_set["gpn_score"][0]) + "\n")
    
    
# write score output to file
    
    output_file = output_dir + "/cad_set_gt5_hg38_part_" + part + "_" + str(set_start) + "_" + str(set_end) + ".csv"
    df_cad_set.to_csv(output_file, index=False, sep=',')   
  
  except:
        
# Write failure to terminal and results file
        
    print(part + ",no score") 
    f.write(part + ",no score" + "\n")
    continue
    
# Close results file
    
f.close()


aaaaaaaa,no score


aaaaaaab,0.89399433


aaaaaaac,1.8475147


aaaaaaad,-2.8691573


aaaaaaae,-1.3439262


aaaaaaaf,-1.3763776
aaaaaaag,no score
aaaaaaah,no score
aaaaaaai,no score
aaaaaaaj,no score


### Load CAD datasets for P < 0.01 and P > 0.5

In [None]:
output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

cad_gt5_csv_file = output_dir + "/cad-gt5-results.csv"
cad_lt01_csv_file = output_dir + "/cad-lt01-results.csv"
df_cad_set_gt5 = pd.read_csv(cad_gt5_csv_file, sep=',')
df_cad_set_lt01 = pd.read_csv(cad_lt01_csv_file, sep=',')

In [None]:
df_cad_set_gt5

In [None]:
df_cad_set_lt01

### Add row for label

In [None]:
df_cad_set_lt01['label'] = 'lt01'

In [None]:
df_cad_set_gt5['label'] = 'gt5'

In [None]:
df_cad_set_lt01

### Drop rows with "No score"

In [None]:
df_cad_set_lt01 = df_cad_set_lt01[df_cad_set_lt01["gpn_score"].str.contains("no score") == False]
df_cad_set_lt01

In [None]:
df_cad_set_gt5 = df_cad_set_gt5[df_cad_set_gt5["gpn_score"].str.contains("no score") == False]
df_cad_set_gt5

### Combine the dataset

In [None]:
df_combined = pd.concat([df_cad_set_lt01,df_cad_set_gt5], axis=0)
df_combined.reset_index(drop=True, inplace=True)
df_combined["gpn_score"] = df_combined["gpn_score"].astype(float)
df_combined

### Plot distributions

In [None]:
sns.displot(data=df_combined, x="gpn_score", hue="label")

In [None]:
part = "afajajag"#  dataset_regulatory_filename = dataset_prefix + "REG_dataset_XY-named-trunc-hg38-part-" + part +".vcf"
dataset_cad_filename = dataset_prefix + "random_sampling_lt01-hg38-part-" + part +".vcf"
dataset_cad = load_dataset_from_file_or_dir(dataset_cad_filename, split="test", is_file=True)
dataset_cad

In [None]:
set_start = 0
#set_end = len(dataset_cad)
set_end = 2
dataset_cad_set = dataset_cad.select(range(set_start, set_end))
df_cad_set = pd.DataFrame(dataset_cad_set)
df_cad_set, dataset_cad_set

In [None]:
dataset_cad_set.set_transform(vep_inf.tokenize_function)
dataset_cad_set.features

In [None]:
from transformers import Trainer, TrainingArguments

output_dir = "/expanse/lustre/projects/nji102/sgriesmer/gpn/output"

batch_size = 500
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=0,
    remove_unused_columns=False,
#   torch_compile=True,
#   fp16=True,
)


In [None]:
trainer = Trainer(
    model=vep_inf.model,
    args=training_args
)

In [None]:
pred=trainer.predict(test_dataset=dataset_cad_set).predictions

In [None]:
gpn_score = vep_inf.postprocess(pred)
df_cad_set["gpn_score"] = gpn_score

In [None]:
df_cad_set

### Cut out stub from set

In [None]:
df_cad_set.drop([1], axis=0, inplace=True)
df_cad_set

In [None]:
print(part + ',' + str(df_cad_set["gpn_score"][0]))

In [None]:
output_file = output_dir + "/cad_set_hg38_part_" + part + "_" + str(set_start) + "_" + str(set_end) + ".csv"
df_cad_set.to_csv(output_file, index=False, sep=',') 