Import packages:

In [None]:
# Python peripherals
import os
import random

# Scipy
import scipy.io
import scipy.stats as ss

# Numpy
import numpy

# Matplotlib
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll

# PyTorch
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

# IPython
from IPython.display import display, HTML

# Deep signature
from deep_signature.data_generation import CurveDatasetGenerator
from deep_signature.data_generation import CurveDataGenerator
from deep_signature.training import DeepSignatureDataset
from deep_signature.training import DeepSignatureNet
from deep_signature.training import ContrastiveLoss
from deep_signature.training import ModelTrainer

Helper functions:

In [None]:
def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

# https://stackoverflow.com/questions/36074455/python-matplotlib-with-a-line-color-gradient-and-colorbar
def colorline(ax, x, y, z=None, cmap='copper', norm=plt.Normalize(0.0, 1.0), linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    # to check for numerical input -- this is a hack
    if not hasattr(z, "__iter__"):
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)

    # ax = plt.gca()
    ax.add_collection(lc)

    return lc

def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)
    return segments

def plot_dist(ax, dist):
    x = numpy.array(range(dist.shape[0]))
    y = dist
    ax.set_xlim(x.min(), x.max())
    ax.set_ylim(y.min(), y.max())
    colorline(ax=ax, x=x, y=y, cmap='hsv')

def plot_curve_sample(ax, curve, curve_sample, indices, point_size=10, alpha=1, cmap='hsv'):
    x = curve_sample[:, 0]
    y = curve_sample[:, 1]
    c = numpy.linspace(0.0, 1.0, curve.shape[0])

    ax.scatter(
        x=x,
        y=y,
        c=c[indices],
        s=point_size,
        cmap=cmap,
        alpha=alpha,
        norm=plt.Normalize(0.0, 1.0))

def plot_curve_section_center_point(ax, center_point_index, radius=1):
    circle = plt.Circle((0, 0), radius=radius, color='w', zorder=10)
    ax.add_artist(circle)

def plot_curve(ax, curve):
    x = curve[:, 0]
    y = curve[:, 1]
    ax.plot(x, y)

def plot_curvature(ax, curvature):
    x = range(curvature.shape[0])
    y = curvature
    ax.plot(x, y)

def plot_sample(ax, sample, color, point_size=10, alpha=1):
    x = sample[:, 0]
    y = sample[:, 1]

    ax.scatter(
        x=x,
        y=y,
        s=point_size,
        color=color,
        alpha=alpha)

Global settings:

In [None]:
images_dir_path = 'C:/deep-signature-data/images'
curves_dir_path = 'C:/deep-signature-data/curves'
dataset_dir_path = 'C:/deep-signature-data/datasets/dataset2'
raw_data_dir_path = 'C:/raw-data'
results_base_dir_path = 'C:/deep-signature-data/results'
epochs = 100
batch_size = 32
validation_split = .1
learning_rate = 1e-4
mu = 1e-6
plt.style.use("dark_background")

Generate curves:

In [None]:
# curve_dataset_generator = CurveDatasetGenerator()
# generated_curves = curve_dataset_generator.generate_curves(dir_path='C:/deep-signature-data/images', plot_curves=False)
# curve_dataset_generator.save_curves(dir_path='C:/deep-signature-data/curves')

First, we test the curve data generation logic:

In [None]:
rotation_factor=1
sampling_factor=1
multimodality_factor=15
supporting_points_count=6
sampling_points_count=None
sampling_points_ratio=0.15
sectioning_points_count=None
sectioning_points_ratio=0.1
sample_points=supporting_points_count*2 + 1

curve_dataset_generator = CurveDatasetGenerator()
curves = curve_dataset_generator.load_curves(file_path=os.path.join(curves_dir_path, 'curves.npy'))

Now, let's plot few positive and negative sampling pairs:

