In [None]:
!pip install scikit-plot
!pip install jax-unirep

Try to deal with GPU memory 
accroading to https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

In [1]:
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='False'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.90'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'

import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
import gc
import jax 

from Bio import SeqIO
from jax_unirep import get_reps
from jax_unirep import evotune, fit
from jax_unirep.utils import dump_params



In [12]:
reset_device_memory(True)
clear_jax_caches()

In [2]:
import gc
import jax 
def clear_jax_caches():
  """Utility to clear all the function caches in jax."""
  # main jit/pmap lu wrapped function caches - have to grab from closures
  jax.xla._xla_callable.__closure__[1].cell_contents.clear()
  jax.pxla.parallel_callable.__closure__[1].cell_contents.clear()
  # primitive callable caches
  jax.xla.xla_primitive_callable.cache_clear()
  jax.xla.primitive_computation.cache_clear()
  # jaxpr caches for control flow and reductions
  jax.lax.lax_control_flow._initial_style_jaxpr.cache_clear()
  jax.lax.lax_control_flow._fori_body_fun.cache_clear()
  jax.lax.lax._reduction_jaxpr.cache_clear()
  # these are trivial and only included for completeness sake
  jax.lax.lax.broadcast_shapes.cache_clear()
  jax.xla.xb.get_backend.cache_clear()
  jax.xla.xb.dtype_to_etype.cache_clear()
  jax.xla.xb.supported_numpy_dtypes.cache_clear()
    
def reset_device_memory(delete_objs=True):
    """Free all tracked DeviceArray memory and delete objects.

  Args:
    delete_objs: bool: whether to delete all live DeviceValues or just free.

  Returns:
    number of DeviceArrays that were manually freed.
  """
    dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceArray))
    n_deleted = 0
    for dv in dvals:
    
        if not isinstance(dv, jax.xla.DeviceConstant):
            try: 
                dv._check_if_deleted()  # pylint: disable=protected-access
                dv.device_buffer.delete()
                n_deleted += 1
            except:
                pass
        if delete_objs:

            del dv
    del dvals
    gc.collect()
    return n_deleted

In [2]:
fastas = "/home/ubuntu/data/bk_fasta/SRR11234331.assembly.len15.fasta"
# plk = "/home/kongkitimanonk/SCRATCH_NOBAK/phase3/PoisonFrog.len15.pkl"

In [3]:
with open(fastas) as fasta_file:  # Will close handle cleanly
    identifiers = []
    lengths = []
    seqs = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):  # (generator)
        identifiers.append(seq_record.id)
        # Remove leading and trailing characters from a string
        seqs.append(str(seq_record.seq.strip('*')))
        lengths.append(len(seq_record.seq))

In [4]:
# dictionary of lists  
dict = {'ID': identifiers, 'Sequence': seqs, 'length': lengths}  
df = pd.DataFrame(dict) 
#df["Sequence"] =  seqs
#df.to_pickle(plk)    
df  

Unnamed: 0,ID,Sequence,length
0,0,VAQHRPLIPVLPERPQYSGRSLHSPAAVSMPLSDLDLLAVTDLSLS...,50
1,1,MERIELRHANSEEGRSQKTSCRYKFF,28
2,2,MQTLKRGGHRRLAVDTNFSEKPED,26
3,3,MPQLNSLHTSVSLSVTPL,20
4,4,MYCLAMLHPTQVRGR,17
...,...,...,...
12753725,12753725,MFFSSPSRRVFSLSNESS,20
12753726,12753726,MNPITPICTVHSTSIAFLFFILAV,26
12753727,12753727,CGLWGGQSVRTQRQLVMSDAQHVPSPAAGSELVSVLWAVGRQLLLL...,163
12753728,12753728,MPRAMRKRHRSRKRES,18


Remove Duplicate

In [5]:
df.drop_duplicates(subset=['Sequence'],inplace=True)
df

