<a href="https://colab.research.google.com/github/sherna90/inteligencia_artificial/blob/master/7.-cnn_mnist_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Redes Neuronales Convolucionales

Este tutorial te guiará a través del proceso de creación, entrenamiento y validación de una red neuronal convolucional (CNN) usando JAX y Flax para clasificar imágenes del dataset MNIST.

1. Examine and understand the data
1. Build an input pipeline, in this case using Keras ImageDataGenerator
1. Compose the model
   * Load in the pretrained base model (and pretrained weights)
   * Stack the classification layers on top
1. Train the model
1. Evaluate model


In [None]:
!pip install jax jaxlib flax optax matplotlib


In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

### Arquitectura de la CNN

### Modelo

Usaremos Flax Linen para definir la CNN. Este ejemplo incluye dos capas convolucionales seguidas de una capa densa para la clasificación.



In [None]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1)) # Aplanar la salida para la capa densa
      x = nn.Dense(features=10)(x)
      return x

Cargar y preprocesar el dataset MNIST:

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalizar y añadir una dimensión para el canal
x_train = x_train.astype(jnp.float32) / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.astype(jnp.float32) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1)

y_train = jax.nn.one_hot(y_train, num_classes=10)
y_test = jax.nn.one_hot(y_test, num_classes=10)



In [None]:
x_train.shape,y_train.shape

In [None]:
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [None]:
plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(x_train[i],cmap='gray')
  plt.title(class_names[jnp.argmax(y_train[i])])
  plt.axis("off")

Inicializar el modelo y el estado de entrenamiento:

In [None]:
@jax.jit
def predict(params, inputs):
  return net.apply({"params": params}, inputs)


@jax.jit
def loss_fun(params, data):
  inputs, labels = data
  logits = predict(params, inputs)
  loss = optax.softmax_cross_entropy(logits=logits, labels=labels).mean()
  return loss


net = CNN()
optimizer = optax.adam(1e-3)
rng = jax.random.PRNGKey(0)
dummy_data = jnp.zeros((1, 28, 28, 1))
params = net.init({"params": rng}, dummy_data)["params"]

 Definir la función y ciclo de entrenamiento usando mini-batches.

In [None]:
def train_step(params, optimizer, x_train, y_train,x_test,y_test,num_epochs=10,batch_size=64):
  """Computes loss and accuracy over the dataset `data_loader`."""
  opt_state = optimizer.init(params)
  train_accuracy = []
  train_loss = []
  test_accuracy = []
  test_loss = []
  for j in range(num_epochs):
    for i in range(0, len(x_train), batch_size):
      batch_images = x_train[i:i+batch_size]
      batch_labels = y_train[i:i+batch_size]
      loss, grads = jax.value_and_grad(loss_fun)(params, (batch_images, batch_labels))
      updates, opt_state = optimizer.update(grads, opt_state)
      params = optax.apply_updates(params, updates)
      predictions = predict(params, batch_images)
      accuracy = jnp.mean(jnp.argmax(predictions, axis=-1) == jnp.argmax(batch_labels, axis=-1))
      train_accuracy.append(accuracy)
      train_loss.append(loss)
    if j % (num_epochs//10) == 0:
      test_predictions = predict(params, x_test)
      test_accuracy.append(jnp.mean(jnp.argmax(test_predictions, axis=-1) == jnp.argmax(y_test, axis=-1)))
      test_loss.append(loss_fun(params,(x_test,y_test)))
      print(f"Epoch {j}, Loss: {loss}, Accuracy: {accuracy}")
  return params,(train_loss,train_accuracy),(test_loss,test_accuracy)

In [None]:
updated_params,train_history,test_history=train_step(params, optimizer, x_train, y_train,x_test,y_test)

In [None]:
jnp.argmax(y_train[-10:],axis=-1)

In [None]:
predictions = predict(updated_params, x_train[-10:])
jnp.argmax(predictions, axis=-1)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2,sharex=True, sharey=False,figsize=(10,5))
fig.suptitle('Train')
ax1.plot(train_history[0],linewidth=1, linestyle='--')
ax1.set_title('Loss')
ax2.plot(train_history[1],'tab:orange',linewidth=1, linestyle='--')
ax2.set_title('Accuracy')
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2,sharex=True, sharey=False,figsize=(10,5))
fig.suptitle('Test')
ax1.plot(test_history[0],linewidth=1, linestyle='--')
ax1.set_title('Loss')
ax2.plot(test_history[1],'tab:orange',linewidth=1, linestyle='--')
ax2.set_title('Accuracy')
plt.show()