In [None]:
for curve_index, curve in enumerate(curves[:1]):
    display(HTML(f'<h1>Curve #{curve_index}</h1>'))

    curve_data_generator = CurveDataGenerator(
        curve=curve,
        rotation_factor=rotation_factor,
        sampling_factor=sampling_factor,
        multimodality_factor=multimodality_factor,
        supporting_points_count=supporting_points_count,
        sampling_points_count=sampling_points_count,
        sampling_points_ratio=sampling_points_ratio,
        sectioning_points_count=sectioning_points_count,
        sectioning_points_ratio=sectioning_points_ratio)

    negative_pairs = curve_data_generator.generate_negative_pairs()
    positive_pairs = curve_data_generator.generate_positive_pairs()

    # random.shuffle(negative_pairs)
    # random.shuffle(positive_pairs)

    # NEGATIVE PAIRS
    display(HTML(f'<h2>Negative Pairs</h2>'))
    for negative_pair_index, negative_pair in enumerate(negative_pairs[:8]):
        fig, ax = plt.subplots(1, 2, figsize=(80,40))
        curve = negative_pair['curve']
        evolved_curve = negative_pair['evolved_curve']
        indices_pool = negative_pair['indices_pool']
        supporting_points_indices = negative_pair['supporting_points_indices']
        center_point_index = negative_pair['center_point_index']
        curve_sample = curve[indices_pool]
        curve_section_sample = curve[supporting_points_indices]
        evolved_curve_section_sample = evolved_curve[supporting_points_indices]

        display(HTML(f'<h3>Negative Pair #{negative_pair_index}</h3>'))

        for label in (ax[0].get_xticklabels() + ax[0].get_yticklabels()):
            label.set_fontsize(30)

        for label in (ax[1].get_xticklabels() + ax[1].get_yticklabels()):
            label.set_fontsize(30)

        plot_dist(ax=ax[0], dist=negative_pair['dist'])

        ax[1].axis('equal')

        plot_curve(ax=ax[1], curve=curve)

        plot_curve_sample(
            ax=ax[1], 
            curve=curve, 
            curve_sample=curve_sample, 
            indices=indices_pool, 
            point_size=120,
            alpha=0.3)

        plot_curve_sample(
            ax=ax[1], 
            curve=curve, 
            curve_sample=curve_section_sample, 
            indices=supporting_points_indices, 
            point_size=200)

        plot_curve_sample(
            ax=ax[1], 
            curve=evolved_curve, 
            curve_sample=evolved_curve_section_sample, 
            indices=supporting_points_indices, 
            point_size=100,
            cmap='twilight')

        plot_curve_section_center_point(ax=ax[1], center_point_index=center_point_index, radius=3)

        plt.show()

    # POSITIVE PAIRS
    display(HTML(f'<h2>Positive Pairs</h2>'))
    for positive_pair_index, positive_pair in enumerate(positive_pairs[:5]):
        fig, ax = plt.subplots(2, 2, figsize=(80,40))
        curve = positive_pair['curve']
        indices_pool1 = positive_pair['indices_pool1']
        indices_pool2 = positive_pair['indices_pool2']
        supporting_points_indices1 = positive_pair['supporting_points_indices1']
        supporting_points_indices2 = positive_pair['supporting_points_indices2']
        center_point_index = positive_pair['center_point_index']
        dist1 = positive_pair['dist1']
        dist2 = positive_pair['dist2']

        curve_sample1 = curve[indices_pool1]
        curve_sample2 = curve[indices_pool2]
        curve_section_sample1 = curve[supporting_points_indices1]
        curve_section_sample2 = curve[supporting_points_indices2]

        display(HTML(f'<h3>Positive Pair #{positive_pair_index}</h3>'))

        for label in (ax[0, 0].get_xticklabels() + ax[0, 0].get_yticklabels()):
            label.set_fontsize(30)

        for label in (ax[0, 1].get_xticklabels() + ax[0, 1].get_yticklabels()):
            label.set_fontsize(30)

        for label in (ax[1, 0].get_xticklabels() + ax[1, 0].get_yticklabels()):
            label.set_fontsize(30)

        for label in (ax[1, 1].get_xticklabels() + ax[1, 1].get_yticklabels()):
            label.set_fontsize(30)

        plot_dist(ax=ax[0, 0], dist=dist1)
        plot_dist(ax=ax[1, 0], dist=dist2)

        ax[0, 1].axis('equal')
        ax[1, 1].axis('equal')

        plot_curve(ax=ax[0, 1], curve=curve)
        plot_curve(ax=ax[1, 1], curve=curve)

        plot_curve_sample(
            ax=ax[0, 1], 
            curve=curve, 
            curve_sample=curve_sample1, 
            indices=indices_pool1, 
            point_size=120,
            alpha=0.3)

        plot_curve_sample(
            ax=ax[1, 1], 
            curve=curve, 
            curve_sample=curve_sample2, 
            indices=indices_pool2, 
            point_size=120,
            alpha=0.3)

        plot_curve_sample(
            ax=ax[0, 1], 
            curve=curve, 
            curve_sample=curve_section_sample1, 
            indices=supporting_points_indices1, 
            point_size=200)

        plot_curve_sample(
            ax=ax[1, 1], 
            curve=curve, 
            curve_sample=curve_section_sample2, 
            indices=supporting_points_indices2, 
            point_size=200)

        plot_curve_section_center_point(ax=ax[0, 1], center_point_index=center_point_index, radius=3)
        plot_curve_section_center_point(ax=ax[1, 1], center_point_index=center_point_index, radius=3)

        plt.show()

