In [None]:
from gpn.data import GenomeMSA #, Tokenizer
import gpn.model

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from transformers import AutoModel #, AutoModelForMaskedLM
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import csv
import warnings
warnings.filterwarnings('ignore')

%run preprocess_utility.py

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model_path = "songlab/gpn-msa-sapiens"
msa_path = "zip:///::/home/sunhuaikuan/ondemand/blue_gpn/examples/msa/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)
model = AutoModel.from_pretrained(model_path).to(device)
model.eval();

### Main Function to get Embedding

In [None]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

max_seqlen=128

def Genosome2Embedding(chrom, pos_start, pos_end, rowid, y):

    msa = genome_msa.get_msa(str(chrom), pos_start, pos_end, strand="+", tokenize=True)
    
    msa = torch.tensor(np.expand_dims(msa, 0).astype(np.int64))
    
    # separating human from rest of species
    input_ids, aux_features = msa[:, :, 0], msa[:, :, 1:]
    
    input_ids = input_ids.to(device)
    aux_features = aux_features.to(device)


    with torch.no_grad():
        last_hidden_state = model(input_ids=input_ids, aux_features=aux_features).last_hidden_state
        
        # Mean Pooling: Compute the mean across the sequence length (dim=1)
        mean_pooled = last_hidden_state.mean(dim=1)  # Shape: (batch_size, embedding_dim)

    feature=np.append(mean_pooled.cpu().numpy(),  [rowid, y]) # chrom, pos_end-pos_start,
    
    return feature

### Output CSV File

In [None]:
def output2CSV(df, csv_Filename):

    if os.path.exists(csv_Filename):
        os.remove(csv_Filename)

    rows=[]
    for index, row in df.iterrows():
        
        chrom=row['CHROM']
        pos_start=row['START']
        pos_end=row['END']
        rowid=row['ROWID']
        y=row['y']
        try:
            embedding  =  Genosome2Embedding(chrom, pos_start, pos_end, rowid, y)
            rows.append(embedding)
    
        except Exception as e:
            print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START']))
    

        if ((index % 1000) ==0):
            with open(csv_Filename, mode='a', newline='') as file:
                writer = csv.writer(file)
                for row in rows:
                    writer.writerow(row)
            rows=[]
            # progress_bar.update(1)
            print(f"complete index={index}")

    
    with open(csv_Filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        for row in rows:
            writer.writerow(row)

    print(f"Create File: "+csv_Filename)


### Load Homo_Sapiens data

In [None]:
datafile_path = '../../datasets/task03-homo-sapiens/Homo_sapiens.GRCh38.109.txt.gz'  
df = preprocess_home_sapiens_datafile(datafile_path)

df.loc[df['SIZE'] > max_seqlen, 'END'] = df['START'] + max_seqlen
df.loc[df['SIZE'] > max_seqlen, 'SIZE'] = max_seqlen

df=df.drop(columns=['TYPE','CLUSTER'])
df

In [None]:
%%time
output2CSV(df,'./homo_sapiens_gpn_embedding.csv')

### Load CSV File

In [None]:
def load_embedding_file(csv_filename):

    df=pd.read_csv(csv_filename)
    
    column_names = [f'{i}' for i in range(1, df.shape[1]-1)]
    column_names.extend(['ROWID', 'y']) 
    
    df.columns = column_names
    return df

df = load_embedding_file('./homo_sapiens_gpn_embedding.csv')
df