In [14]:
import os
import numpy as np

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')


def rolling(N, i, loss, err):
    i_ = i[N-1:]
    K = np.full(N, 1./N)
    loss_ = np.convolve(loss, K, 'valid')
    err_ = np.convolve(err, K, 'valid')
    return i_, loss_, err_


expDir = './snapshots'


def create_plots(folder):
    trainP = os.path.join(expDir, folder, 'train.csv')
    trainData = np.loadtxt(trainP, delimiter=',').reshape(-1, 5)[:, :3]
    testP = os.path.join(expDir, folder, 'test.csv')
    testData = np.loadtxt(testP, delimiter=',').reshape(-1, 5)[:, :3]

    N = 392*2 # Rolling loss over the past epoch.

    trainI, trainLoss, trainErr = np.split(trainData, [1, 2], axis=1)
    #trainI, trainLoss, trainErr = [x.ravel() for x in
    #                               (trainI, trainLoss, trainErr)]
    #trainI_, trainLoss_, trainErr_ = rolling(N, trainI, trainLoss, trainErr)

    testI, testLoss, testErr = np.split(testData, [1, 2], axis=1)
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    plt.plot(trainI, trainLoss, label='Train')
    # plt.plot(trainI_, trainLoss_, label='Train')
    plt.plot(testI, testLoss, label='Test')
    plt.xlabel('Epoch')
    plt.ylabel('Cross-Entropy Loss')
    plt.legend()
    ax.set_yscale('log')
    loss_fname = os.path.join(expDir,  folder, 'loss.png')
    plt.savefig(loss_fname)
    print('Created {}'.format(loss_fname))

    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    print(trainI)
    print(trainErr)
    plt.plot(trainI, trainErr, label='Train')
    # plt.plot(trainI_, trainErr_, label='Train')
    plt.plot(testI, testErr, label='Test')
    plt.xlabel('Epoch')
    plt.ylabel('Error')
    ax.set_yscale('log')
    plt.legend()
    err_fname = os.path.join(expDir, folder, 'error.png')
    plt.savefig(err_fname)
    print('Created {}'.format(err_fname))

    loss_err_fname = os.path.join(expDir, folder, 'loss-error.png')
    os.system('convert +append {} {} {}'.format(loss_fname, err_fname, loss_err_fname))
    print('Created {}'.format(loss_err_fname))

create_plots('cnn/2017-03-11_21-01-14')

Created ./snapshots/cnn/2017-03-11_21-01-14/loss.png
[[  100.]
 [  200.]
 [  300.]
 [  400.]
 [  500.]
 [  600.]
 [  700.]
 [  800.]
 [  900.]
 [ 1000.]
 [ 1100.]
 [ 1200.]
 [ 1300.]
 [ 1400.]
 [ 1500.]
 [ 1600.]
 [ 1700.]
 [ 1800.]
 [ 1900.]
 [ 2000.]
 [ 2100.]
 [ 2200.]
 [ 2300.]
 [ 2400.]
 [ 2500.]]
[[ 53.42105263]
 [ 57.89473684]
 [ 66.57894737]
 [ 72.63157895]
 [ 79.47368421]
 [ 79.21052632]
 [ 86.84210526]
 [ 84.47368421]
 [ 91.05263158]
 [ 92.63157895]
 [ 94.21052632]
 [ 93.68421053]
 [ 95.26315789]
 [ 96.05263158]
 [ 96.31578947]
 [ 96.05263158]
 [ 96.84210526]
 [ 97.36842105]
 [ 97.10526316]
 [ 97.63157895]
 [ 98.15789474]
 [ 97.63157895]
 [ 98.94736842]
 [ 97.89473684]
 [ 99.21052632]]


Created ./snapshots/cnn/2017-03-11_21-01-14/error.png
Created ./snapshots/cnn/2017-03-11_21-01-14/loss-error.png
