In [3]:
import jax
# Global flag to set a specific platform, must be used at startup.
#jax.config.update('jax_platform_name', 'cpu')
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='False'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.90'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'

import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
import gc
import jax 

from Bio import SeqIO
from jax_unirep import get_reps
from jax_unirep import evotune, fit
from jax_unirep.utils import dump_params



In [4]:

def clear_jax_caches():
  """Utility to clear all the function caches in jax."""
  # main jit/pmap lu wrapped function caches - have to grab from closures
  jax.xla._xla_callable.__closure__[1].cell_contents.clear()
  jax.pxla.parallel_callable.__closure__[1].cell_contents.clear()
  # primitive callable caches
  jax.xla.xla_primitive_callable.cache_clear()
  jax.xla.primitive_computation.cache_clear()
  # jaxpr caches for control flow and reductions
  jax.lax.lax_control_flow._initial_style_jaxpr.cache_clear()
  jax.lax.lax_control_flow._fori_body_fun.cache_clear()
  jax.lax.lax._reduction_jaxpr.cache_clear()
  # these are trivial and only included for completeness sake
  jax.lax.lax.broadcast_shapes.cache_clear()
  jax.xla.xb.get_backend.cache_clear()
  jax.xla.xb.dtype_to_etype.cache_clear()
  jax.xla.xb.supported_numpy_dtypes.cache_clear()
    
def reset_device_memory(delete_objs=True):
    """Free all tracked DeviceArray memory and delete objects.

  Args:
    delete_objs: bool: whether to delete all live DeviceValues or just free.

  Returns:
    number of DeviceArrays that were manually freed.
  """
    dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceArray))
    n_deleted = 0
    for dv in dvals:
    
        if not isinstance(dv, jax.xla.DeviceConstant):
            try: 
                dv._check_if_deleted()  # pylint: disable=protected-access
                dv.device_buffer.delete()
                n_deleted += 1
            except:
                pass
        if delete_objs:

            del dv
    del dvals
    gc.collect()
    return n_deleted

In [5]:
def createREPs(df, filename):
    _h_avg, h_final, c_final= get_reps(df['Sequence'].to_list())
    df.drop(columns=['length','index'], inplace=True)
    df['reps']=_h_avg.tolist() # if there is a problem , might be here , possible solution is reindexing
    df.to_pickle(filename) # dont forget to change the file name to 0_613834

In [4]:
fastas = "/home/ubuntu/data/uniprot/uniprot_sprot.fasta"
with open(fastas) as fasta_file:  # Will close handle cleanly
    identifiers = []
    lengths = []
    seqs = []
    for seq_record in SeqIO.parse(fasta_file, 'fasta'):  # (generator)
        identifiers.append(seq_record.id)
        # Remove leading and trailing characters from a string
        seqs.append(str(seq_record.seq.strip('*')))
        lengths.append(len(seq_record.seq))

In [5]:
# dictionary of lists  
dict = {'ID': identifiers, 'Sequence': seqs, 'length': lengths}  
df = pd.DataFrame(dict) 
#df["Sequence"] =  seqs
#df.to_pickle(plk)    
df  

Unnamed: 0,ID,Sequence,length
0,sp|Q6GZX4|001R_FRG3G,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,256
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,320
2,sp|Q197F8|002R_IIV3,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,458
3,sp|Q197F7|003L_IIV3,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,156
4,sp|Q6GZX2|003R_FRG3G,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,438
...,...,...,...
563547,sp|Q6UY62|Z_SABVB,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,100
563548,sp|P08105|Z_SHEEP,MSSSLEITSFYSFIWTPHIGPLLFGIGLWFSMFKEPSHFCPCQHPH...,79
563549,sp|Q88470|Z_TACVF,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,95
563550,sp|A9JR22|Z_TAMVU,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,95


In [6]:
df["index"]=df.index.values
df

