In [1]:
# using Kernel PyTorch-1.10
# using Kernel PyTorch-1.10
# using Kernel PyTorch-1.10

from gpn.data import GenomeMSA #, Tokenizer
import gpn.model

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from transformers import AutoModel #, AutoModelForMaskedLM
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# import pickle
# import re
import os
import csv
import warnings
warnings.filterwarnings('ignore')

%run preprocess_utility.py

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
model_path = "songlab/gpn-msa-sapiens"
msa_path = "zip:///::/home/sunhuaikuan/ondemand/blue_gpn/examples/msa/89.zarr.zip"
genome_msa = GenomeMSA(msa_path)
model = AutoModel.from_pretrained(model_path).to(device)
model.eval();

Loading MSA...
Loading MSA... Done


Some weights of the model checkpoint at songlab/gpn-msa-sapiens were not used when initializing GPNRoFormerModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing GPNRoFormerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPNRoFormerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Main Function to get Embedding

In [3]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

max_seqlen=128

def Genosome2Embedding(chrom, pos_start, pos_end, rowid, y):
    # tokenized msa
    # pos_start = pos - 65
    # pos_end=pos_start + 128
    msa = genome_msa.get_msa(str(chrom), pos_start, pos_end, strand="+", tokenize=True)
    # print(msa.shape)

    msa = torch.tensor(np.expand_dims(msa, 0).astype(np.int64))
    # msa

    # separating human from rest of species
    input_ids, aux_features = msa[:, :, 0], msa[:, :, 1:]
    
    input_ids = input_ids.to(device)
    aux_features = aux_features.to(device)


    with torch.no_grad():
        last_hidden_state = model(input_ids=input_ids, aux_features=aux_features).last_hidden_state
        
        # Mean Pooling: Compute the mean across the sequence length (dim=1)
        mean_pooled = last_hidden_state.mean(dim=1)  # Shape: (batch_size, embedding_dim)

    feature=np.append(mean_pooled.cpu().numpy(),  [rowid, y]) # chrom, pos_end-pos_start,
    
    return feature

### Output CSV File

