In [None]:
!pip install scikit-plot

In [None]:
!pip install jax-unirep

Try to deal with GPU memory 
accroading to https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

In [1]:
import jax
# Global flag to set a specific platform, must be used at startup.
#jax.config.update('jax_platform_name', 'cpu')

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='False'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.95'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'


In [2]:
import gc

def clear_jax_caches():
  """Utility to clear all the function caches in jax."""
  # main jit/pmap lu wrapped function caches - have to grab from closures
  jax.xla._xla_callable.__closure__[1].cell_contents.clear()
  jax.pxla.parallel_callable.__closure__[1].cell_contents.clear()
  # primitive callable caches
  jax.xla.xla_primitive_callable.cache_clear()
  jax.xla.primitive_computation.cache_clear()
  # jaxpr caches for control flow and reductions
  jax.lax.lax_control_flow._initial_style_jaxpr.cache_clear()
  jax.lax.lax_control_flow._fori_body_fun.cache_clear()
  jax.lax.lax._reduction_jaxpr.cache_clear()
  # these are trivial and only included for completeness sake
  jax.lax.lax.broadcast_shapes.cache_clear()
  jax.xla.xb.get_backend.cache_clear()
  jax.xla.xb.dtype_to_etype.cache_clear()
  jax.xla.xb.supported_numpy_dtypes.cache_clear()
    
def reset_device_memory(delete_objs=True):
    """Free all tracked DeviceArray memory and delete objects.

  Args:
    delete_objs: bool: whether to delete all live DeviceValues or just free.

  Returns:
    number of DeviceArrays that were manually freed.
  """
    dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceArray))
    n_deleted = 0
    for dv in dvals:
    
        if not isinstance(dv, jax.xla.DeviceConstant):
            try: 
                dv._check_if_deleted()  # pylint: disable=protected-access
                dv.device_buffer.delete()
                n_deleted += 1
            except:
                pass
        if delete_objs:

            del dv
    del dvals
    gc.collect()
    return n_deleted

In [3]:
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from jax_unirep import get_reps
  
from jax_unirep import evotune, fit
from jax_unirep.utils import dump_params

In [67]:
#db_path="AMPS_NonAMPs.ready.csv"
db_path="../../datasets/truthset.csv"

In [68]:
AMPs_df=pd.read_csv(db_path,sep=',',header=0)
AMPs_df # Class 0= AMPs , 1=NonAMps

Unnamed: 0,ID,paper,seq
0,ISGCock_Contig04_0915,PMC4864078,ALQICTRNMIDDRLPYVADNVRPGTFIKQQRKQKQQRHHTSGTRKR...
1,ISGCock_Contig13_4610,PMC4864078,HLYPCKLNLKLGKVPFHFLNLNHKGKSIMVNQQTCLYYIICQTR
2,ISGCock_Contig16_2060,PMC4864078,ISHNHLTAASITHVKNRGKYIYMHLKFRKTNVLI
3,ISGCock_Contig16_4974,PMC4864078,RKKVWFIFHVCPKLKQRILSDTHAKNKCRLSPLLIKSTKIKNET
4,ISGCock_Contig07_3736,PMC4864078,CNYISFFRKCKNSQSTMYGCHRMNKCVFSSY
...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,MRYIVCFVFFLFFFLLFLWLVPARTASSFLTPRLSSLGKRSWAV
58,TR27534|c0_g1_i1,hal-02965337,MLFKVIIVIWISVCRECTRGGFCNFMHLKPISRELRRELYGRTRRRRK
59,P3,S2162-2531(20)30132-3,FWELWKFLKSLWSIFPRRRP
60,P10,S2162-2531(20)30132-3,ICTTLNWMVKLTCLTHVTLTTRWC


In [69]:
AMPs_df.drop_duplicates(subset=['seq'],inplace=True)
# remove white space
AMPs_df['seq'] = AMPs_df['seq'].str.strip()
AMPs_df

