## Parameters

In [17]:
vocab_size           = 5000
hide_most_frequently = 0

epochs               = 10
batch_size           = 512
fit_verbosity        = 1

In [18]:
run_dir = "run/"

In [19]:
import os
os.environ['KERAS_BACKEND'] = "torch"
import keras

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split

from modules.dataset import OneHotProtDataset

In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
from modules import encoding

## Retrieve data


In [22]:
df_match = pd.read_csv("../data/SAbDab/data.csv", sep=";")
# df_match[["ab_id", "ab_type"]] = df_match["ab"].str.split('|', n=1, expand=True)
# df_match[["ag_id", "ag_type"]] = df_match["ag"].str.split('|', n=1, expand=True)
df_match.head()

Unnamed: 0,ab,ag,interaction
0,5kel|ab,5kel|ag,1
1,5kel|ab,6cwt|ag,0
2,5kel|ab,4fp8|ag,0
3,5kel|ab,4yjz|ag,0
4,5kel|ab,6j15|ag,0


In [23]:
df_seq = pd.read_csv("../data/SAbDab/sequences.csv", sep=";")
df_seq.head()

Unnamed: 0,seq_id,specie,sequence
0,5kel|ag,Zaire ebolavirus (strain Mayinga-76) (128952),IPLGVIHNSTLQVSDVDKLVCRDKLSSTNQLRSVGLNLEGNGVATD...
1,5kel|ag,Zaire ebolavirus (128952),EAIVNAQPKCNPNLHYWTTQDEGAAIGLAWIPYFGPAAEGIYTEGL...
2,5kel|ab,Homo sapiens (9606),EVQLQESGGGLMQPGGSMKLSCVASGFTFSNYWMNWVRQSPEKGLE...
3,5kel|ab,Homo sapiens (9606),DIQMTQSPASLSVSVGETVSITCRASENIYSSLAWYQQKQGKSPQL...
4,5kel|ab,Homo sapiens (9606),DVKLLESGGGLVQPGGSLKLSCAASGFSLSTSGVGVGWFRQPSGKG...


### Split dataset

In [24]:
df_seq[["seq_rcpb", "seq_type"]] = df_seq["seq_id"].str.split('|',  n=1, expand=True)
ordinal_encoder = OrdinalEncoder()
enc_seq_type = ordinal_encoder.fit_transform(df_seq[["seq_type"]])

In [25]:
X, y, vector_size, alphabet_size = OneHotProtDataset.get_data()

In [26]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

## Model

In [27]:
model = keras.Sequential(name="abag_classifier")

model.add(keras.layers.Input(shape=(vector_size, alphabet_size)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense( 32, activation='relu'))
model.add(keras.layers.Dense( 32, activation='relu'))
model.add(keras.layers.Dense( 1,  activation='sigmoid'))

    
model.compile(optimizer = 'rmsprop',
                  loss      = 'binary_crossentropy',
                  metrics   = ['accuracy'])

model.summary()

Model: "abag_classifier"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 33000)             0         
                                                                 
 dense (Dense)               (None, 32)                1056032   
                                                                 
 dense_1 (Dense)             (None, 32)                1056      
                                                                 
 dense_2 (Dense)             (None, 1)                 33        
                                                                 
Total params: 1057121 (4.03 MB)
Trainable params: 1057121 (4.03 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


2024-03-23 09:38:16.731802: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


## Train the model

### Add callback

In [28]:
os.makedirs(f'{run_dir}/models', mode=0o750, exist_ok=True)
save_dir = f'{run_dir}/models/best_model.keras'

savemodel_callback = keras.callbacks.ModelCheckpoint(filepath=save_dir, monitor='val_accuracy', mode='max', save_best_only=True)

### Train it

In [29]:
%%time

history = model.fit(X_train,
                    y_train,
                    epochs          = epochs,
                    batch_size      = batch_size,
                    validation_data = (X_test, y_test),
                    verbose         = fit_verbosity,
                    callbacks       = [savemodel_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
CPU times: user 11.3 s, sys: 950 ms, total: 12.3 s
Wall time: 4.33 s


In [30]:
X_train.shape

(2426, 1500, 22)

In [31]:
y_train

array([[1.],
       [1.],
       [1.],
       ...,
       [0.],
       [1.],
       [1.]])