Now that we know that the dataset generation logic works correctly, let's generate a large dataset that will be used for training:

In [None]:
# dataset_generator = DatasetGenerator()
# dataset_generator.load_raw_curves(dir_path=raw_data_dir_path)
# dataset_generator.save(
#     dir_path=dataset_dir_path,
#     pairs_per_curve=25,
#     rotation_factor=12,
#     sampling_factor=15,
#     sample_points=600,
#     metadata_only=True)

Before we start training, we first have to do a sanity check by plotting a few positive and negative examples of pairs:

In [None]:
random_seed = 42
dataset = DeepSignatureDataset()
dataset.load_dataset(dir_path=dataset_dir_path)
dataset_size = len(dataset)
indices = list(range(dataset_size))
numpy.random.seed(random_seed)
numpy.random.shuffle(indices)
sampler = SubsetRandomSampler(indices)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

In [None]:
display(HTML('<h1>Random samples of positive and negative examples:</h1>'))
for pair_index, data in enumerate(data_loader, 0):
    if pair_index == 10:
        break

    curve1 = torch.squeeze(torch.squeeze(data['curves_channel1']))
    curve2 = torch.squeeze(torch.squeeze(data['curves_channel2']))
    label = int(torch.squeeze(data['labels']))

    if label == 1:
        pair_type = 'Positive'
    else:
        pair_type = 'Negative'

    display(HTML(f'<h2>{pair_type} sample #{pair_index}:</h2>'))

    curve1 = curve1.cpu().numpy()
    curve2 = curve2.cpu().numpy()

    fig, ax = plt.subplots(1, 1, figsize=(80,80))
    ax.axis('equal')

    plot_sample(
        ax=ax, 
        sample=curve1, 
        point_size=120,
        color='lightcoral')

    plot_sample(
        ax=ax, 
        sample=curve2, 
        point_size=120,
        color='skyblue')

    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontsize(30)
    
    plt.show()

Train model:

In [None]:
torch.set_default_dtype(torch.float64)
dataset = DeepSignatureDataset()
dataset.load_dataset(dir_path=dataset_dir_path)
model = DeepSignatureNet(sample_points=sample_points).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = ContrastiveLoss(mu)
model_trainer = ModelTrainer(model=model, loss_fn=loss_fn, optimizer=optimizer)