Unnamed: 0,ID,paper,seq
0,ISGCock_Contig04_0915,PMC4864078,ALQICTRNMIDDRLPYVADNVRPGTFIKQQRKQKQQRHHTSGTRKR...
1,ISGCock_Contig13_4610,PMC4864078,HLYPCKLNLKLGKVPFHFLNLNHKGKSIMVNQQTCLYYIICQTR
2,ISGCock_Contig16_2060,PMC4864078,ISHNHLTAASITHVKNRGKYIYMHLKFRKTNVLI
3,ISGCock_Contig16_4974,PMC4864078,RKKVWFIFHVCPKLKQRILSDTHAKNKCRLSPLLIKSTKIKNET
4,ISGCock_Contig07_3736,PMC4864078,CNYISFFRKCKNSQSTMYGCHRMNKCVFSSY
...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,MRYIVCFVFFLFFFLLFLWLVPARTASSFLTPRLSSLGKRSWAV
58,TR27534|c0_g1_i1,hal-02965337,MLFKVIIVIWISVCRECTRGGFCNFMHLKPISRELRRELYGRTRRRRK
59,P3,S2162-2531(20)30132-3,FWELWKFLKSLWSIFPRRRP
60,P10,S2162-2531(20)30132-3,ICTTLNWMVKLTCLTHVTLTTRWC


In [70]:
AMPs_df['length']  = AMPs_df['seq'].str.len()
AMPs_df['length'].values

array([ 50,  44,  34,  44,  31,  37,  47,  18,  30,  46,  15,  14,  13,
        15,  13,  12,  12,   9,   9,  12,  11,  13,  13,  74,  74,  74,
        74,  74,  74,  71,  71,  81,  75,  75,  63,  64,  64,  66,  66,
        66,  66,  71,  66,  66,  66,  66,  66,  66,  66, 118, 119, 118,
       253, 242, 267, 245,  47,  44,  48,  20,  24,  29])

In [37]:
AMPs_df.iloc[0]['seq']

'ALQICTRNMIDDRLPYVADNVRPGTFIKQQRKQKQQRHHTSGTRKRMAKG'

In [38]:
_h_avg, h_final, c_final= get_reps(AMPs_df['seq'].to_list())
AMPs_df['reps']=_h_avg.tolist()
AMPs_df

Unnamed: 0,ID,paper,seq,length,reps
0,ISGCock_Contig04_0915,PMC4864078,ALQICTRNMIDDRLPYVADNVRPGTFIKQQRKQKQQRHHTSGTRKR...,50,"[0.019923143088817596, -0.04708784073591232, 0..."
1,ISGCock_Contig13_4610,PMC4864078,HLYPCKLNLKLGKVPFHFLNLNHKGKSIMVNQQTCLYYIICQTR,44,"[0.014724492095410824, -0.0309316273778677, 0...."
2,ISGCock_Contig16_2060,PMC4864078,ISHNHLTAASITHVKNRGKYIYMHLKFRKTNVLI,34,"[0.018701443448662758, -0.05508643761277199, -..."
3,ISGCock_Contig16_4974,PMC4864078,RKKVWFIFHVCPKLKQRILSDTHAKNKCRLSPLLIKSTKIKNET,44,"[0.019436540082097054, -0.03149893507361412, 0..."
4,ISGCock_Contig07_3736,PMC4864078,CNYISFFRKCKNSQSTMYGCHRMNKCVFSSY,31,"[0.03331663832068443, -0.09785227477550507, 0...."
...,...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,MRYIVCFVFFLFFFLLFLWLVPARTASSFLTPRLSSLGKRSWAV,44,"[0.016170913353562355, 0.00010415514407213777,..."
58,TR27534|c0_g1_i1,hal-02965337,MLFKVIIVIWISVCRECTRGGFCNFMHLKPISRELRRELYGRTRRRRK,48,"[0.015017224475741386, -0.03606594353914261, 0..."
59,P3,S2162-2531(20)30132-3,FWELWKFLKSLWSIFPRRRP,20,"[0.032129984349012375, -0.0004883870133198798,..."
60,P10,S2162-2531(20)30132-3,ICTTLNWMVKLTCLTHVTLTTRWC,24,"[0.02543189749121666, 0.006170747801661491, -0..."


