Import packages:

In [None]:
import os
import scipy.io
import numpy
import scipy.stats as ss
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from deep_signature.training import DeepSignatureDataset
from IPython.display import display, HTML

Plot a sample of positive examples:

In [None]:
epochs = 10
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed = 42

dataset = DeepSignatureDataset(dir_path='./dataset2')
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(numpy.floor(validation_split * dataset_size))

if shuffle_dataset is True:
    numpy.random.seed(random_seed)
    numpy.random.shuffle(indices)

train_indices, validation_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(validation_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=validation_sampler)

plt.style.use("dark_background")
display(HTML('<h1>Random samples of positive and negative examples:</h1>'))
for i, data in enumerate(train_loader, 0):

    if i == 1:
        break

    display(HTML(f'<h2>Batch #{i}:</h2>'))
    for sample_index in range(batch_size):

        fig, ax = plt.subplots(1, 2, figsize=(80,40))

        label = data['labels'][sample_index]

        if label == 1:
            sample_type = 'Positive'
        else:
            sample_type = 'Negative'

        display(HTML(f'<h3>{sample_type} sample #{sample_index}:</h3>'))

        curve1 = data['curves'][0][sample_index][0]
        curve2 = data['curves'][1][sample_index][0]
        
        x1 = curve1[:,0].cpu().numpy()
        y1 = curve1[:,1].cpu().numpy()
        x2 = curve2[:,0].cpu().numpy()
        y2 = curve2[:,1].cpu().numpy()

        ax[0].axis('equal')
        ax[1].axis('equal')

        ax[0].scatter(x=x1, y=y1, s=10)
        ax[1].scatter(x=x2, y=y2, s=10)

        for label in (ax[0].get_xticklabels() + ax[0].get_yticklabels()):
            label.set_fontsize(30)

        for label in (ax[1].get_xticklabels() + ax[1].get_yticklabels()):
            label.set_fontsize(30)
        
        plt.show()