print(model)

# curves_to_plot = 2
# configurations_per_curve = 1
# samples_per_configuration = 2
# plots_per_samples = 4
# rows = curves_to_plot *  plots_per_samples
# cols = configurations_per_curve * samples_per_configuration
# data_loader = DataLoader(dataset, batch_size=1)
# dataset_generator = DatasetGenerator()
# dataset_generator.load_raw_curves(dir_path=raw_data_dir_path, shuffle_curves=False)
# test_curves = dataset_generator.generate_curves(rotation_factor=rotation_factor, sampling_factor=sampling_factor, sample_points=sample_points, limit=10)

def epoch_handler(epoch_index):
    h = 4

# def epoch_handler(epoch_index):
#     fig, ax = plt.subplots(rows, cols, figsize=(80,80))
#     for curve_index, curve in enumerate(test_curves[:curves_to_plot]):
#         for configuration_index, curve_configuration in enumerate(curve.curve_configurations[:configurations_per_curve]):
#             for sample_index, curve_sample in enumerate(curve_configuration.curve_samples[:samples_per_configuration]):

#                 base_row = plots_per_samples * curve_index
#                 c = numpy.linspace(0.0, 1.0, curve_configuration.curve.shape[0])

#                 # Plot curve configuration
#                 ax[base_row, sample_index].plot(curve_configuration.curve[:,0], curve_configuration.curve[:,1])

#                 # Plot curve configuration sample
#                 x = curve_sample.sampled_curve[:,0]
#                 y = curve_sample.sampled_curve[:,1]
#                 ax[base_row + 1, sample_index].scatter(
#                     x=x,
#                     y=y,
#                     c=c[curve_sample.sorted_indices],
#                     s=10,
#                     cmap='hsv')

#                 # Plot curve's true curvature
#                 x = numpy.array(range(curve_configuration.curve.shape[0]))
#                 y = curve.curvature
#                 lc = colorline(ax[base_row + 2, sample_index], x, y, cmap='hsv')
#                 ax[base_row + 2, sample_index].set_xlim(x.min(), x.max())
#                 ax[base_row + 2, sample_index].set_ylim(y.min(), y.max())

#                 # Plut curve's curvature prediction
#                 batch_data = torch.unsqueeze(torch.unsqueeze(torch.tensor(curve_sample.sampled_curve, dtype=torch.float64), dim=0), dim=0).cuda()

#                 x = numpy.array(range(sample_points))
#                 model.eval()
#                 with torch.no_grad():
#                     y = torch.squeeze(model(batch_data), dim=0).cpu().detach().numpy()

#                 x = numpy.array(range(curve_configuration.curve.shape[0]))
#                 ax[base_row + 3, sample_index].set_xlim(x.min(), x.max())

#                 ax[base_row + 3, sample_index].scatter(
#                     x=curve_sample.sorted_indices,
#                     y=y,
#                     c=c[curve_sample.sorted_indices],
#                     s=70,
#                     cmap='hsv')
#                 ax[base_row + 3, sample_index].plot(curve_sample.sorted_indices, y, linewidth=2)

#     # Set the tick labels font
#     for row in range(rows):
#         for col in range(cols):
#             for label in (ax[row, col].get_xticklabels() + ax[row, col].get_yticklabels()):
#                 label.set_fontsize(30)

#     # plt.show()
#     plt.savefig(os.path.join(results_base_dir_path, f'epoch_{epoch_index}.png'))

results = model_trainer.fit(dataset=dataset, epochs=epochs, batch_size=batch_size, results_base_dir_path=results_base_dir_path, epoch_handler=epoch_handler)

Load loss stats:

In [None]:
# results_file_path = os.path.normpath(os.path.join(results_base_dir_path, 'results.npy'))
results = numpy.load("C:/deep-signature-data/results/2020-10-27-16-05-50/results.npy", allow_pickle=True).item()

