# Convolutional Neural Network (CNN)

In this lab, you will be training a simple [Convolutional Neural Network](https://developers.google.com/machine-learning/glossary/#convolutional_neural_network) (CNN) to classify [CIFAR images](https://www.cs.toronto.edu/~kriz/cifar.html). 

The [Keras Sequential API](https://www.tensorflow.org/guide/keras/overview) will allow you to do this very easily. 

Creating and training a model will take just a few lines of code.

<img src="https://pics.me.me/the-maths-dehhddeepleihing-import-keras-say-hello-to-keras-56785635.png">




### Import Libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from IPython.display import display, Markdown, Latex

from tensorflow.keras import datasets, layers, models

Load the dataset and normalize it

## Dataset

CIFAR-10 consists of 60000 32x32 color images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. Here are the classes in the dataset, as well as 10 random images from each:

<table>
    <tbody><tr>
        <td class="cifar-class-name">airplane</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/airplane10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">automobile</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/automobile10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">bird</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/bird10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">cat</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/cat10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">deer</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/deer10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">dog</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/dog10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">frog</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/frog10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">horse</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/horse10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">ship</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/ship10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">truck</td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck1.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck2.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck3.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck4.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck5.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck6.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck7.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck8.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck9.png" class="cifar-sample"></td>
        <td><img src="https://www.cs.toronto.edu/~kriz/cifar-10-sample/truck10.png" class="cifar-sample"></td>
    </tr>
</tbody></table>

In [None]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images = train_images / ?
test_images  = test_images / ?

### Verify the data

To verify that the dataset looks correct, let's plot the first 25 images from the training set and display the class name below each image:


In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# grid of size N x N
N = ? 

plt.figure(figsize=(10, 10))
for i in range(N*N):
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    image, label_index = train_images[i], train_labels[i][0]
    label_name = class_names[label_index]

    plt.subplot(N,N,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(image)
    plt.xlabel(label_name)
    
plt.show()

### Create the convolutional base

## Mathematical aspect

Here a simplified depiction of a 32x32x3 image

<img src="https://raw.githubusercontent.com/mingruimingrui/Convolution-neural-networks-made-easy-with-keras/master/imgs/input-image-dimension.JPG" width="240" >

A typical input image will be broken down into its individual pixel components. In the picture above, we have a 32x32 pixel image which has a R, G, and B value attached to each pixel, therefore a 32x32x3 input, also known as an input with 32 height, 32 width, and 3 depth.

<img src="https://raw.githubusercontent.com/mingruimingrui/Convolution-neural-networks-made-easy-with-keras/master/imgs/filtering.JPG" width="360" style="max-width: 100%;">

A 3x3 filter applyed to a 3x3(x1) input

<img src="https://raw.githubusercontent.com/mingruimingrui/Convolution-neural-networks-made-easy-with-keras/master/imgs/filtering-math.JPG">

Since we are dealing with an image of depth 3 (number of colors), we need to imagine a 3x3x3 sized mini image being multiplied and summed up with another 3x3x3 filter. Then by adding another constant term, we will receive a single number result from this transformation.

<img src=https://raw.githubusercontent.com/mingruimingrui/Convolution-neural-networks-made-easy-with-keras/master/imgs/filtering-many-to-one.gif>


### Pooling Layers

After some ReLU layers, it is customary to apply a pooling layer (aka downsampling layer).

In this category, there are also several layer options, with maxpooling being the most popular.

<img width="80%" src="https://raw.githubusercontent.com/leriomaggio/deep-learning-keras-tensorflow/b996ea1faba3ee7d01ea7bd733d2e7cd71be6bb5/imgs/MaxPool.png">

The 6 lines of code below define the convolutional base using a common pattern: a stack of [Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D) and [MaxPooling2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MaxPool2D) layers.

As input, a CNN takes tensors of shape (image_height, image_width, color_channels), ignoring the batch size. If you are new to these dimensions, color_channels refers to (R,G,B). In this example, you will configure your CNN to process inputs of shape (32, 32, 3), which is the format of CIFAR images. You can do this by passing the argument `input_shape` to your first layer.

**layer 1**
- 32 x [Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D) (3, 3)
- relu
- MaxPooling2D (2, 2)

**layer 2**
- 64 x [Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D) (3, 3)
- relu
- MaxPooling2D (2, 2)

**layer 3**
- 64 x [Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D) (3, 3)

note that you can pass a `activation` parameter to Conv2D


In [None]:
input_shape = (?, ?, ?)

model = models.Sequential()
model.add(layers.InputLayer(input_shape=input_shape))

# layer 1
model.add(layers.Conv2D(?, (?, ?), activation=?))
model.add(layers.MaxPooling2D((?, ?)))

# layer 2
model.add(layers.Conv2D(?, (?, ?), activation=?))
model.add(layers.MaxPooling2D((?, ?)))

# layer 3
model.add(layers.Conv2D(?, (?, ?), activation=?))

Let's display the architecture of your model so far:

In [None]:
model.summary()

Above, you can see that the output of every Conv2D and MaxPooling2D layer is a 3D tensor of shape (height, width, channels). The width and height dimensions tend to shrink as you go deeper in the network. The number of output channels for each Conv2D layer is controlled by the first argument (e.g., 32 or 64). Typically,  as the width and height shrink, you can afford (computationally) to add more output channels in each Conv2D layer.

### Add Dense layers on top

To complete the model, you will feed the last output tensor from the convolutional base (of shape (4, 4, 64)) into one or more Dense layers to perform classification. Dense layers take vectors as input (which are 1D), while the current output is a 3D tensor. First, you will flatten (or unroll) the 3D output to 1D,  then add one or more Dense layers on top. CIFAR has 10 output classes, so you use a final Dense layer with 10 outputs.

**layer 4**
- Flatten
- Dense (64) + relu
- Dense (n_classes)



In [None]:
n_classes = ?

# layer 4
model.add(layers.Flatten())
model.add(layers.Dense(?, activation=?))
model.add(layers.Dense(n_classes))

Here's the complete architecture of your model:

In [None]:
model.summary()

The network summary shows that (4, 4, 64) outputs were flattened into vectors of shape (1024) before going through two Dense layers.

### Compile and train the model



In [None]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

history = model.fit(
    train_images, train_labels, 
    epochs=10, 
    validation_data=(test_images, test_labels)
)

### Plot the metrics

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 9))

# plot loss
axes[0].plot(history.history['loss'], 'x--', label='loss')
axes[0].plot(history.history['val_loss'], 'x--', label='val_loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend(loc='lower right')
axes[0].grid(True)

# plot accuracy
axes[1].plot(history.history['accuracy'], 'x--', label='accuracy')
axes[1].plot(history.history['val_accuracy'], 'x--', label='val_accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend(loc='lower right')
axes[1].grid(True)

### Evaluate your model

In [None]:
test_loss, test_acc = model.evaluate(?,  ?, verbose=2)

display(Markdown(f"Your model achieved **{round(100 * test_acc, 2)} %** on the test set !"))

Your simple CNN should has achieved a test accuracy of over 70%.
Not bad for a few lines of code !

# Computing the representation of an image

Lets keep only the representation (just after the Flatten layer !)
Discard the last two layers of the previous Model

In [None]:
input_layer = model.layers[0].input
output_layer = model.layers[-2].output

model_repr = models.Model(input_layer, output_layer)

In [None]:
representation = model_repr.predict(test_images)

print("shape of representation:", representation.shape)

# sparse
print("proportion of zero valued axis: %0.3f" % np.mean(representation[0]==0))

## Visual Search: finding similar images

find the `top_n` similar representation's indexs 

In [None]:
def most_similar(tensor, idx, top_n=5):
    dists = np.linalg.norm(tensor - tensor[idx], axis=1)
    sorted_dists = np.argsort(dists)
    return sorted_dists[:top_n]

idx = 57 # choose an index (anchor)
top_n = 5

sim_indexs = most_similar(representation, idx, top_n=top_n)
print(sim_indexs)

In [None]:
# anchor
image, label_index = test_images[idx], test_labels[idx][0]
label_name = class_names[label_index]
plt.imshow(image)
plt.xlabel(label_name)


# most similar images
plt.figure()
for i, sim_index in enumerate(sim_indexs):
    image, label_index = test_images[sim_index], test_labels[sim_index][0]
    label_name = class_names[label_index]

    plt.subplot(1, top_n, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(image)
    plt.xlabel(label_name)

plt.show()

# TSNE 

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


tsne = TSNE(n_components=2, perplexity=30)
img_emb_tsne = tsne.fit_transform(representation)

plt.figure(figsize=(10, 10))
plt.scatter(img_emb_tsne[:, 0], img_emb_tsne[:, 1]);
plt.xticks(())
plt.yticks(())
plt.show()

In [None]:
import torch
import numpy as np

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

%matplotlib inline
import matplotlib.pyplot as plt

def dimension_reduction_2D(X, method='tsne'):
    ''' unsupervised linear dimensionality reduction and scale to [0, 1] '''

    fitter = {
        'tsne': TSNE(n_components=2).fit_transform,
        'pca': PCA(n_components=2).fit_transform
    }.get(method, None)
    if fitter is None:
       raise Exception('unkown dimensionality reduction technique')
    
    Xs = fitter(X)
    Xmin, Xmax = np.amin(Xs), np.amax(Xs)
    Xs = (Xs - Xmax) / (Xmax - Xmin)
    #Xs[:, 0] = scale_0_1(Xs[:, 0]) # scale x to [0, 1]
    #Xs[:, 1] = scale_0_1(Xs[:, 1]) # scale y to [0, 1]
    return Xs

def plot_tsne(X_list, label, n_classes, titles=None):
    # color map
    cmap_tsne = cm.get_cmap('hsv', n_classes) 

    n = len(X_list)
    if n == 1:
        fig, ax = plt.subplots(figsize=(5, 5))
        x, y = X_list[0].T
        ax.scatter(x, y, c=label, cmap=cmap_tsne, s=50, edgecolors='w')
        ax.axis('off')
    else:
        fig, axs = plt.subplots(1, n, figsize=(5*n, 5))
        for i, X_2D in enumerate(X_list):
            x, y = X_2D.T

            axs[i].scatter(x, y, c=label, cmap=cmap_tsne, s=50, edgecolors='w')
            axs[i].axis('off')
            if titles: axs[i].title.set_text(titles[i])

            # set 1:1 ratio
            (x_left, x_right), (y_bot, y_top) = axs[i].get_xlim(), axs[i].get_ylim()
            ratio = abs((x_right-x_left) / (y_bot-y_top))
            axs[i].set_aspect(ratio)

    #ax.legend(loc='best')
    plt.show()


