In [2]:
import matplotlib.pyplot as plt
import json
from dbCon import PGCON
import numpy as np
from datetime import datetime
import time

## Create plots for all entrys in the database, and save plots

In [66]:
# Turn off inline printing of the graphs
%matplotlib agg
%matplotlib agg

def fetchData(sequenceNbr):
    number_of_records_returned = 16
    offset = number_of_records_returned * (sequenceNbr-1)
    
    # Fetch the actual data
    dbread = PGCON()
    sql = 'select id, name, data, comments, created, hyperparameters from californium.traininglogs\
        where name is not null and data is not null and hyperparameters is not null\
        order by id asc\
        limit ' + str(number_of_records_returned) + ' offset ' + str(offset)
    rows = dbread.fetchAll(sql)
    if len(rows)==0:
        rows = False
    return rows
                           
def createPlot(rows):
    nbrCols = 6
    plots = len(rows) * 3
    nbrRows = (plots // nbrCols) + 1
    rowNbr = 1
    
    fig = plt.figure(figsize=(16, (1.875*len(rows))))

    for row in rows:    
        
        # Prepare performance graph
        if 'Exception' in row[2]:
            continue
        historyDict = row[2]
        loss = historyDict['loss']
        valLoss = historyDict['val_loss']
        acc = historyDict['accuracy']
        valAcc = historyDict['val_accuracy']
        epochs = range(1, len(loss) + 1)

        iCount = 1
        for i in range(((rowNbr-1)*3)+1,(rowNbr*3)+1):
            x_ticks_step = 10
            plt.subplot(nbrRows,nbrCols,i)
            if iCount==1:
                # Accuracy
                plt.title(row[1])
                plt.plot(epochs, acc, '--r', label='Training accuracy')
                plt.plot(epochs, valAcc, 'b', label='Validation accuracy')
                plt.xticks(np.arange(0,len(epochs)+1, step=x_ticks_step))
                plt.yticks(np.arange(0.5, 0.91, step=0.1))
                plt.ylim(0.6, 0.9)
                plt.legend(fontsize=10, loc='upper center')
            elif iCount == 2:
                # Loss
                plt.plot(epochs, loss, '--r', label='Training loss')
                plt.plot(epochs, valLoss, 'b', label='Validation loss')
                plt.xticks(np.arange(0,len(epochs)+1, step=x_ticks_step))
                plt.yticks(np.arange(0.0, 2.0, step=0.5))
                plt.ylim(0.0, 2.0)
                plt.legend(fontsize=10, loc='upper center')
            elif iCount == 3:
                # Information
                if 'numberOfConvBlocks' in row[5]:
                    textBox = "Conv: Blocks: " \
                    + str(row[5]['numberOfConvBlocks']) \
                    + ", \n  Layers:" \
                    + str(row[5]['numberOfConvLayersPerBlock']) \
                    + ", \n  Filters:" \
                    + str(row[5]['numberOfConvFiltersInFirstLayer']) \
                    + "\nDense: Blocks:" \
                    + str(row[5]['numberOfDenseLayerBlocks']) \
                    + ", \n  Dropout: " \
                    + str(row[5]['dropOutInDenseLayer']) \
                    + "\nLearning rate: " + str(row[5]['learningRate'])
                elif 'numberOfDenseLayerUnitsPerBlock' in row[5]:
                    textBox = "\nDense: Blocks:" \
                    + str(row[5]['numberOfDenseLayerBlocks']) \
                    + ", \n Units/Dense layer:" \
                    + str(row[5]['numberOfDenseLayerUnitsPerBlock']) \
                    + ", \n  Dropout: " \
                    + str(row[5]['dropOutInDenseLayer']) \
                    + "\nLearning rate: " + str(row[5]['learningRate'])
                else:
                    textBox = "\nDense: Blocks:" \
                    + str(row[5]['numberOfDenseLayerBlocks']) \
                    + ", \n  Dropout: " \
                    + str(row[5]['dropOutInDenseLayer']) \
                    + "\nLearning rate: " + str(row[5]['learningRate'])
                plt.text(0.1,0.5,textBox)
                plt.xticks([])
                plt.yticks([])

            iCount+=1
        rowNbr+=1
    filename = "Plot_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    plt.savefig(filename)

sequence = 1
while True:
    rows = fetchData(sequence)
    if rows == False:
        break
    
    createPlot(rows)
    
    sequence+=1
    time.sleep(1)