In [1]:
import jax
# Global flag to set a specific platform, must be used at startup.
#jax.config.update('jax_platform_name', 'cpu')

import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='False'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='1'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'


In [2]:
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from jax_unirep import get_reps
  
from jax_unirep import evotune, fit
from jax_unirep.utils import dump_params

In [3]:
import gc
import jax 
def clear_jax_caches():
  """Utility to clear all the function caches in jax."""
  # main jit/pmap lu wrapped function caches - have to grab from closures
  jax.xla._xla_callable.__closure__[1].cell_contents.clear()
  jax.pxla.parallel_callable.__closure__[1].cell_contents.clear()
  # primitive callable caches
  jax.xla.xla_primitive_callable.cache_clear()
  jax.xla.primitive_computation.cache_clear()
  # jaxpr caches for control flow and reductions
  jax.lax.lax_control_flow._initial_style_jaxpr.cache_clear()
  jax.lax.lax_control_flow._fori_body_fun.cache_clear()
  jax.lax.lax._reduction_jaxpr.cache_clear()
  # these are trivial and only included for completeness sake
  jax.lax.lax.broadcast_shapes.cache_clear()
  jax.xla.xb.get_backend.cache_clear()
  jax.xla.xb.dtype_to_etype.cache_clear()
  jax.xla.xb.supported_numpy_dtypes.cache_clear()
    
def reset_device_memory(delete_objs=True):
    """Free all tracked DeviceArray memory and delete objects.

  Args:
    delete_objs: bool: whether to delete all live DeviceValues or just free.

  Returns:
    number of DeviceArrays that were manually freed.
  """
    dvals = (x for x in gc.get_objects() if isinstance(x, jax.xla.DeviceArray))
    n_deleted = 0
    for dv in dvals:
    
        if not isinstance(dv, jax.xla.DeviceConstant):
            try: 
                dv._check_if_deleted()  # pylint: disable=protected-access
                dv.device_buffer.delete()
                n_deleted += 1
            except:
                pass
        if delete_objs:

            del dv
    del dvals
    gc.collect()
    return n_deleted

In [4]:
#db_path="AMPS_NonAMPs.ready.csv"
db_path="/mnt/vdb/thesis/AMP_NonAMPs.recal.ready.csv"

In [5]:
df=pd.read_csv(db_path,sep=',',header=0,quoting=csv.QUOTE_ALL)
df # Class 0= AMPs , 1=NonAMps

Unnamed: 0,ID,Sequence,length,class
0,EN92515250|C|B3FJD7|phage,MAKKSVPLRKPAGSDGQGNIKVPGGPVVLDLGDFDDIFGPMESESP...,2337,0
1,EN54061055|C|F8SJ56|phage,MASKKTTLPKPKGINPQGSIVQLDLDDFDDLFDEDFGLPKKNSPYT...,2319,0
2,EN4815120|C|Q8SCY1|phage,MAKKVTLPKGQTGATGTTLGQAGNILDLSDVDDIFGDTPKAKKGSP...,2237,0
3,EN3016141|D|D2J8A7|bacteriocin,MAETIKGLRIDLSLKDMGVGRSITELKRSFRTLNSDLKVSSKNFEY...,1619,0
4,EN9175723|BD|Q93IM3|bacteriocin,MAKKKNTYKVPSIIALTLAGTALTTHHAQAADKTQDQSTNKNILND...,1564,0
...,...,...,...,...
242048,UniRef50_U4KZ92,MHLEFNRARAEYQSTPCATKPFLWL,25,1
242049,UniRef50_F6KAZ0,MQFFIFVSILCLFLENVGAFYMYSS,25,1
242050,UniRef50_A0A2S4WHR4,MGNMFLIRPMEASHSSASGVPSLCG,25,1
242051,UniRef50_A0A0G4NLI8,HHALAPLPRNRHHVARRPRRRRPPR,25,1


In [6]:
# Setting my chunk size
chunk_size = 500
# Assigning chunk numbers to rows
df['chunk'] = df.index.map(lambda x: int(int(x)/ chunk_size))
# We don't want the 'chunk' and 'index' columns in the output
cols = [col for col in df.columns if col not in ['chunk']]
# groupby chunk and export each chunk to a different csv.
i = 0
for _, chunk in df.groupby('chunk'):
    chunk[cols].to_csv(f'/mnt/vdb/thesis/jax/chunk{i}.csv',sep=",", quotechar='"',index=False, quoting=csv.QUOTE_ALL) # <<-- change this line 
    i += 1
print("complete")

complete


In [7]:
def createREPs(df, filename):
    _h_avg, h_final, c_final= get_reps(df['Sequence'].to_list())
    df.drop(columns=['Sequence'], inplace=True)
    df['reps']=_h_avg.tolist() # if there is a problem , might be here , possible solution is reindexing
    df.to_pickle(filename) # dont forget to change the file name to 0_613834

In [8]:
import glob
appended_reps = []
for infile in glob.glob("/mnt/vdb/thesis/jax/*.pkl"): # <<-- change this line 
    # print(infile)
    appended_reps.append(infile)
print(len(appended_reps))

0


In [9]:
for infile in glob.glob("/mnt/vdb/thesis/jax/*.csv"): # <<-- change this line 
    #print("Read:"+infile)
    file_name = os.path.basename(infile)
    result="/mnt/vdb/thesis/jax/"+file_name.replace("csv", "pkl") # <<-- change this line 
    if result in appended_reps :
        # print("found then skip : " , result)
        continue
    else:
        df = pd.read_csv(infile) 
        createREPs(df,result)
        print("Save:"+result)
        reset_device_memory()
        clear_jax_caches()