epochs = results['epochs']
batch_size = results['batch_size']
train_loss_array = results['train_loss_array']
validation_loss_array = results['validation_loss_array']
epochs_list = numpy.array(range(epochs))

fig, ax = plt.subplots(1, 1, figsize=(80,40))

for label in (ax.get_xticklabels() + ax.get_yticklabels()):
    label.set_fontsize(50)

ax.plot(epochs_list, train_loss_array, label='Train Loss', linewidth=7.0)
ax.plot(epochs_list, validation_loss_array, label='Validation Loss', linewidth=7.0)
plt.legend(fontsize=50, title_fontsize=50)

plt.show()

Test model:

In [None]:
torch.set_default_dtype(torch.float64)

curve_dataset_generator = CurveDatasetGenerator()
curves = curve_dataset_generator.load_curves(file_path="C:/deep-signature-data/curves/curves.npy")

device = torch.device('cuda')
model = DeepSignatureNet(sample_points=sample_points).cuda()
model.load_state_dict(torch.load(results['model_file_path'], map_location=device))
model.eval()

curve_dataset_generator = CurveDatasetGenerator()
curves = curve_dataset_generator.load_curves(file_path="C:/deep-signature-data/curves/curves.npy")

device = torch.device('cuda')
model = DeepSignatureNet(sample_points=sample_points).cuda()
model.load_state_dict(torch.load(results['model_file_path'], map_location=device))
model.eval()

rotation_factor=1
sampling_factor=1
multimodality_factor=15
supporting_points_count=6
sampling_points_count=None
sampling_points_ratio=0.15
sectioning_points_count=None
sectioning_points_ratio=0.1

for i, curve in enumerate(curves[:35]):

    # for i in range(curve.shape[0])[:4]:
    #     indices = numpy.arange(i - supporting_points_count, i + supporting_points_count + 1)
    #     # print(indices)
    #     indices = numpy.mod(indices, curve.shape[0])
    #     # print(indices)

    curve_data_generator = CurveDataGenerator(
        curve=curve, 
        rotation_factor=rotation_factor, 
        sampling_factor=sampling_factor, 
        multimodality_factor=multimodality_factor, 
        supporting_points_count=supporting_points_count, 
        sampling_points_count=None, 
        sectioning_points_count=None, 
        sampling_points_ratio=sampling_points_ratio, 
        sectioning_points_ratio=sectioning_points_ratio)

    fig, ax = plt.subplots(3, 1, figsize=(80,40))

    display(HTML(f'<h3>Curve #{i}</h3>'))

    for label in (ax[0].get_xticklabels() + ax[0].get_yticklabels()):
        label.set_fontsize(30)

    for label in (ax[1].get_xticklabels() + ax[1].get_yticklabels()):
        label.set_fontsize(30)

    for label in (ax[2].get_xticklabels() + ax[2].get_yticklabels()):
        label.set_fontsize(30)

    ax[0].axis('equal')
    # ax[1].axis('equal')

    plot_curve(ax=ax[0], curve=curve)
    plot_curvature(ax=ax[1], curvature=curve_data_generator.curvature)

    predicted_curvature = numpy.zeros_like(curve_data_generator.curvature)
    for i in range(curve.shape[0]):

        # if i == 1:
        #     break

        indices = numpy.arange(i - supporting_points_count, i + supporting_points_count + 1)
        indices = numpy.mod(indices, curve.shape[0])
        sample = curve[indices]
        sample = sample - sample[6]
        batch_data = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(sample).double(), dim=0), dim=0).cuda()
        # print(batch_data.shape)
        with torch.no_grad():
            # bla = model(batch_data).cpu().detach().numpy()
            # print('THISSSSS:' + str(bla))
            predicted_curvature[i] = torch.squeeze(model(batch_data), dim=0).cpu().detach().numpy()
            # print(predicted_curvature)

    plot_curvature(ax=ax[2], curvature=predicted_curvature)
    plt.show()