In [None]:
from gpn.data import GenomeMSA #, Tokenizer
import gpn.model

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from transformers import AutoModel #, AutoModelForMaskedLM
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import re
import os
import csv
import warnings
warnings.filterwarnings('ignore')

from datetime import datetime, timedelta

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model_path = "songlab/gpn-msa-sapiens"
msa_path = "zip:///::/home/sunhuaikuan/ondemand/blue_gpn/examples/msa/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)
model = AutoModel.from_pretrained(model_path).to(device)
model.eval();

### Main Function to get Embedding

In [None]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

max_seqlen=128

def Genosome2Embedding(chrom, pos_start, pos_end,  y):  # ref,alt,
    msa = genome_msa.get_msa(str(chrom), pos_start, pos_end, strand="+", tokenize=True)
    # print(msa.shape)

    msa = torch.tensor(np.expand_dims(msa, 0).astype(np.int64))
    # msa

    # separating human from rest of species
    input_ids, aux_features = msa[:, :, 0], msa[:, :, 1:]
    
    input_ids = input_ids.to(device)
    aux_features = aux_features.to(device)


    with torch.no_grad():
        last_hidden_state = model(input_ids=input_ids, aux_features=aux_features).last_hidden_state
        
        # Mean Pooling: Compute the mean across the sequence length (dim=1)
        mean_pooled = last_hidden_state.mean(dim=1)  # Shape: (batch_size, embedding_dim)

    feature=np.append(mean_pooled.cpu().numpy(),  [y]) # chrom,  pos_end-pos_start, comp[ref],comp[alt],
    
    return feature


### Output CSV File

In [None]:
def output2CSV(df, csv_Filename):

    if os.path.exists(csv_Filename):
        os.remove(csv_Filename)

    rows=[]
    # progress_bar = tqdm(total=df.shape[0], desc="Processing")
    for index, row in df.iterrows():
        
        chrom=row['CHROM']
        pos_start=row['START']
        pos_end=row['END']
        y=row['y']
        try:
            embedding  =  Genosome2Embedding(chrom, pos_start,pos_end, y) # ref,alt,
            rows.append(embedding)
    
        except Exception as e:
            print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START']))
    

        if ((index % 5000) ==0):
            with open(csv_Filename, mode='a', newline='') as file:
                writer = csv.writer(file)
                # writer.writerow(rows)
                for row in rows:
                    writer.writerow(row)
            rows=[]
            # progress_bar.update(1)
            print(f"complete index={index}")


    with open(csv_Filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        for row in rows:
            writer.writerow(row)

    print(f"Create File: "+csv_Filename)

### Load Homo_Sapiens data

In [None]:
max_seqlen=128


pathogenecity_type='noncoding'
datafile='/home/sunhuaikuan/ondemand/blue_papers/DNA_LLM_REVIEW/datasets/task04-pathogenecity/clinvar_20240805.'+pathogenecity_type+'.txt'

# pathogenecity_type='missense'
# datafile='/home/sunhuaikuan/ondemand/blue_papers/DNA_LLM_REVIEW/datasets/task04-pathogenecity/clinvar_20240805.'+pathogenecity_type+'_matched.txt'


df=pd.read_csv(datafile, delimiter='\t')

columns_to_keep=['CHROM','POS','Pathogenicity'] # 'ID', 'REF','ALT',
df = df[columns_to_keep]

# Merge CHROM=9 and '9' etc
for i in range(1,23):
    df.loc[df['CHROM']==i,'CHROM']=str(i)

df=df[~df['CHROM'].isna()]
df = df[~df['CHROM'].str.contains('KI')]
df = df[~df['CHROM'].str.contains('GL')]


    

df['START']=df['POS']- max_seqlen //2 -1

df['END']=df['START'] + max_seqlen

df=df[~df['CHROM'].isna()]

Pathogenicity_dict={'B':0,'P':1}
df['y'] = df['Pathogenicity'].map(Pathogenicity_dict)

df=df.drop(columns=['POS','Pathogenicity'])
# df['CHROM'].value_counts()
df

In [None]:
%%time

now = datetime.now()
formatted_time = now.strftime("%y-%m-%d-%H-%M-%S")
csv_filename = './pathogenecity_gpn_'+pathogenecity_type+'_'+formatted_time+'.csv'

output2CSV(df,csv_filename)

### Load CSV File

In [None]:
import pandas as pd

def load_embedding_file(csv_filename):

    df=pd.read_csv(csv_filename)
     
    column_names = [f'{i}' for i in range(1, df.shape[1])]
    column_names.extend([ 'y'])
    
    df.columns = column_names
    return df

df = load_embedding_file(csv_filename)
df