<a href="https://colab.research.google.com/github/xlopez-ml/DL-AMR/blob/master/Examples/DeepAMR_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#0-Libraries

In [11]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import itertools
import pandas as pd
from sklearn.model_selection import train_test_split
import sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix,classification_report,ConfusionMatrixDisplay,matthews_corrcoef
from sklearn.preprocessing import Normalizer
from sklearn.compose import ColumnTransformer
import tensorflow as tf
from tensorflow import keras
from keras import regularizers
from keras.backend import expand_dims
from keras.models import load_model
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.models import Sequential
from keras.constraints import MaxNorm
from keras.layers import Activation, Dense, Conv1D, Flatten, MaxPooling1D, Dropout, BatchNormalization

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
      keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

#1-Load Data and Model

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
model = load_model('/content/drive/MyDrive/Colab Notebooks/DRIAMS/s_aureus_oxacillin.h5')
model.summary()

Model: "Modelo_s_aureus_oxacillin"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Conv_1 (Conv1D)             (None, 5985, 64)          1152      
                                                                 
 batch_normalization_8 (Batc  (None, 5985, 64)         256       
 hNormalization)                                                 
                                                                 
 activation_8 (Activation)   (None, 5985, 64)          0         
                                                                 
 MaxPooling1D (MaxPooling1D)  (None, 2992, 64)         0         
                                                                 
 Conv_2 (Conv1D)             (None, 2984, 128)         73856     
                                                                 
 batch_normalization_9 (Batc  (None, 2984, 128)        512       
 hNormalization)                         

In [9]:
s_aureus_driams_b = pd.read_csv('/content/drive/MyDrive/New driams databae/Datasets Driams con espectro de masa/Driams_b/s_aureus_driams_b_bin3_2000_20000Da.csv')
s_aureus_driams_b.head()

Unnamed: 0,2000,2003,2006,2009,2012,2015,2018,2021,2024,2027,...,19991,19994,19997,code,species,Oxacillin,Clindamycin,Ceftriaxone,Ciprofloxacin,Fusidic acid
0,3894.285714,4288.428571,3771.714286,5134.714286,3902.142857,3062.571429,3026.0,3078.857143,3751.875,3582.142857,...,19.666667,18.0,16.2,379e3abe-c5b2-4f92-8f2f-0c9dd0a2c7b0,Staphylococcus aureus,0.0,0.0,,0.0,0.0
1,7327.714286,7367.0,9050.714286,9410.285714,8567.571429,9221.0,7407.857143,7006.857143,6694.142857,6969.714286,...,246.5,226.0,241.820755,eed06320-c82a-43a2-ad35-139e4e082044,Staphylococcus aureus,0.0,0.0,,0.0,0.0
2,5981.142857,6145.0,7768.75,6982.142857,6709.428571,6847.857143,5945.285714,5704.428571,6554.25,6829.0,...,178.0,186.0,189.74359,1b1e94b9-f2cc-42ec-91e1-e5c3bef4adc7,Staphylococcus aureus,0.0,1.0,,0.0,0.0
3,3470.142857,3477.0,2912.714286,3249.714286,2469.142857,2462.714286,2484.714286,2528.0,2918.375,2667.0,...,74.666667,90.5,96.5,e6cf028f-0960-4751-9ca6-d94f90e07ae6,Staphylococcus aureus,0.0,0.0,,0.0,0.0
4,1564.625,1984.857143,1563.0,1842.0,1406.714286,1411.428571,1319.0,1277.857143,1445.571429,1616.0,...,15.5,23.5,21.529412,5ea281ba-f7c8-43a7-a17f-43ac77ed7f68,Staphylococcus aureus,0.0,0.0,,0.0,0.0


#2-Training model with only external data 

In [10]:
s_aureus_oxacillin_driams_b = s_aureus_driams_b.drop(columns=['code','species', 'Clindamycin', 'Ceftriaxone', 'Ciprofloxacin', 'Fusidic acid']) 
s_aureus_oxacillin_driams_b.dropna(axis=0, how="any", inplace=True)

In [12]:
X = s_aureus_oxacillin_driams_b.iloc[:, 0:6000].values  # variables independientes (espectros de masa)
y = s_aureus_oxacillin_driams_b.iloc[:, 6000].values    # variable dependientes (resistencia a ciprofloxacin)
X = np.asarray(X).astype(np.float32)
y = np.asarray(y).astype(np.float32)

In [13]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0, stratify=y)

In [14]:
scaler=Normalizer(norm='max')
sc_X = scaler
X_train = sc_X.fit_transform(X_train)
X_test = sc_X.transform(X_test)

In [15]:
sample_size = X_train.shape[0] # numero de muestras en el set de datos
time_steps  = X_train.shape[1] # numero de atributos en el set de datos
input_dimension = 1            #

X_train_reshaped = X_train.reshape(sample_size,time_steps,input_dimension)
X_test_reshaped = X_test.reshape(X_test.shape[0],X_test.shape[1],1)

In [16]:
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, min_lr=0.000001)
early_st = EarlyStopping(monitor='val_loss', patience=4, restore_best_weights=True)

n_timesteps = X_train_reshaped.shape[1] #
n_features  = X_train_reshaped.shape[2] #

## create and fit DeepAMR model

In [None]:
model = Sequential(name="Modelo_s_aureus_ciprofloxacin")
init_mode = 'normal'
model.add(Conv1D(filters=(64), kernel_size=(17), input_shape = (n_timesteps,n_features), name='Conv_1'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling1D(pool_size=2, name="MaxPooling1D_1"))

model.add(Conv1D(filters=(128), kernel_size=(9),kernel_initializer=init_mode, kernel_regularizer=regularizers.l2(0.0001),  name='Conv_2'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling1D(pool_size=2, name="MaxPooling1D_2"))

model.add(Conv1D(filters=(256), kernel_size=(5),kernel_initializer=init_mode,kernel_regularizer=regularizers.l2(0.0001),   name='Conv_3'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling1D(pool_size=2, name="MaxPooling1D_3"))

model.add(Conv1D(filters=(256), kernel_size=(5),kernel_initializer=init_mode, kernel_regularizer=regularizers.l2(0.0001),   name='Conv_4'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling1D(pool_size=2, name="MaxPooling1D_4"))

model.add(Flatten())
model.add(Dropout(0.65))
model.add(Dense(256, activation='relu',kernel_initializer=init_mode, kernel_regularizer=regularizers.l2(0.0001), name="fully_connected_0"))
model.add(Dense(64, activation='relu',kernel_initializer=init_mode, kernel_regularizer=regularizers.l2(0.0001), name="fully_connected_1"))
model.add(Dense(64, activation='relu',kernel_initializer=init_mode, kernel_regularizer=regularizers.l2(0.0001),  name="fully_connected_2"))
model.add(Dense(n_features, activation='sigmoid', name="OUT_Layer"))

model.compile(optimizer = Adam(learning_rate=0.0001), loss = 'binary_crossentropy',  metrics=METRICS)
model.summary()

In [18]:
history = model.fit(X_train_reshaped, y_train, epochs=100, batch_size=10, verbose=1, validation_split=0.1, callbacks=[reduce_lr,early_st])


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

#3-Test external data on DRIAMS-A pretrained model 

#4-Test external data applying transfer learning, freezing convolutional layers.

#5-Test external data applying transfer learning, retraining all layers.