In [2]:
# # https://pypi.org/project/quickdraw
# from quickdraw import QuickDrawData
import numpy as np
import urllib.request
import os
import glob as gb
import pathlib
import random
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import random_split, TensorDataset, DataLoader

In [3]:
# 난수 고정
def set_seed(seed):
    # os.environ['PYTHONASHSEED'] = 0 무작위화 비활성화
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

# 이미지 저장
def save_figure(figure_name, figure_base_path = '../figure/', figure_extension='.png', resolution=300):
    # make directory
    try:
        if not os.path.exists(figure_base_path):
            os.makedirs(figure_base_path)
    except:
      print('already exists')
    
    figure_path = figure_base_path + figure_name + figure_extension
    print('save figure: ', figure_name)
    
    plt.savefig(figure_path, bbox_inches='tight', format=figure_extension[1:], dpi=resolution)

In [4]:
# Get image labels
def get_labels():
  f = open(os.path.join(os.getcwd(), 'data/30_labels.txt'), 'r')
  labels = f.readlines()
  f.close()
  
  labels = [l.replace('\n', '') for l in labels]
  return labels

In [5]:
def quickdraw_npy():
  # make directory
  dataset_path = '../dataset/'
  try:
    if not os.path.exists(dataset_path):
      os.makedirs(dataset_path)
  except:
      None
      
  # get data from web
  labels = get_labels()
  base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
  for label in labels:
    label_url = label.replace('_', '%20')
    npy_url = base_url + label_url + '.npy'
    print(npy_url)
    urllib.request.urlretrieve(npy_url, dataset_path + label + '.npy')

  print('Done!')
  
quickdraw_npy()

https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/airplane.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/apple.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/banana.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/baseball.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bear.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bicycle.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bird.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/bus.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cat.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/cup.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/dog.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/duck.npy
https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/fish.npy
ht

In [None]:
# https://stackoverflow.com/questions/44429199/how-to-load-a-list-of-numpy-arrays-to-pytorch-dataset-loader

def prepare_npy_data(test_ratio=0.2, max_items_per_class=10000):
    npy_files = gb.glob('../dataset/*.npy')

    #initialize variables 
    x = np.empty([0, 784]) # 28*28 =784
    y = np.empty([0])
    classes = []

    #load a subset of the data to memory 
    for idx, npy_file in enumerate(npy_files):
        data = np.load(npy_file)
        data = data[0: max_items_per_class, :]
        labels = np.full(data.shape[0], idx)

        x = np.concatenate((x, data), axis=0)
        y = np.append(y, labels)
    
        label, extension = os.path.splitext(os.path.basename(npy_file))
        classes.append(label)

    data = None
    labels = None
    
    # transform to torch tensor
    tensor_x = torch.Tensor(x)
    tensor_x = tensor_x.reshape(tensor_x.shape[0], 1, 28, 28)
    tensor_y = torch.Tensor(y)
    
    # create dataset
    dataset = TensorDataset(tensor_x, tensor_y)

    #separate into train data and test data
    lengths = [int(len(dataset)*(1-test_ratio)), int(len(dataset)*test_ratio)]
    
    train_dataset, test_dataset = random_split(dataset=dataset, lengths=lengths)
    
    return train_dataset, test_dataset, classes

In [None]:
train_dataset, test_dataset, classes = prepare_npy_data()

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

In [None]:
X, y = next(iter(train_dataset))
print(X.size())

In [None]:
def imshow(img, label):
  img = img / 2 + 0.5  # unnormalize
  npimg = img.numpy()
  plt.figure(figsize=(5, 5))
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.axis('off')
  save_figure(label)
  plt.show()

In [None]:
data_iter = iter(train_loader)
images, labels = next(data_iter)

In [None]:
imshow(images[8], classes[int(labels[8].item())])