<a href="https://colab.research.google.com/github/sherna90/inteligencia_artificial/blob/master/7.-cnn_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transfer learning and fine-tuning

Este tutorial te guiará a través del proceso de creación, entrenamiento y validación de una red neuronal convolucional (CNN) usando JAX y Flax para clasificar imágenes del dataset MNIST.

1. Examine and understand the data
1. Build an input pipeline, in this case using Keras ImageDataGenerator
1. Compose the model
   * Load in the pretrained base model (and pretrained weights)
   * Stack the classification layers on top
1. Train the model
1. Evaluate model


In [None]:
!pip install jax jaxlib flax optax matplotlib


In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist

### Arquitectura de la CNN

### Modelo

Usaremos Flax Linen para definir la CNN. Este ejemplo incluye dos capas convolucionales seguidas de una capa densa para la clasificación.



In [None]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1)) # Aplanar la salida para la capa densa
      x = nn.Dense(features=10)(x)
      return x

Cargar y preprocesar el dataset MNIST:

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalizar y añadir una dimensión para el canal
x_train = x_train.astype(jnp.float32) / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.astype(jnp.float32) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1)

y_train = jax.nn.one_hot(y_train, num_classes=10)
y_test = jax.nn.one_hot(y_test, num_classes=10)



In [None]:
x_train.shape,y_train.shape

In [None]:
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [None]:
plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(x_train[i],cmap='gray')
  plt.title(class_names[jnp.argmax(y_train[i])])
  plt.axis("off")

As the original dataset doesn't contain a test set, you will create one. To do so, determine how many batches of data are available in the validation set using `tf.data.experimental.cardinality`, then move 20% of them to a test set.

In [None]:
def loss_fn(params, images, labels):
    logits = model.apply({'params': params}, images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=jnp.argmax(labels, axis=-1)).mean()
    return loss


In [None]:
key = jax.random.PRNGKey(0)
model = CNN()

dummy_input = jnp.zeros((1, 28, 28, 1))
params = model.init(key, dummy_input)['params']

optim = optax.sgd(1e-3)
train_state_ = optim.init(params)

In [None]:
train_state_

In [None]:
@jax.jit
def train_step(state, images, labels):
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, images, labels)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [None]:
epochs = 5
batch_size = 32

for epoch in range(epochs):
    for i in range(0, len(x_train), batch_size):
        batch_images = x_train[i:i+batch_size]
        batch_labels = y_train[i:i+batch_size]
        train_state_, loss = train_step(train_state_, batch_images, batch_labels)
    print(f"Epoch {epoch+1}, Loss: {loss}")