Unnamed: 0,ID,Sequence,length
0,0,VAQHRPLIPVLPERPQYSGRSLHSPAAVSMPLSDLDLLAVTDLSLS...,50
1,1,MERIELRHANSEEGRSQKTSCRYKFF,28
2,2,MQTLKRGGHRRLAVDTNFSEKPED,26
3,3,MPQLNSLHTSVSLSVTPL,20
4,4,MYCLAMLHPTQVRGR,17
...,...,...,...
12753717,12753717,EPQRRSARLSAKPAPPKAEPKPKKPPAAKKADKAQKRKKGKADSGKDA,48
12753718,12753718,HLFQNLPFLFFFFELCQPSLLLVAFWALALLLEGRV,37
12753720,12753720,PEAKVGGLCSRWKDSVFVRMVLELKVTVRHPRLRIWSLWWIFNSAG...,97
12753721,12753721,MKKVAILQENALLLIN,18


Create chunk

In [6]:
# Setting my chunk size
chunk_size = 1000
# Assigning chunk numbers to rows
df['chunk'] = df['ID'].apply(lambda x: int(int(x)/ chunk_size))
# We don't want the 'chunk' and 'index' columns in the output
cols = [col for col in df.columns if col not in ['chunk']]
# groupby chunk and export each chunk to a different csv.
i = 0
for _, chunk in df.groupby('chunk'):
    chunk[cols].to_csv(f'/mnt/vdb/PoisonFrog/chunks/chunk{i}.csv', index=False)
    i += 1
print("complete")

complete


In [9]:
#0-613834
#613834-1277668
#1277668-1841502
#1841502-2455336
#2455336 -3069170
# 3069170-3683004
#3683004-4296838
# 4296838 - 4910674 


df = df.iloc[0:613834, :]
#df2 = df.iloc[300000:613834, :]
# df3 = df.iloc[613834:1277668, :]
#df3 = df.iloc[1277668:1841502, :]
#df4 = df.iloc[1841502:2455336, :]
#df5 = df.iloc[2455336:3069170, :]
#df6 = df.iloc[3069170:3683004, :]
#df7 = df.iloc[3683004:4296838, :]
#df8 = df.iloc[4296838:4910674, :]

 check results if exists

In [13]:
import glob
appended_reps = []
for infile in glob.glob("/mnt/vdb/PoisonFrog/reps/*.pkl"):
    # print(infile)
    appended_reps.append(infile)
print(len(appended_reps))

2416


load chunk

In [4]:
def createREPs(df, filename):
    _h_avg, h_final, c_final= get_reps(df['Sequence'].to_list())
    df.drop(columns=['Sequence'], inplace=True)
    df['reps']=_h_avg.tolist() # if there is a problem , might be here , possible solution is reindexing
    df.to_pickle(filename) # dont forget to change the file name to 0_613834

In [None]:
for infile in glob.glob("/mnt/vdb/PoisonFrog/chunks/*.csv"):
    #print("Read:"+infile)
    file_name = os.path.basename(infile)
    result="/mnt/vdb/PoisonFrog/reps/"+file_name.replace("csv", "pkl")
    if result in appended_reps :
        # print("found then skip : " , result)
        continue
    else:
        df = pd.read_csv(infile) 
        createREPs(df,result)
        print("Save:"+result)
        reset_device_memory()
        clear_jax_caches()
print("complete")

In [55]:
data = pd.read_pickle("/mnt/vdb/PoisonFrog/reps/chunk4389.pkl")
data

Unnamed: 0,ID,length,reps
0,4389000,191,"[0.012593431398272514, 0.14382264018058777, 0...."
1,4389003,77,"[0.01161829475313425, -0.08347764611244202, 0...."
2,4389004,19,"[0.03577690199017525, -0.020374510437250137, -..."
3,4389009,47,"[0.014379587024450302, -0.01683482900261879, 0..."
4,4389010,47,"[0.021044721826910973, -0.08363862335681915, 0..."
...,...,...,...
374,4389983,56,"[0.01555175893008709, -0.0975966826081276, -0...."
375,4389987,17,"[0.03833804279565811, 0.01685425080358982, -0...."
376,4389988,35,"[0.01996765471994877, -0.0177371297031641, 0.0..."
377,4389991,247,"[0.0062294285744428635, 0.11091631650924683, 0..."


