## Parameters

In [138]:
vocab_size           = 5000
hide_most_frequently = 0

epochs               = 10
batch_size           = 512
fit_verbosity        = 1

In [112]:
run_dir = "run/"

In [134]:
import os
os.environ['KERAS_BACKEND'] = "torch"
import keras

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split

In [21]:
%load_ext autoreload
%autoreload 2

In [22]:
from modules import encoding

## Retrieve data


In [99]:
df_match = pd.read_csv("../data/SAbDab/data.csv", sep=";")
# df_match[["ab_id", "ab_type"]] = df_match["ab"].str.split('|', n=1, expand=True)
# df_match[["ag_id", "ag_type"]] = df_match["ag"].str.split('|', n=1, expand=True)
df_match.head()

Unnamed: 0,ab,ag,interaction
0,5kel|ab,5kel|ag,1
1,5kel|ab,6cwt|ag,0
2,5kel|ab,4fp8|ag,0
3,5kel|ab,4yjz|ag,0
4,5kel|ab,6j15|ag,0


In [98]:
df_seq = pd.read_csv("../data/SAbDab/sequences.csv", sep=";")
df_seq.head()

Unnamed: 0,seq_id,specie,sequence
0,5kel|ag,Zaire ebolavirus (strain Mayinga-76) (128952),IPLGVIHNSTLQVSDVDKLVCRDKLSSTNQLRSVGLNLEGNGVATD...
1,5kel|ag,Zaire ebolavirus (128952),EAIVNAQPKCNPNLHYWTTQDEGAAIGLAWIPYFGPAAEGIYTEGL...
2,5kel|ab,Homo sapiens (9606),EVQLQESGGGLMQPGGSMKLSCVASGFTFSNYWMNWVRQSPEKGLE...
3,5kel|ab,Homo sapiens (9606),DIQMTQSPASLSVSVGETVSITCRASENIYSSLAWYQQKQGKSPQL...
4,5kel|ab,Homo sapiens (9606),DVKLLESGGGLVQPGGSLKLSCAASGFSLSTSGVGVGWFRQPSGKG...


In [83]:
seq = df_seq["sequence"]
encoder = encoding.alphabet_one_hot(alphabet=encoding.AMINO_ACID_ALPHABET)

In [84]:
seq_encode = encoding.one_hot_encoder(seq, encoder, vector_size=1500)

In [93]:
encoded = encoding.one_hot_encode_sequence("MYA", encoder, 100)

In [140]:
vector_size = 1500
seq_encoded = encoding.one_hot_encoder(seq, encoder, )

In [109]:
ALPHABET_SIZE = len(encoding.AMINO_ACID_ALPHABET)

### Split dataset

In [131]:
df_seq[["seq_rcpb", "seq_type"]] = df_seq["seq_id"].str.split('|',  n=1, expand=True)
ordinal_encoder = OrdinalEncoder()
enc_seq_type = ordinal_encoder.fit_transform(df_seq[["seq_type"]])

In [135]:
X = seq_encoded
y = enc_seq_type

In [136]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

## Model

In [150]:
model = keras.Sequential(name="abag_classifier")

model.add(keras.layers.Input(shape=(vector_size, ALPHABET_SIZE)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense( 32, activation='relu'))
model.add(keras.layers.Dense( 32, activation='relu'))
model.add(keras.layers.Dense( 1,  activation='sigmoid'))

    
model.compile(optimizer = 'rmsprop',
                  loss      = 'binary_crossentropy',
                  metrics   = ['accuracy'])

model.summary()

Model: "abag_classifier"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 33000)             0         
                                                                 
 dense_4 (Dense)             (None, 32)                1056032   
                                                                 
 dense_5 (Dense)             (None, 32)                1056      
                                                                 
 dense_6 (Dense)             (None, 1)                 33        
                                                                 
Total params: 1057121 (4.03 MB)
Trainable params: 1057121 (4.03 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## Train the model

### Add callback

In [148]:
os.makedirs(f'{run_dir}/models', mode=0o750, exist_ok=True)
save_dir = f'{run_dir}/models/best_model.keras'

savemodel_callback = keras.callbacks.ModelCheckpoint(filepath=save_dir, monitor='val_accuracy', mode='max', save_best_only=True)

### Train it

In [151]:
%%time

history = model.fit(X_train,
                    y_train,
                    epochs          = epochs,
                    batch_size      = batch_size,
                    validation_data = (X_test, y_test),
                    verbose         = fit_verbosity,
                    callbacks       = [savemodel_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
CPU times: user 10.6 s, sys: 584 ms, total: 11.2 s
Wall time: 3.59 s