print("complete")

Save:/mnt/vdb/thesis/jax/chunk470.pkl
Save:/mnt/vdb/thesis/jax/chunk124.pkl
Save:/mnt/vdb/thesis/jax/chunk50.pkl
Save:/mnt/vdb/thesis/jax/chunk155.pkl
Save:/mnt/vdb/thesis/jax/chunk214.pkl
Save:/mnt/vdb/thesis/jax/chunk10.pkl
Save:/mnt/vdb/thesis/jax/chunk396.pkl
Save:/mnt/vdb/thesis/jax/chunk251.pkl
Save:/mnt/vdb/thesis/jax/chunk352.pkl
Save:/mnt/vdb/thesis/jax/chunk171.pkl
Save:/mnt/vdb/thesis/jax/chunk415.pkl
Save:/mnt/vdb/thesis/jax/chunk479.pkl
Save:/mnt/vdb/thesis/jax/chunk349.pkl
Save:/mnt/vdb/thesis/jax/chunk395.pkl
Save:/mnt/vdb/thesis/jax/chunk130.pkl
Save:/mnt/vdb/thesis/jax/chunk300.pkl
Save:/mnt/vdb/thesis/jax/chunk484.pkl
Save:/mnt/vdb/thesis/jax/chunk147.pkl
Save:/mnt/vdb/thesis/jax/chunk135.pkl
Save:/mnt/vdb/thesis/jax/chunk245.pkl
Save:/mnt/vdb/thesis/jax/chunk73.pkl
Save:/mnt/vdb/thesis/jax/chunk123.pkl
Save:/mnt/vdb/thesis/jax/chunk133.pkl
Save:/mnt/vdb/thesis/jax/chunk49.pkl
Save:/mnt/vdb/thesis/jax/chunk421.pkl
Save:/mnt/vdb/thesis/jax/chunk168.pkl
Save:/mnt/vdb/th

In [10]:
reset_device_memory()
clear_jax_caches()

In [11]:
def mergeDF(path,to_dir,file_name):
    appended_data = []
    for infile in glob.glob(path):
        #print(infile)
        data = pd.read_pickle(infile)
        # store DataFrame in list
        appended_data.append(data)
    result_path=to_dir+"/"+file_name
    print("Save:",result_path)
    appended_data = pd.concat(appended_data)
    appended_data.sort_values(by=['ID'], inplace=True)
    appended_data.to_pickle(result_path)
    return appended_data

In [12]:
# write result
final_df = mergeDF("/mnt/vdb/thesis/jax/*.pkl","/mnt/vdb/thesis/jax","AMPNonAMP.final.reps.recal") # <<-- change this line 
final_df

Save: /mnt/vdb/thesis/jax/AMPNonAMP.final.reps.recal


Unnamed: 0,ID,length,class,reps
492,0_antitbpred|antitbpred,33,0,"[0.021885788068175316, 0.06677422672510147, 0...."
457,0_peptideDB.anti|peptideDB.anti,148,0,"[0.006680202670395374, -0.09558603912591934, 0..."
194,1000_pos_train_ds3|pos_train_ds3,86,0,"[0.012824634090065956, 0.0021224257070571184, ..."
367,10023_dbaasp|dbaasp_peptides,36,0,"[0.0037219703663140535, -0.07121428847312927, ..."
469,"1003,1011,1019,1027,1035|CancerPPD_l_natural",20,0,"[0.02989775314927101, -0.004465686623007059, -..."
...,...,...,...,...
185,tagenome__1003787_1003787.scaffolds.fasta_scaf...,60,0,"[0.01130302157253027, -0.07055055350065231, 0...."
108,tagenome__1003787_1003787.scaffolds.fasta_scaf...,58,0,"[0.01459740474820137, -0.13994133472442627, 0...."
362,tagenome__1003787_1003787.scaffolds.fasta_scaf...,56,0,"[0.007430919446051121, -0.06292885541915894, 0..."
107,tagenome__1003787_1003787.scaffolds.fasta_scaf...,90,0,"[0.009985578246414661, -0.025958647951483727, ..."


Remove Duplicate

## Split train and test set

In [None]:
# convert array value from  single column into separate column
df = pd.concat([final_df.pop('reps').apply(pd.Series), final_df['class'],final_df['ID'],final_df['length']], axis=1)
df
#df =final_df[["reps","class"]]
#df
#df_new = df.reps.apply(pd.Series).astype(np.float64)
#df_new['class'] = df['class']
#df_new

In [12]:
X= np.array(final_df['reps'].to_list())
y= np.array(final_df['class'].to_list())

In [13]:
X.shape

(254036, 1900)

In [15]:
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=1) # 0.25 x 0.8 = 0.2

# Implement Toy Model (RF)

In [16]:
from sklearn.ensemble import RandomForestClassifier 

classifier = RandomForestClassifier ( random_state=0)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)

KeyboardInterrupt: 

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
print(classification_report(y_test,y_pred))
print(accuracy_score(y_test, y_pred))

In [None]:
from sklearn.metrics import plot_confusion_matrix

class_names = ['AMPs', 'NonAMPs']

disp = plot_confusion_matrix(classifier, X_test, y_test,
                            display_labels = class_names,
                            cmap=plt.cm.Blues, xticks_rotation='vertical')

disp.ax_.set_title(" Confusion Matrix")

print(disp.confusion_matrix)
plt.grid(False)
plt.show()