In [None]:
from tensorflow import keras
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
MNIST_PATH = "./mnist.npz"

def load_mnist(path):
    if os.path.isfile(path):
        with np.load(path, allow_pickle=True) as f:
            x_train, y_train = f['x_train'], f['y_train']
            x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
    return keras.datasets.mnist.load_data(MNIST_PATH)

def _process_x(x):
    return tf.expand_dims(tf.cast(x, tf.float32), axis=3) / 255. * 2 - 1
def _process_y(y):
    return tf.convert_to_tensor(y, tf.int32)

def get_69_ds():
    (x, y), _ = load_mnist(MNIST_PATH)
    x6, x9 = x[y == 6], x[y == 9]
    return _process_x(x6), _process_x(x9)

def get_test_69():
    _, (x, y) = load_mnist(MNIST_PATH)
    return _process_x(x[y == 6]), _process_x(x[y == 9])

#visualization
def visual_Mnist(x,num_of_sample_to_plot):
  x=np.squeeze(x)
  fig=plt.figure
  plt.imshow(x[num_of_sample_to_plot], cmap='gray')
  plt.show()

def get_0_ds():
    (x,y), (x_T,y_T) = load_mnist(MNIST_PATH)
    x0= x[y==0]
    x0_T = x_T[y_T == 0]
    return _process_x(x0),_process_y(y[y==0]),\
           _process_x(x0_T),_process_y(y_T[y_T==0])

def get_index_ds(index):
    (x,y), (x_T,y_T) = load_mnist(MNIST_PATH)
    x= x[y==index]
    x_T = x_T[y_T == index]
    return x,y[y==index],\
           x_T,y_T[y_T==index]
def get_1to9_ds():
    index_list=[1,2,3,4,5,6,7,8,9]
    ds_x=[]
    ds_y=[]
    ds_x_T=[]
    ds_y_T=[]
    for i in index_list:
        x,y,x_T,y_T=get_index_ds(i)
        ds_x.extend(x)
        ds_y.extend(y)
        ds_x_T.extend(x_T)
        ds_y_T.extend(y_T)
    ds_x=_process_x(np.array(ds_x))
    ds_y=_process_y(np.array(ds_y))
    ds_x_T=_process_x(np.array(ds_x_T))
    ds_y_T=_process_y(np.array(ds_y_T))

    return ds_x,ds_y,ds_x_T,ds_y_T