In [39]:
AMPs_df.to_pickle('../../datasets/AMPs_truthset.reps.plk')
AMPs_df = pd.read_pickle('../../datasets/AMPs_truthset.reps.plk')
AMPs_df

Unnamed: 0,ID,paper,seq,length,reps
0,ISGCock_Contig04_0915,PMC4864078,ALQICTRNMIDDRLPYVADNVRPGTFIKQQRKQKQQRHHTSGTRKR...,50,"[0.019923143088817596, -0.04708784073591232, 0..."
1,ISGCock_Contig13_4610,PMC4864078,HLYPCKLNLKLGKVPFHFLNLNHKGKSIMVNQQTCLYYIICQTR,44,"[0.014724492095410824, -0.0309316273778677, 0...."
2,ISGCock_Contig16_2060,PMC4864078,ISHNHLTAASITHVKNRGKYIYMHLKFRKTNVLI,34,"[0.018701443448662758, -0.05508643761277199, -..."
3,ISGCock_Contig16_4974,PMC4864078,RKKVWFIFHVCPKLKQRILSDTHAKNKCRLSPLLIKSTKIKNET,44,"[0.019436540082097054, -0.03149893507361412, 0..."
4,ISGCock_Contig07_3736,PMC4864078,CNYISFFRKCKNSQSTMYGCHRMNKCVFSSY,31,"[0.03331663832068443, -0.09785227477550507, 0...."
...,...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,MRYIVCFVFFLFFFLLFLWLVPARTASSFLTPRLSSLGKRSWAV,44,"[0.016170913353562355, 0.00010415514407213777,..."
58,TR27534|c0_g1_i1,hal-02965337,MLFKVIIVIWISVCRECTRGGFCNFMHLKPISRELRRELYGRTRRRRK,48,"[0.015017224475741386, -0.03606594353914261, 0..."
59,P3,S2162-2531(20)30132-3,FWELWKFLKSLWSIFPRRRP,20,"[0.032129984349012375, -0.0004883870133198798,..."
60,P10,S2162-2531(20)30132-3,ICTTLNWMVKLTCLTHVTLTTRWC,24,"[0.02543189749121666, 0.006170747801661491, -0..."


In [45]:
reset_device_memory()
clear_jax_caches()

## Shuffle 

In [75]:
import random

def rand(stri,seed):
    random.seed(seed)
    return ''.join(random.sample(stri,len(stri)))

def start_shuffle(AMPs_df, i):
    seed = 42 + i
    print(seed)
    for i in range(9):
        AMPs_df['seq'] = AMPs_df['seq'].apply(lambda x : rand(x,i * seed+1))
    AMPs_df
    
    AMPs_df_1 = AMPs_df
    AMPs_df_1['seq'] = AMPs_df['seq'].apply(lambda x : rand(x,seed))
    _h_avg, h_final, c_final= get_reps(AMPs_df_1['seq'].to_list())
    AMPs_df_1['reps']=_h_avg.tolist()
    AMPs_df_1.to_pickle('../../datasets/AMPs_truthset_'+str(seed)+'.reps.plk')
    reset_device_memory()
    clear_jax_caches()
    

In [76]:
for x in range(10):
    start_shuffle(AMPs_df, x)

42
43
44
45
46
47
48
49
50
51


In [77]:
pd.read_pickle('../../datasets/AMPs_truthset_42.reps.plk')