## CD100

In [3]:
fastas = "/home/ubuntu/data/bk_fasta/SRR11234331.assembly.len15.cd100.fasta"
with open(fastas) as fasta_file:  # Will close handle cleanly
    identifiers = []
    lengths = []
    seqs = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):  # (generator)
        identifiers.append(seq_record.id)
        # Remove leading and trailing characters from a string
        seqs.append(str(seq_record.seq.strip('*')))
        lengths.append(len(seq_record.seq))
# dictionary of lists  
dict = {'ID': identifiers, 'Sequence': seqs, 'length': lengths}  
df = pd.DataFrame(dict) 
#df["Sequence"] =  seqs
#df.to_pickle(plk)    
df  

Unnamed: 0,ID,Sequence,length
0,0,VAQHRPLIPVLPERPQYSGRSLHSPAAVSMPLSDLDLLAVTDLSLS...,50
1,8,CLLLFPFLHFARRGAAADTMQLFVRAQNLHTLEVSGQETVSQIKAH...,116
2,10,HMYQLEGLNWLRFSWAQGTDTILADEMGLGKTVQTIVFLYSLYKEG...,160
3,15,GTLYTYPDNWRAYKPLIAAQYSGFPITVASSAPEFQFGVTNKTPEF...,70
4,18,MYLFVFSILIIQYIFIGIAIISALLSCFCCENSCCAKICRYFESCL...,134
...,...,...,...
2592015,12753702,LLPCLNTSVVRNDGPRFGGGRVCRITALPSSYLLPSLAQSHTRVTM...,133
2592016,12753704,KLLWFFLELTLSLLFKPSEWRHMDDVPASLTSHSLIMPGCNLRAVY...,52
2592017,12753708,CLPFSQFALRVAALFAPSARGGSAAAAMSDYPGAQTGNRKYAFADA...,481
2592018,12753717,EPQRRSARLSAKPAPPKAEPKPKKPPAAKKADKAQKRKKGKADSGKDA,48


In [4]:
# Setting my chunk size
chunk_size = 2000
# Assigning chunk numbers to rows
df['chunk'] = df['ID'].apply(lambda x: int(int(x)/ chunk_size))
# We don't want the 'chunk' and 'index' columns in the output
cols = [col for col in df.columns if col not in ['chunk']]
# groupby chunk and export each chunk to a different csv.
i = 0
for _, chunk in df.groupby('chunk'):
    chunk[cols].to_csv(f'/mnt/vdb/PoisonFrog/cd100/chunks/chunk{i}.csv', index=False)
    i += 1
print("complete")

complete


In [6]:
import glob
appended_reps = []
for infile in glob.glob("/mnt/vdb/PoisonFrog/cd100/reps/*.pkl"):
    # print(infile)
    appended_reps.append(infile)
print(len(appended_reps))

6377


In [None]:
for infile in glob.glob("/mnt/vdb/PoisonFrog/cd100/chunks/*.csv"):
    #print("Read:"+infile)
    file_name = os.path.basename(infile)
    result="/mnt/vdb/PoisonFrog/cd100/reps/"+file_name.replace("csv", "pkl")
    if result in appended_reps :
        # print("found then skip : " , result)
        continue
    else:
        df = pd.read_csv(infile) 
        createREPs(df,result)
        print("Save:"+result)
        reset_device_memory()
        clear_jax_caches()
print("complete")

Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3141.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3659.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk249.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk5028.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk2107.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk2100.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3999.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk6114.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk759.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk2327.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3944.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3464.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk4877.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk6160.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3789.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3737.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3980.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk3019.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk288.pkl
Save:/mnt/vdb/PoisonFrog/cd100/reps/chunk292.pkl
Save

In [None]:
df2=""
createREPs(df3,"/home/kongkitimanonk/SCRATCH_NOBAK/phase3/PoisonFrog.len15.613834_1277668.pkl")

----- base ---

In [None]:
_h_avg, h_final, c_final= get_reps(df['Sequence'].to_list())
df['reps']=_h_avg.tolist()
df

In [None]:
df.to_pickle(plk)
