In [None]:
from gpn.data import GenomeMSA
import gpn.model

# from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from transformers import AutoModel #, AutoModelForMaskedLM
import torch
import numpy as np
import pandas as pd
import re
import os
import csv
import warnings
warnings.filterwarnings('ignore')

%run preprocess_utility.py

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model_path = "songlab/gpn-msa-sapiens"
# msa_path = "zip:///::89.zarr.zip"
msa_path = "zip:///::/home/sunhuaikuan/ondemand/blue_gpn/examples/msa/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)
model = AutoModel.from_pretrained(model_path).to(device)
model.eval();

### Main Function to get Embedding

In [None]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

max_seqlen=128

def Genosome2Embedding(chrom, pos_start, pos_end, y): 
    msa = genome_msa.get_msa(str(chrom), pos_start, pos_end, strand="+", tokenize=True)
    # print(msa.shape)

    msa = torch.tensor(np.expand_dims(msa, 0).astype(np.int64))
    # msa

    # separating human from rest of species
    input_ids, aux_features = msa[:, :, 0], msa[:, :, 1:]
    
    input_ids = input_ids.to(device)
    aux_features = aux_features.to(device)


    with torch.no_grad():
        last_hidden_state = model(input_ids=input_ids, aux_features=aux_features).last_hidden_state
        
        # Mean Pooling: Compute the mean across the sequence length (dim=1)
        mean_pooled = last_hidden_state.mean(dim=1)  # Shape: (batch_size, embedding_dim)

    feature=np.append(mean_pooled.cpu().numpy(),  [ y]) 
    
    return feature

### Output CSV File

In [None]:
def output2CSV(df, csv_Filename):

    if os.path.exists(csv_Filename):
        os.remove(csv_Filename)

    rows=[]

    for index, row in df.iterrows():
        
        chrom=row['CHROM']
        pos_start=row['START']
        pos_end=row['END']
        y=row['y']
        try:
            embedding  =  Genosome2Embedding(chrom, pos_start,pos_end,y) # ref,alt, 
            rows.append(embedding)
    
        except Exception as e:
            print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START']))
    

        if ((index % 5000) ==0):
            with open(csv_Filename, mode='a', newline='') as file:
                writer = csv.writer(file)
                for row in rows:
                    writer.writerow(row)
            rows=[]
            # progress_bar.update(1)
            print(f"complete index={index}")

    
    with open(csv_Filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        for row in rows:
            writer.writerow(row)

    print(f"Create File: "+csv_Filename)

### Load mathylation data

In [None]:
datafile='methylation'

import gzip
import pandas as pd

max_length= 128 

data_filename = '../../datasets/task05-methylation/GSM6637962_CpG_coverage20_GRCh38.bed.gz'     
with gzip.open(data_filename, 'rt') as f:
    df = pd.read_csv(f, sep='\t')  # Automatically detects header from the file


df['CHROM'] = df['CHROM'].str.replace('chr', '', regex=False)
df['START']=df['FROM']- max_length //2 -1
df['END']=df['START'] + max_length


df = df.rename(columns={'Percentage':'y'})
df = df.drop(['FROM','TO','Coverage'], axis=1)


cols = df.columns.tolist()

# Move the 3rd column to the last position
cols.append(cols.pop(2))

# Reorder the DataFrame
df = df[cols]
df=df[~df['CHROM'].str.contains('KI',na=False)]
df=df[~df['CHROM'].str.contains('GL',na=False)]
df=df[~df['CHROM'].str.contains('M',na=False)]
df

In [None]:
%%time

output2CSV(df,datafile+'_gpn_embedding.csv')

### Load CSV File

In [None]:
df = load_embedding_file(datafile+'_gpn_embedding.csv')
df