In [1]:
from DataCollector import DataCollectorv2
from Dataset import DatasetHPs
from NNModel import NNModelHPs
from covit import CovitProject

In [2]:
dc = DataCollectorv2()

Building Data frame
Done building Data frame
Building remote dicts
Done building remote dicts
Building local dicts
Done building local dicts


In [17]:
covit = CovitProject(project_name="375Lins",
                     data_collector=dc)

In [18]:
transfer_learning = False

change_pred_head = False

deepen = False
new_layers = 2

load_existing = True

old_model = "nn.old"
new_model = "nn1.4"

nnmodel_hps = NNModelHPs(
    encoder_repeats = 2,
    classes = len(covit.dataset.getLineages()),
    d_model = 256,
    d_val = 96,
    d_key = 96,
    d_ff = 1536,
    heads = 18,
    dropout_rate = 0.1
)

In [19]:
if load_existing:
    covit.loadNNModel(new_model)
elif transfer_learning is False:
    covit.addNNModel(name=new_model,
                     nnmodel_hps=nnmodel_hps)
else:
    covit.loadNNModel(old_model)

    covit.addNNModel(name=new_model,
                     other=old_model)
    
    if deepen is True:
        for _ in range(new_layers):
            covit.deepenNN(name=new_model,
                           trainable=True)
        covit.makeTrainable(name=new_model)
    elif change_pred_head is True:
        covit.changeNumClasses(name=new_model,
                               classes=len(covit.dataset.getLineages()))

In [20]:
covit.train(new_model,
            epochs=1,
            batch_size=256,
            mini_batch_size=64)



In [None]:
models = [
    new_model
         ]

for model in models:
    covit.loadNNModel(model)
    print("Model is {}".format(model))
    model_perf = covit.getResults(name=model).getPerf()
    val_min_loss = min(model_perf["val_loss"])
    val_min_idx = model_perf["val_loss"].index(val_min_loss)
    val_top1_max_acc = model_perf["val_top1_accuracy"][val_min_idx]
    print("===> val top1 accuracy = {}".format(val_top1_max_acc))

perf = {}
for model in models:
    covit.loadNNModel(model)
    model_perf = covit.getResults(name=model).getPerf()
    for metric in model_perf:
        if metric in perf:
            perf[metric].extend(model_perf[metric])
        else:
            perf.update({metric: model_perf[metric]})

In [None]:
import matplotlib.pyplot as plt

plt.plot(perf["val_top1_accuracy"], label="valid top1 accuracy")
plt.plot(perf["top1_accuracy"], label="train top1 accuracy")
plt.plot(perf["val_top2_accuracy"], label="valid top2 accuracy")
plt.plot(perf["top2_accuracy"], label="train top2 accuracy")
plt.plot(perf["val_top5_accuracy"], label="valid top5 accuracy")
plt.plot(perf["top5_accuracy"], label="train top5 accuracy")

plt.legend()


plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("375Lins Accuracy vs. Epochs graph")

plt.savefig("acc.png")

val_min_loss = min(perf["val_loss"])
val_min_idx = perf["val_loss"].index(val_min_loss)
val_top1_max_acc = perf["val_top1_accuracy"][val_min_idx]
val_top2_max_acc = perf["val_top2_accuracy"][val_min_idx]
val_top5_max_acc = perf["val_top5_accuracy"][val_min_idx]
print(len(perf["val_top1_accuracy"]))
print("valid top1 accuracy {}".format(val_top1_max_acc))
print("valid top2 accuracy {}".format(val_top2_max_acc))
print("valid top5 accuracy {}".format(val_top5_max_acc))



In [None]:
plt.plot(perf["val_loss"], label="valid loss")
plt.plot(perf["loss"][1:], label="train loss")

plt.legend()

plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.title("189Lins Loss vs. Epochs graph")

plt.savefig("loss.png")

In [None]:
for model in models:
    print("Model is {}".format(model))
    model_times = covit.getResults(name=model).getTimes()
    num_epochs = 0
    for model_num_epochs in model_times["epochs"]:
        num_epochs += model_num_epochs
    batch_size = model_times["batch_size"][0]
    total_time = 0
    for model_train_time in model_times["time"]:
        total_time += model_train_time
    avg_time = total_time / num_epochs
    print("===> batch size = {}".format(batch_size))
    print("===> number of epochs = {}".format(num_epochs))
    print("===> average train time per epoch = {:.2f}[min]".format(avg_time / 60))