In [5]:
import copy
from keypress_recognition.models import black_key_model as bmodel, white_key_model as wmodel
from keypress_recognition import dataset
import numpy as np
import matplotlib.pyplot as plt

In [4]:
black_indices = np.array([1, 4, 6, 9, 11, 13, 16, 18, 21, 23, 25, 28, 30, 33, 35, 37, 40, 42, 45, 47, 49, 52, 54, 57, 59, 61, 64, 66, 69, 71, 73, 76, 78, 81, 83, 85])
def visualize_keys(note):
    """
    note: bool
    return: str
    """
    return ("Pressed" if note else "Not Pressed")
visualize_keys(True)

Pressed


In [None]:
# print several images with labels

for _type in ['train', 'test', 'val']:

    print(f'Samples from X_{_type}: ')

    fig, axarr = plt.subplots(2, 2, figsize=[16, 16])

    for i in range(0, 2):
        imgw, imgb, notew, noteb = dataset.get_sample(_type, mode=0)
        axarr[i, 0].imshow(imgw)
        axarr[i, 0].set_title(visualize_keys(notew))
        axarr[i, 1].imshow(imgb)
        axarr[i, 1].set_title(visualize_keys(noteb))

    plt.show()

In [None]:
bmodel.load_model('keypress_recognition/best_black_model.tar')
wmodel.load_model('keypress_recognition/best_white_model.tar')

In [None]:
# load images into memory

dataset.load_to_memory(spliter=['train', 'val'])

In [None]:
bmodel.train(batch_size=32, num_epochs=20,
                     learning_rate=1e-3,
                     dirs=[0],
                     best_path='keypress_recognition/best_black_model.tar',
                     current_path='keypress_recognition/current_black.tar')

In [None]:
wmodel.train(batch_size=32, num_epochs=20,
                     learning_rate=1e-3,
                     dirs=[0],
                     best_path='keypress_recognition/best_white_model.tar',
                     current_path='keypress_recognition/current_white.tar')

In [None]:
Xw, Xb, _, _ = dataset.get_sample('train', mode=0)
fig, axarr = plt.subplots(1, 2, figsize=[16, 16])
axarr[0].imshow(Xw)
axarr[1].imshow(Xb)

Xw = np.transpose(Xw, [2, 0, 1])[None]
Xb = np.transpose(Xb, [2, 0, 1])[None]
yw = wmodel.evaluate(Xw)[0]
yb = bmodel.evaluate(Xb)[0]
axarr[0].set_title(visualize_keys(yw))
axarr[1].set_title(visualize_keys(yb))

plt.show()