Unnamed: 0,ID,Sequence,length,index
0,sp|Q6GZX4|001R_FRG3G,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,256,0
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,320,1
2,sp|Q197F8|002R_IIV3,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,458,2
3,sp|Q197F7|003L_IIV3,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,156,3
4,sp|Q6GZX2|003R_FRG3G,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,438,4
...,...,...,...,...
563547,sp|Q6UY62|Z_SABVB,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,100,563547
563548,sp|P08105|Z_SHEEP,MSSSLEITSFYSFIWTPHIGPLLFGIGLWFSMFKEPSHFCPCQHPH...,79,563548
563549,sp|Q88470|Z_TACVF,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,95,563549
563550,sp|A9JR22|Z_TAMVU,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,95,563550


In [7]:
# Setting my chunk size
chunk_size = 100
# Assigning chunk numbers to rows
df['chunk'] = df['index'].apply(lambda x: int(int(x)/ chunk_size))
# We don't want the 'chunk' and 'index' columns in the output
cols = [col for col in df.columns if col not in ['chunk']]
# groupby chunk and export each chunk to a different csv.
i = 0
for _, chunk in df.groupby('chunk'):
    chunk[cols].to_csv(f'/mnt/vdb/uniprot/chunks/chunk{i}.csv', index=False)
    i += 1
print("complete")

complete


In [6]:
import glob
appended_reps = []
for infile in glob.glob("/mnt/vdb/uniprot/reps/*.pkl"):
    # print(infile)
    appended_reps.append(infile)
print(len(appended_reps))

3434


In [7]:
for infile in glob.glob("/mnt/vdb/uniprot/chunks/*.csv"):
    #print("Read:"+infile)
    file_name = os.path.basename(infile)
    result="/mnt/vdb/uniprot/reps/"+file_name.replace("csv", "pkl")
    if result in appended_reps :
        # print("found then skip : " , result)
        continue
    else:
        df = pd.read_csv(infile) 
        createREPs(df,result)
        print("Save:"+result)
        reset_device_memory()
        clear_jax_caches()
print("complete")

Save:/mnt/vdb/uniprot/reps/chunk1532.pkl
Save:/mnt/vdb/uniprot/reps/chunk4946.pkl
Save:/mnt/vdb/uniprot/reps/chunk656.pkl
Save:/mnt/vdb/uniprot/reps/chunk3614.pkl
Save:/mnt/vdb/uniprot/reps/chunk547.pkl
Save:/mnt/vdb/uniprot/reps/chunk1439.pkl
Save:/mnt/vdb/uniprot/reps/chunk2741.pkl
Save:/mnt/vdb/uniprot/reps/chunk3997.pkl
Save:/mnt/vdb/uniprot/reps/chunk3791.pkl
Save:/mnt/vdb/uniprot/reps/chunk3244.pkl
Save:/mnt/vdb/uniprot/reps/chunk348.pkl
Save:/mnt/vdb/uniprot/reps/chunk3748.pkl
Save:/mnt/vdb/uniprot/reps/chunk3817.pkl
Save:/mnt/vdb/uniprot/reps/chunk3777.pkl
Save:/mnt/vdb/uniprot/reps/chunk1260.pkl
Save:/mnt/vdb/uniprot/reps/chunk824.pkl
Save:/mnt/vdb/uniprot/reps/chunk4648.pkl
Save:/mnt/vdb/uniprot/reps/chunk2649.pkl
Save:/mnt/vdb/uniprot/reps/chunk4992.pkl
Save:/mnt/vdb/uniprot/reps/chunk460.pkl
Save:/mnt/vdb/uniprot/reps/chunk927.pkl
Save:/mnt/vdb/uniprot/reps/chunk1380.pkl
Save:/mnt/vdb/uniprot/reps/chunk628.pkl
Save:/mnt/vdb/uniprot/reps/chunk894.pkl
Save:/mnt/vdb/uniprot/re