Unnamed: 0,ID,paper,seq,length,reps
0,ISGCock_Contig04_0915,PMC4864078,RQDRAHNQIDGQRPMGRLDKRKIAMKCTLSNQVFAKRRITYQTHKG...,50,"[0.01679776795208454, -0.07345905154943466, 0...."
1,ISGCock_Contig13_4610,PMC4864078,ICCLSTLMGNPRVKHLNKNNTYPKYIVLYLFHKLGFIQQLKHQC,44,"[0.015076708979904652, -0.050224486738443375, ..."
2,ISGCock_Contig16_2060,PMC4864078,VKHTKRTISYIYHKILRNSLMFTHINVHANGALK,34,"[0.020523976534605026, -0.03999472036957741, -..."
3,ISGCock_Contig16_4974,PMC4864078,KKIDKLIRPTVTLVHKSKFSESKAKILHTCQRFKKICPLLNRNW,44,"[0.02007628232240677, -0.05524342507123947, 0...."
4,ISGCock_Contig07_3736,PMC4864078,SKMCINSRYSTYSCCMRFKKHVYSFNNGFCQ,31,"[0.022961627691984177, -0.03958016261458397, -..."
...,...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,KSSVSLFFFPIVFFARLFVTALLRYRLFGLLMCSTLSPRWAFWV,44,"[0.014877736568450928, 0.00022438698215410113,..."
58,TR27534|c0_g1_i1,hal-02965337,CESFVGRRRVTIMRPFIREFKICRIRWGMSIYRRGKNVHELLRKLTCL,48,"[0.01740586943924427, -0.043963976204395294, 0..."
59,P3,S2162-2531(20)30132-3,SFSKKLRLWEPIWWPFRRLF,20,"[0.024964744225144386, -0.0190703384578228, -0..."
60,P10,S2162-2531(20)30132-3,LICWTVKLCVTLTHTTLTCTWNMR,24,"[0.02561492845416069, 0.008683688007295132, -0..."


In [78]:
pd.read_pickle('../../datasets/AMPs_truthset_46.reps.plk')

Unnamed: 0,ID,paper,seq,length,reps
0,ISGCock_Contig04_0915,PMC4864078,MLHIKQTVDIQTRYAKVKKAKHNRRQSPQQPDGQRIRARDGTTGFN...,50,"[0.013063241727650166, -0.042226146906614304, ..."
1,ISGCock_Contig13_4610,PMC4864078,YTPLLLKVKHIQQQNSPIYLMCRCGLHFNGKHYIVTFNNCKLKL,44,"[0.01589961349964142, -0.10049132257699966, 0...."
2,ISGCock_Contig16_2060,PMC4864078,TKHLNNLINHYIIVSKHTGMIKSRTARYKFAHVL,34,"[0.019583813846111298, -0.04624079167842865, -..."
3,ISGCock_Contig16_4974,PMC4864078,KEKKCHNLKHKNPLFKVCTLRITWPKRQSKVRSILLITSKFDAI,44,"[0.014719865284860134, -0.07682189345359802, 0..."
4,ISGCock_Contig07_3736,PMC4864078,SKYFRYCKSHNMRSIVGFNSTNKYSFCQCMC,31,"[0.022735636681318283, 0.010877581313252449, 0..."
...,...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,YALRLFAFFAKWPRVSISGWFSVVFSFLTTFMLRLLLPLSCVRF,44,"[0.014602179639041424, 0.003792224684730172, 0..."
58,TR27534|c0_g1_i1,hal-02965337,TRGERIFRMPRLTLKRLIKEYRISWESGRFNLFIVCRHIVCMGVRKRC,48,"[0.019080253317952156, -0.034618228673934937, ..."
59,P3,S2162-2531(20)30132-3,PKRFFELKLSWFSLIRWWPR,20,"[0.03329581022262573, -0.008069236762821674, -..."
60,P10,S2162-2531(20)30132-3,HLTTTVTCLLMLWTTCRCTIVWKN,24,"[0.024820439517498016, 0.00955728068947792, -0..."


In [79]:
pd.read_pickle('../../datasets/AMPs_truthset_48.reps.plk')

