## Install Requirements

In [4]:
import numpy as np
import ndjson
import os
import matplotlib.pyplot as plt
import itertools
import math

## Import Data using ndjson

In [5]:
data_dir = os.path.join("../data/")

In [38]:
def load_data(data_dir):
    data_fnames = os.listdir(data_dir)
    A = [0 for i in range(len(data_fnames))]
    
    for i, fname in enumerate(data_fnames):

        # Get filepath
        full_path = os.path.join(data_dir, fname)

        # Load file
        with open(full_path, "rb") as f:
            A[i] = ndjson.load(f)
            f.close()
    return A

## Visualize Strokes
We'll use matplotlib to visualize series of strokes.

In [39]:
def display_strokes(drawing):
    sqrt_N = math.ceil(np.sqrt(len(drawing)))
    fig, axs = plt.subplots(sqrt_N, sqrt_N, sharex=True, sharey=True)
    xplot_index = 0
    yplot_index = 0
    xs = []
    ys = []
    for j, stroke in enumerate(drawing):
        xs.append(stroke[0])
        ys.append(stroke[1])
        xs = list(itertools.chain.from_iterable([xs[:-1],xs[-1]]))
        ys = list(itertools.chain.from_iterable([ys[:-1],ys[-1]]))
        if j % sqrt_N == 0 and j != 0:
            xplot_index = 0
            yplot_index += 1
        axs[yplot_index, xplot_index].plot(xs, ys)
        axs[yplot_index, xplot_index].set_title("Drawing {}".format((xplot_index+1) + sqrt_N*yplot_index))

        xplot_index += 1
    plt.subplots_adjust(top=1.5)
    plt.show()


## Get Drawings and Labels
This is where we'll start to build our supervised learning dataset.

In [40]:
def get_drawings_and_labels(data):
    drawings = []
    classes = []
    class2index = {}
    for class_index in range(len(A)):
        class2index[class_index] = A[class_index][0]['word']
        for drawing_index in range(len(A[class_index])):
            drawings.append(A[class_index][drawing_index]['drawing'])
            classes.append(class_index)
    return drawings, classes, class2index

In [41]:
data = load_data(data_dir)
drawings, classes, class2index = get_drawings_and_labels(data)
print("Class indices to classes mapping: {}".format(class2index))


Class indices to classes mapping: {0: 'cat', 1: 'table'}


## Create a PyTorch DataLoader Using Drawings
Let's transform our dataset into a PyTorch DataLoader so we can begin training!

In [42]:
# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Now let's create a DataLoader class
class QuickDrawDataset(Dataset):
    def __init__(self, data, labels):

        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.N)

    def __getitem__(self, index):
        return data[index], label[index]

## Prepare Data for Training
### Training Hyperparameters

In [43]:
TRAIN_TEST_SPLIT = 0.8
BATCH_SIZE = 12

### Split Data into Training, Testing, Validation

In [53]:
# Find index of shuffled data to split
N = len(drawings)
indices = [i for i in range(N)]
np.random.shuffle(indices)

# Randomly shuffle drawings and labels
drawings = np.array(drawings)[indices]
classes = np.array(classes)[indices]

split = int(N * TRAIN_TEST_SPLIT)

# Now split data
train_data = drawings[:split]
train_labels = classes[:split]

val_data = drawings[split:split + N // 10]
val_labels = classes[split:split + N // 10]


test_data = drawings[split + N // 10:]
test_labels = classes[split + N // 10:]

# Now create Datasets - TODO: Format input arguments correctly
train_dataset = QuickDrawDataset(train_data, train_labels)
val_dataset = QuickDrawDataset(val_data, val_labels)
test_dataset = QuickDrawDataset(test_data, test_labels)

# Now create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [55]:
print(.shape)

(251223,)
