<a href="https://colab.research.google.com/github/pikanaeri/plm-model-comparison/blob/main/efam-performance/EFAM_Classifier_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras import backend as K
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import random
import os
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
phrog_data_dir = 'final_embeddings' #ensure that this aligns with the model that you use
phrog_metadata = pd.read_csv('PHROG_index.csv')

### To train the Esm2 model, some of the empty embeddings had to be removed
empty_phrogs = []
os.chdir("final_embeddings")

for i in os.listdir():
  if i.endswith(".pkl"):
    f2 = open(i, "rb")
    phrog_num = i.replace(".pkl", "")
    ar = pickle.load(f2)
    if len(ar) == 0:
      empty_phrogs.append(phrog_num)
    f2.close()

print(len(empty_phrogs), " empty")
for phrog in empty_phrogs:
  phrog_metadata = phrog_metadata.drop(phrog_metadata[phrog_metadata["#phrog"] == phrog].index)
###

sequence_number_per_family = 1000000
phrog_metadata['Category'].value_counts()
phrog_known = phrog_metadata[~phrog_metadata['Category'].isna()]
phrog_known = phrog_known[~phrog_known['Category'].isin(['unknown function'])]
len(phrog_known)
cs = set(phrog_known['Category'])
## dict for family:label -> {fl}
## dict for family:vectors -> {fv}
## dict for label:families -> {lf}
fl = {}
fv = {}
lf = {}

for c in cs:
    ps = phrog_known[phrog_known['Category'] == c]['#phrog']
    for p in ps:
        fl[p] = c
        try:
            fv[p] = pickle.load(open('{0}/{1}.pkl' ''.format(phrog_data_dir, p), 'rb'))
        except:
            print('{0} embeddings not found' ''.format(p))
            pass
    lf[c] = list(set(ps).intersection(set(fv.keys())))

from typing import List, Dict
def subset_training_data(
    vectors: Dict,
    labels: Dict,
    tr_families: List,
    num_train_seq: int):
    tr_vectors = [random.sample(list(vectors[f]), min(num_train_seq, len(vectors[f]))) for f in tr_families]
    tr_vectors = np.vstack(tr_vectors)
    tr_label = [[labels[f]] * min(num_train_seq, len(vectors[f])) for f in tr_families]
    tr_label = [j for i in tr_label for j in i]
    return tr_vectors, tr_label

train_families = list(set(fv.keys()))
train_x, train_y = subset_training_data(
    vectors=fv,
    labels=fl,
    tr_families=train_families,
    num_train_seq=sequence_number_per_family)
np.unique(np.array(train_y), return_counts=True)
# label binarize
# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(train_y)
trainX = train_x
# model architechture
model = tf.keras.Sequential([keras.layers.Dense(512, input_shape=(1024,), activation="relu"), #change size of the top layer to match the embedding dimensions
                             keras.layers.Dropout(0.2),
                             keras.layers.Dense(256, input_shape=(512,), activation="relu"),
                             keras.layers.Dropout(0.2),
                             keras.layers.Dense(128, input_shape=(256,), activation="relu"),
                             keras.layers.Dense(9, activation="softmax")])

n_epoch = 5
opt = Adam(0.0001)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
H = model.fit(trainX, trainY, epochs=n_epoch, batch_size=60)
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, n_epoch), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, n_epoch), H.history["accuracy"], label="train_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()

phrog_unknown = phrog_metadata[~phrog_metadata['Category'].isna()]
phrog_unknown = phrog_unknown[phrog_unknown['Category'].isin(['unknown function'])]
ufv = {}
for p in phrog_unknown['#phrog']:
    try:
        ufv[p] = pickle.load(open('{0}/{1}.pkl' ''.format(phrog_data_dir,p), 'rb'))
    except:
        print('{0} embeddings not found' ''.format(p))
        pass

confidence = 0.8
confident_unknown = []
unconfident_unknown = []
for f in tqdm(ufv.keys()):
  try:
    pred_f = model.predict(ufv[f], verbose=0)
    pred_f = np.mean(pred_f, axis=0)
    if sum(pred_f > confidence) > 0:
        confident_unknown.append(f)
    else:
        unconfident_unknown.append(f)
  except:
    print('empty embedding ', f)
    pass

len(unconfident_unknown)
ufv_vectors = [random.sample(list(ufv[f]), min(sequence_number_per_family, len(ufv[f]))) for f in unconfident_unknown]
ufv_vectors = np.vstack(ufv_vectors)
ufv_label = ['unknown'] * len(ufv_vectors)
len(ufv_vectors)

vectors = np.concatenate((train_x, ufv_vectors))
label = np.concatenate((train_y, ufv_label))
np.unique(np.array(label), return_counts=True)
trainX = vectors
trainY = label
# label binarize
# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
# model architechture
model2 = tf.keras.Sequential([keras.layers.Dense(512, input_shape=(1024,), activation="relu"), #change size of the top layer to match the embedding dimensions
                              keras.layers.Dropout(0.2),
                              keras.layers.Dense(256, input_shape=(512,), activation="relu"),
                              keras.layers.Dropout(0.2),
                              keras.layers.Dense(128, input_shape=(256,), activation="relu"),
                              keras.layers.Dense(10, activation="softmax")])
n_epoch = 5
opt = Adam(0.0001)
model2.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
H2 = model2.fit(trainX, trainY, epochs=n_epoch, batch_size=60)
# plot the training loss and accuracy
plt.rcParams["figure.figsize"]=8,8
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, n_epoch), H2.history["loss"], label="train_loss")

plt.plot(np.arange(0, n_epoch), H2.history["accuracy"], label="train_acc")

plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()

os.mkdir('models')
model2.save('models/model_unknown.keras')
pickle.dump(lb, open('models/model_unknown_lb.pkl', 'wb'))

