# ECBM 4040 Fall '21 Project  - BIOM Group

In [None]:
import numpy as np
import tensorflow as tf
import pickle
from matplotlib import pyplot as plt

import shutil
import os


from model.models_cstm import get_embedding_model
from model.train_model import train_siamese_model

DISTANCE_METRICS = {
    'EUCLIDEAN': 'euclidean',
    'HYPERBOLIC': 'hyperbolic',
    'MANHATTAN': 'manhattan',
    'SQUARE': 'square',
    'COSINE': 'cosine'
}


## Get Qiita Data

In [None]:
!wget https://www.dropbox.com/s/mv546rx259tgwaz/qiita_numpy.pkl

In [None]:
cwd = os.getcwd()
shutil.move(f"{cwd}/qiita_numpy.pkl", f"{cwd}/data/qiita/qiita_numpy.pkl")

## Load Qiita Dataset

In [None]:
# Load QIITA dataset.
((X_train, X_test, X_val), (y_train, y_test, y_val)) = pickle.load(open(f"{cwd}/data/qiita/qiita_numpy.pkl", "rb"))

## Train Siamese Model

In [None]:
# Train and Test Siamese Model
embedding = get_embedding_model()
data = ((X_train[:1000], X_test, X_val[:1000]), (y_train[:1000,:1000], y_test, y_val[:1000,:1000]))
dist = DISTANCE_METRICS['EUCLIDEAN']

model, score, history = train_siamese_model(data, embedding, dist , batch_size=256, epochs=5)

In [None]:
print(f'Score for Siamese Model using {dist} distance: {score}')

## Visualize Loss 

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model accuracy')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()