In [11]:
import glob
import os
import numpy as np
import pandas as pd
def toJson(path):

    for infile in glob.glob(path):
        file_name = os.path.basename(infile)
        result="/mnt/vdb/uniprot/json/"+file_name.replace("pkl", "json")
        print("Save:"+result)
        # store DataFrame in list
        data = pd.read_pickle(infile)
        data.to_json(result)
print()

In [12]:

toJson("/mnt/vdb/uniprot/reps/*.pkl")


Save:/mnt/vdb/uniprot/json/chunk505.json
Save:/mnt/vdb/uniprot/json/chunk2495.json
Save:/mnt/vdb/uniprot/json/chunk2288.json
Save:/mnt/vdb/uniprot/json/chunk2634.json
Save:/mnt/vdb/uniprot/json/chunk5558.json
Save:/mnt/vdb/uniprot/json/chunk1047.json
Save:/mnt/vdb/uniprot/json/chunk3084.json
Save:/mnt/vdb/uniprot/json/chunk5629.json
Save:/mnt/vdb/uniprot/json/chunk2570.json
Save:/mnt/vdb/uniprot/json/chunk2580.json
Save:/mnt/vdb/uniprot/json/chunk1003.json
Save:/mnt/vdb/uniprot/json/chunk1284.json
Save:/mnt/vdb/uniprot/json/chunk3551.json
Save:/mnt/vdb/uniprot/json/chunk1653.json
Save:/mnt/vdb/uniprot/json/chunk4332.json
Save:/mnt/vdb/uniprot/json/chunk2610.json
Save:/mnt/vdb/uniprot/json/chunk3025.json
Save:/mnt/vdb/uniprot/json/chunk4829.json
Save:/mnt/vdb/uniprot/json/chunk3116.json
Save:/mnt/vdb/uniprot/json/chunk356.json
Save:/mnt/vdb/uniprot/json/chunk4017.json
Save:/mnt/vdb/uniprot/json/chunk2666.json
Save:/mnt/vdb/uniprot/json/chunk1852.json
Save:/mnt/vdb/uniprot/json/chunk5094

In [14]:
data_df = pd.read_json("/mnt/vdb/uniprot/json/chunk0.json")

In [17]:
data_df

Unnamed: 0,ID,Sequence,reps
0,sp|Q6GZX4|001R_FRG3G,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,"[0.0037700436, -0.0257038362, 0.0359647982, -0..."
1,sp|Q6GZX3|002L_FRG3G,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,"[0.0051520769, -0.0848760009, 0.0712415203, -0..."
2,sp|Q197F8|002R_IIV3,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,"[0.0043888474000000005, -0.0694381371, 0.03061..."
3,sp|Q197F7|003L_IIV3,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,"[0.006500421100000001, -0.037640102200000004, ..."
4,sp|Q6GZX2|003R_FRG3G,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,"[0.0035510727, 0.0003123366, 0.0461201183, -0...."
...,...,...,...
95,sp|Q91G40|065R_IIV6,MNIPKTCFQIHNKIQVQNYLIRINLNIFLIYHFSPIYCPYLFLFTV...,"[0.013552869700000001, -0.0287670381, -0.00065..."
96,sp|Q6GZQ9|066L_FRG3G,MPFYICSDPDPKRTVRGPRFTVPDPKPPPDPAHPLDDTDNVMTAFP...,"[0.009182518300000001, -0.0449988656, 0.005595..."
97,sp|Q6GZQ7|068R_FRG3G,MWIHFPVRNYIHLSHTILHYPCPDNAGHVCGHRNVSRNGLKDGLGV...,"[0.008656034200000001, -0.0397028476, 0.006880..."
98,sp|O55709|069L_IIV6,MSDKIDNQIVKVENTNNGGLRAIFNLDGVTLDTPIMGTWDKPVFFG...,"[0.0034565877, -0.0163974799, 0.0419957228, -0..."