Unnamed: 0,ID,paper,seq,length,reps
0,ISGCock_Contig04_0915,PMC4864078,TRQQPLDAKTIHVRSYKAKGTMQRLRFDKIDQRMRQHRGNQPGTNI...,50,"[0.014190083369612694, -0.045038964599370956, ..."
1,ISGCock_Contig13_4610,PMC4864078,FTLHVNGTKKQCYFINMCPILKLKCNNLSQGVIKQLYHLRYHLP,44,"[0.01409909874200821, -0.04996402561664581, 0...."
2,ISGCock_Contig16_2060,PMC4864078,TIHMNGNIHHLVTFYKYLRAIVRKTLKHAKISSN,34,"[0.017020177096128464, -0.047774046659469604, ..."
3,ISGCock_Contig16_4974,PMC4864078,QLIRLSPEAVLITICFRKKKLKHFWSTKKNKLINPCKHDTSRKV,44,"[0.012386562302708626, -0.04997082054615021, 0..."
4,ISGCock_Contig07_3736,PMC4864078,SFINVYNKHKCSSCFMRGMRSYCNQKFTYSC,31,"[0.028336286544799805, -0.13328006863594055, 0..."
...,...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,LLFFLTFARFRSGLSVFSLKWFFCVLPRSWTFRAPLYAVVLMSI,44,"[0.01515231467783451, 0.0020381335634738207, -..."
58,TR27534|c0_g1_i1,hal-02965337,FRGETRCMIGKRHRLRINRRWSESGVPCRIYVEFMLLVITFLKKRRIC,48,"[0.01888762228190899, -0.05116211995482445, 0...."
59,P3,S2162-2531(20)30132-3,FLEPWWPRSKSWRFLFIKLR,20,"[0.030161423608660698, -0.035769350826740265, ..."
60,P10,S2162-2531(20)30132-3,TLCCLMWTTWKCVIVHLNRTTLTT,24,"[0.024908849969506264, 0.008124792017042637, -..."


Shuffle ten times

In [71]:
for i in range(10):
    AMPs_df['seq'] = AMPs_df['seq'].apply(lambda x : rand(x,i))
AMPs_df

Unnamed: 0,ID,paper,seq,length
0,ISGCock_Contig04_0915,PMC4864078,KRRAQCQVDPDQIPRLKTLTQRDGQHKKRNNHATKQMRIAISRGFY...,50
1,ISGCock_Contig13_4610,PMC4864078,LLLKQNYLHVMPKCLHQTRVNCHKFNPLQKLNGSYGYIICKITF,44
2,ISGCock_Contig16_2060,PMC4864078,SVLHHYANTNKIHIKRLTKIGALKYIFTHNSMRV,34
3,ISGCock_Contig16_4974,PMC4864078,IDKFPSTCRLRKKWHRLLTLFIHAQSVLNVKTPKSKKKCKNIEI,44
4,ISGCock_Contig07_3736,PMC4864078,SCTMYKYSVQFNFSCKSHKCNFGNIRYMRCS,31
...,...,...,...,...
57,TR42258|c1_g1_i1,hal-02965337,FVSCPLGLMLFLFVFFRLVFVSARLTIWWFRPFSLTYKSSARAL,44
58,TR27534|c0_g1_i1,hal-02965337,VFRRLELSNPRKFELYEIGVKVWIRRCICKFISLMHIRGTCRTMRRGR,48
59,P3,S2162-2531(20)30132-3,RSWSEKIRRWPWLFLKFLFP,20
60,P10,S2162-2531(20)30132-3,TVVLLLKMWHTTCCITTTWRTCLN,24


In [72]:
_h_avg, h_final, c_final= get_reps(AMPs_df['seq'].to_list())
AMPs_df['reps']=_h_avg.tolist()
AMPs_df.to_pickle('../../datasets/AMPs_truthset_10times.reps.plk')
reset_device_memory()
clear_jax_caches()