In [4]:
def output2CSV(df, csv_Filename):

    if os.path.exists(csv_Filename):
        os.remove(csv_Filename)

    rows=[]
    # progress_bar = tqdm(total=df.shape[0], desc="Processing")
    for index, row in df.iterrows():
        
        chrom=row['CHROM']
        pos_start=row['START']
        pos_end=row['END']
        rowid=row['ROWID']
        y=row['y']
        try:
            embedding  =  Genosome2Embedding(chrom, pos_start, pos_end, rowid, y)
            rows.append(embedding)
    
        except Exception as e:
            print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START']))
    
        # progress_bar.update(1)


        if ((index % 1000) ==0):
            with open(csv_Filename, mode='a', newline='') as file:
                writer = csv.writer(file)
                # writer.writerow(rows)
                for row in rows:
                    writer.writerow(row)
            rows=[]
            # progress_bar.update(1)
            print(f"complete index={index}")

    
    with open(csv_Filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        # writer.writerow(rows)
        for row in rows:
            writer.writerow(row)

    print(f"Create File: "+csv_Filename)
    # progress_bar.reset()
    # progress_bar.close()

### Load Homo_Sapiens data

In [6]:
datafile_path = '/blue/xiaofan/chenyuanhan/data/Homo_sapiens.GRCh38.109.txt.gz'  
df = preprocess_home_sapiens_datafile(datafile_path)

# Assuming df is your DataFrame and max_seqlen is defined
df.loc[df['SIZE'] > max_seqlen, 'END'] = df['START'] + max_seqlen
df.loc[df['SIZE'] > max_seqlen, 'SIZE'] = max_seqlen

df=df.drop(columns=['TYPE','CLUSTER'])
df

Unnamed: 0,CHROM,START,END,SIZE,ROWID,y
0,7,116953541,116953669,128,0,0
1,12,54241755,54241883,128,1,0
2,20,38033948,38034076,128,2,0
3,15,82659560,82659688,128,3,0
4,9,131579497,131579625,128,4,0
...,...,...,...,...,...,...
34832,17,19189665,19189793,128,34832,6
34833,17,19188016,19188144,128,34833,6
34834,17,19557211,19557339,128,34834,6
34835,17,19061912,19062040,128,34835,6


In [7]:
%%time
output2CSV(df,'./homo_sapiens_gpn_embedding.csv')

complete index=0
complete index=1000
complete index=2000
complete index=3000
complete index=4000
complete index=5000
complete index=6000
complete index=7000
complete index=8000
complete index=9000
complete index=10000
complete index=11000
complete index=12000
complete index=13000
complete index=14000
complete index=15000
complete index=16000
complete index=17000
complete index=18000
complete index=19000
complete index=20000
complete index=21000
complete index=22000
complete index=23000
complete index=24000
complete index=25000
complete index=26000
complete index=27000
complete index=28000
complete index=29000
complete index=30000
complete index=31000
complete index=32000
complete index=33000
complete index=34000
Create File: ./homo_sapiens_gpn_embedding.csv
CPU times: user 7min 11s, sys: 1min 18s, total: 8min 30s
Wall time: 30min 22s


### Load CSV File

In [9]:
def load_embedding_file(csv_filename):

    df=pd.read_csv(csv_filename)
    
    column_names = [f'{i}' for i in range(1, df.shape[1]-1)]
    column_names.extend(['ROWID', 'y']) # 'CHROM', 'SIZE',
    
    df.columns = column_names
    return df

df = load_embedding_file('./homo_sapiens_gpn_embedding.csv')
df

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,761,762,763,764,765,766,767,768,ROWID,y
0,-0.444453,0.829066,0.096076,-0.693920,1.323465,-0.858230,-0.294933,-0.434608,-0.119355,-0.269585,...,-0.447579,-0.130508,0.456571,-0.614223,-0.021212,-0.397703,0.116457,-0.461298,1.0,0.0
1,0.272819,0.866934,0.527317,-0.374678,0.796706,-0.783266,-0.135167,-0.243605,0.309559,0.478524,...,-0.467666,-0.484659,0.188427,-0.111583,-0.280430,-0.540670,-0.369964,0.163323,2.0,0.0
2,0.002897,0.300540,0.363745,0.083066,0.097976,-0.705333,-0.345298,-0.162139,0.272691,0.111771,...,-0.375416,-0.182703,-0.326266,-0.158091,-0.380751,-0.772700,0.031157,0.344509,3.0,0.0
3,-0.497240,0.093710,0.335565,-0.105232,1.134319,-0.539970,-0.541573,-0.415931,-0.286506,0.133382,...,-0.464082,-0.365336,-0.058795,-0.306530,-0.313155,-0.506973,0.349965,-0.236048,4.0,0.0
4,0.285708,0.275770,0.598883,0.461520,0.081254,-1.124011,-0.944545,-0.015475,0.940282,0.396197,...,-0.210003,-0.503278,-0.084990,0.244842,-0.982839,-0.535660,-0.367225,0.574554,5.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34785,1.563133,-0.323035,0.308644,0.509430,-1.153867,-0.625784,-0.907656,-0.554260,0.927047,-1.098697,...,-0.255241,0.468471,-0.366710,0.713182,0.678800,-1.091625,-1.163554,0.292450,34832.0,6.0
34786,0.382414,0.100565,-0.460504,-0.205118,-0.021366,-0.848140,-0.920970,-0.118495,0.095080,-0.405723,...,0.510821,-0.204055,0.207660,0.138362,-0.322580,-0.945742,0.485976,0.299117,34833.0,6.0
34787,0.274222,-0.483887,-1.149753,0.645911,-0.647589,-0.166167,0.047552,-0.063885,-0.347718,-0.856082,...,-0.792973,1.050195,0.788367,-0.599289,-0.246968,-1.378417,0.941640,0.635049,34834.0,6.0
34788,1.150788,-0.794538,0.021731,-0.153305,-1.474532,-0.613830,-0.815617,-0.786006,0.084275,-1.344010,...,-0.287960,0.827012,-0.163107,0.470947,1.085012,-1.099518,-0.934582,-0.303798,34835.0,6.0


In [34]:
min_value = df['size'].min()
print(min_value)

40
