In [None]:
import os
os.listdir('/kaggle/input/')

In [None]:
import os
import json
import string
from tqdm import tqdm
import pickle
from time import time
import numpy as np
from PIL import Image
import pandas as pd

import keras.preprocessing.image
from keras.models import Sequential
from keras.layers import (
    LSTM, Embedding, TimeDistributed, Dense, RepeatVector, 
    Activation, Flatten, Reshape, concatenate, Dropout, 
    BatchNormalization, GlobalAveragePooling2D, Conv2D,
    Activation, Add
)
from tensorflow.keras.optimizers import Adam , RMSprop
from tensorflow.keras import Input, layers
from tensorflow.keras import optimizers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import add
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.inception_v3 import InceptionV3
import tensorflow.keras.applications.inception_v3
from tensorflow.keras.applications import DenseNet169
import tensorflow.keras.applications
import tensorflow as tf

from tensorflow.keras.layers import TextVectorization
from tensorflow.keras.applications import efficientnet



import matplotlib.pyplot as plt
import matplotlib.image as mpimg

START = "aaprincipioaa"
STOP = "zzfinzz"

IMAGE_SIZE = (299, 299)
EMBEDDING_DIM = 512
NUM_HEADS = 2
FF_DIM = 32
EPOCHS = 40
BATCH_SIZE = 128
LSTM_BATCH_SIZE = 8
AUTOTUNE = tf.data.AUTOTUNE
WORD_COUNT_THRESHOLD = 3


# Root captioning contiene flickr-image-dataset y glove6B
ROOT_CAPTIONING = os.path.join("/", "kaggle", "input")
DATASET_DIR = os.path.join(ROOT_CAPTIONING, "flickr-image-dataset", "flickr30k_images")
MODEL_DIR = os.path.join("/", "kaggle", "working")

# Sequence preprocess

In [None]:
null_punct = str.maketrans('', '', string.punctuation)

lookup = {}

train_descriptions = dict()
valid_descriptions = dict()
test_descriptions = dict()
all_text_captions = list()

max_length = 0
max_length_caption = 0
max_length_image_path = 0
word_counts = {}

with open( os.path.join(ROOT_CAPTIONING, "flickr30k-split-json", "dataset_flickr30k.json") ) as f:
    reference_json = json.load(f)
    
for image_dict in reference_json['images']:
    
    if image_dict['split'] == "train":
        dict_to_add = train_descriptions
    elif image_dict['split'] == "val":
        dict_to_add = valid_descriptions
    elif image_dict['split'] == "test":
        dict_to_add = test_descriptions
    
    image_id = os.path.join(DATASET_DIR, 'flickr30k_images', image_dict['filename'])
    image_captions = image_dict['sentences']
    
    if len(image_captions) != 5:
        continue
    
    dict_to_add[image_id] = list()
    
    for caption_dict in image_captions:
        
        caption_str = caption_dict['raw']
        
        tokens = caption_str.split()
        
        lower_tokens = [word.lower() for word in tokens]
        clean_tokens = [w.translate(null_punct) for w in lower_tokens]
        description_tokens = [word for word in clean_tokens if word.isalpha()]

        if len(description_tokens) > max_length:
            max_length = len(description_tokens)
            max_length_caption = caption_str
            max_length_image_path = image_id

        for word in description_tokens:
            word_counts[word] = word_counts.get(word, 0) + 1

        caption_str = ' '.join(description_tokens)
        caption_str = START + " " + caption_str + " " + STOP
        dict_to_add[image_id].append(caption_str)
        all_text_captions.append(caption_str)

        
for k, v in train_descriptions.items():
    for desc in v:
        print(desc)
    break
    
# Hemos añadido dos tokens
max_length = max_length + 2

n_total = len(reference_json['images'])
n_train = len(train_descriptions)
n_valid = len(valid_descriptions)
n_test = len(test_descriptions)

print(f"Número de imágenes: {len(reference_json['images'])}\n")
print(f"Tamaño del vocabulario: {len(word_counts)}\n")
vocab = [w for w in word_counts if word_counts[w] >= WORD_COUNT_THRESHOLD]
vocab_size = len(vocab) + 2
print(f"Tamaño del vocabulario si consideramos sólo palabras que aparecen almenos {WORD_COUNT_THRESHOLD} veces: {vocab_size}\n")
print(f"Tamaño de la secuencia más larga: {max_length}\n")
print(f"Secuencia más larga: {max_length_caption}\n")
img = mpimg.imread(max_length_image_path)
imgplot = plt.imshow(img)
print("Imagen:")
plt.show()

MAX_LENGTH_SEQ = max_length

vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=MAX_LENGTH_SEQ
)

print( f"Número de descripciones: {len(all_text_captions)}\n")
vectorization.adapt(all_text_captions)

# LSTM

### 1.  Create dataset

In [None]:
vocab = vectorization.get_vocabulary()
vocab_lookup = dict(zip(vocab, range(len(vocab))))
max_decoded_sentence_length = MAX_LENGTH_SEQ - 1
valid_images = list(valid_descriptions.keys())
END_TOKEN_ID = vocab_lookup[STOP]

with open(os.path.join(ROOT_CAPTIONING,'1664-img-emb-densenet169', 'images','images.pkl'), "rb") as fp:
    all_encoding = pickle.load(fp)

print(len(all_encoding))

def decode_and_resize(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img

def data_generator(descriptions, num_photos_per_batch, img_embeds):
    n = 0
    x1, x2, y = [], [], []
    while True:
        for key, desc_list in descriptions.items():
            #print(key)
            n+=1
            photo = img_embeds[key.split("/")[-1]] #decode_and_resize(key)
            # Each photo has 5 descriptions
            for desc in desc_list:
                # Convert each word into a list of sequences.
                seq = vectorization(desc)
                # Generate a training case for every possible sequence and outcome
                index = 1
                while seq[index] != END_TOKEN_ID:
                    in_seq, out = seq[:index], seq[index]
                    in_seq = pad_sequences([in_seq], maxlen=MAX_LENGTH_SEQ, padding='post')[0]
                    out = to_categorical([out], num_classes=vocab_size)[0]
                    x1.append(photo)
                    x2.append(in_seq)
                    y.append(out)
                    index += 1
                    
            if n==num_photos_per_batch:
                yield ([np.array(x1), np.array(x2)], np.array(y))
                x1, x2, y = [], [], []
                n=0

train_steps = len(train_descriptions)//LSTM_BATCH_SIZE
valid_steps = len(valid_descriptions)//LSTM_BATCH_SIZE
train_generator = data_generator(train_descriptions, LSTM_BATCH_SIZE, all_encoding)
valid_generator = data_generator(valid_descriptions, LSTM_BATCH_SIZE, all_encoding)

### Error debugging

In [None]:
cont = 0
problematic = dict()
for key, desc_list in train_descriptions.items():
    cont += 1
    if cont > 2912:
        problematic[key] = desc_list
        print(key)
    if cont > 2920:
        break
        
problematic_generator = data_generator(problematic, 1, all_encoding)
for i in range(len(problematic)):
    b = next(problematic_generator)
    print(b)

In [None]:
img_size = 1664
phrase_size = 80
word_size = 9936
for i in range(train_steps):
    if i % 20 == 0:
        print(i)
    b = next(train_generator)
    b_0_0, b_0_1, b_1 = b[0][0].shape, b[0][1].shape, b[1].shape
    print(b_0_0, b_0_1, b_1)
    if b_0_0[1] != img_size or b_0_1[1] != phrase_size or b_1[1] != word_size:
        print("ERROR")
        break

print(b_0_0, b_0_1, b_1)

### 2.  Create arquitecture

In [None]:
glove_dir = os.path.join(ROOT_CAPTIONING,'glove6b')
embeddings_index = {} 
f = open(os.path.join(glove_dir, 'glove.6B.200d.txt'), encoding="utf-8")

for line in tqdm(f):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs

f.close()
print(f'Found {len(embeddings_index)} word vectors.')

In [None]:
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))

embedding_dim = 200
embedding_matrix = np.zeros((vocab_size, embedding_dim))
for i, word in index_lookup.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        embedding_matrix[i] = embedding_vector
print(embedding_matrix.shape)

In [None]:
"""
inputs1 = Input(shape=(299,299,3,))
encode_model = DenseNet169(weights='imagenet', include_top=False, input_tensor=inputs1)
encode_model.trainable = False
img_emb = GlobalAveragePooling2D()(encode_model.output)

fe1 = Dropout(0.5)(img_emb)
fe2 = Dense(256, activation='relu')(fe1)
"""

inputs1 = Input(shape=(1664,))
fe1 = Dropout(0.5)(inputs1)
fe2 = Dense(256, activation='relu')(fe1)


inputs2 = Input(shape=(MAX_LENGTH_SEQ,))
se1 = Embedding(vocab_size, embedding_dim, mask_zero=True, name="pretrained_glove")(inputs2)
se2 = Dropout(0.5)(se1)
se3 = LSTM(256)(se2)

decoder1 = add([fe2, se3])
decoder2 = Dense(256, activation='relu')(decoder1)
outputs = Dense(vocab_size, activation='softmax')(decoder2)
caption_model = Model(inputs=[inputs1, inputs2], outputs=outputs)
#caption_model.summary()

caption_model.get_layer("pretrained_glove").set_weights([embedding_matrix])
caption_model.get_layer("pretrained_glove").trainable = False
caption_model.compile(loss='categorical_crossentropy', optimizer='adam')

### 3. Train

In [None]:
model_path = os.path.join(".",f'caption-model-lstm-7-epochs.hdf5')
caption_model.load_weights("./caption-model-lstm.hdf5")
       
caption_model.fit(
    train_generator,
    epochs=6,
    steps_per_epoch=train_steps,
    verbose=1,
    validation_data=valid_generator,
    validation_steps=valid_steps
)

caption_model.save_weights(model_path)

### 4. Evaluate

### 5. Generate captions

In [None]:
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = MAX_LENGTH_SEQ - 1
valid_images = list(valid_descriptions.keys())

def generate_caption():
    # Select a random image from the validation dataset
    sample_img = np.random.choice(valid_images)
    print(sample_img)
    
    # Read the image from the disk
    img = decode_and_resize(sample_img)
    img = img.numpy().clip(0, 255).astype(np.uint8)
    plt.imshow(img)
    plt.show()

    img_embed = all_encoding[sample_img.split("/")[-1]]
    img_embed = np.expand_dims(img_embed, axis=0)

    decoded_caption = START
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])
        predictions = caption_model.predict([img_embed, tokenized_caption])
        sampled_token_index = np.argmax(predictions[0, :])
        sampled_token = index_lookup[sampled_token_index]
        decoded_caption += " " + sampled_token
        if sampled_token == STOP:
            break
        

    decoded_caption = decoded_caption.replace(START, "")
    decoded_caption = decoded_caption.replace(STOP, "").strip()
    print("Predicted Caption: ", decoded_caption)

# Check predictions for a few samples
generate_caption()
generate_caption()
generate_caption()

# TRANSFORMER

### 1.  Create tensorflow dataset

In [None]:
def decode_and_resize(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img


def process_input(img_path, captions):
    return decode_and_resize(img_path), vectorization(captions)


def make_dataset(images, captions):
    dataset = tf.data.Dataset.from_tensor_slices((images, captions))
    dataset = dataset.shuffle(len(images))
    dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)

    return dataset


# Pass the list of images and the list of corresponding captions
train_dataset = make_dataset(
    list(train_descriptions.keys()),
    list(train_descriptions.values())
)

valid_dataset = make_dataset(
    list(valid_descriptions.keys()),
    list(valid_descriptions.values())
)

### 2.  Create arquitecture

In [None]:
def get_cnn_model():
    base_model = efficientnet.EfficientNetB0(
        input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet",
    )
    # We freeze our feature extractor
    base_model.trainable = False
    base_model_out = base_model.output
    base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
    cnn_model = keras.models.Model(base_model.input, base_model_out)
    return cnn_model


class TransformerEncoderBlock(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.0
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.dense_1 = layers.Dense(embed_dim, activation="relu")

    def call(self, inputs, training, mask=None):
        inputs = self.layernorm_1(inputs)
        inputs = self.dense_1(inputs)
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=None,
            training=training,
        )
        out_1 = self.layernorm_2(inputs + attention_output_1)
        return out_1

class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_tokens = embedded_tokens * self.embed_scale
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)


class TransformerDecoderBlock(layers.Layer):
    def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
    
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
        self.ffn_layer_2 = layers.Dense(embed_dim)

        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()

        self.embedding = PositionalEmbedding(
            embed_dim=embed_dim, sequence_length=MAX_LENGTH_SEQ, vocab_size=vocab_size
        )
        self.out = layers.Dense(vocab_size, activation="softmax")

        self.dropout_1 = layers.Dropout(0.3)
        self.dropout_2 = layers.Dropout(0.5)
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, training, mask=None):
        
        inputs = self.embedding(inputs)
        causal_mask = self.get_causal_attention_mask(inputs)

        if mask is not None:
            padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
            combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
            combined_mask = tf.minimum(combined_mask, causal_mask)

        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=combined_mask,
            training=training,
        )

        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
            training=training,
        )

        out_2 = self.layernorm_2(out_1 + attention_output_2)


        ffn_out = self.ffn_layer_1(out_2)
        ffn_out = self.dropout_1(ffn_out, training=training)
        ffn_out = self.ffn_layer_2(ffn_out)

        ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
        ffn_out = self.dropout_2(ffn_out, training=training)

        preds = self.out(ffn_out)

        return preds
    
    
    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )
        return tf.tile(mask, mult)

class ImageCaptioningModel(keras.Model):
    def __init__(
        self, cnn_model, encoder, decoder, num_captions_per_image=5, image_aug=None,
    ):
        super().__init__()
        self.cnn_model = cnn_model
        self.encoder = encoder
        self.decoder = decoder
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.acc_tracker = keras.metrics.Mean(name="accuracy")
        self.num_captions_per_image = num_captions_per_image
        self.image_aug = image_aug

    def calculate_loss(self, y_true, y_pred, mask):
        loss = self.loss(y_true, y_pred)
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    def calculate_accuracy(self, y_true, y_pred, mask):
        accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
        accuracy = tf.math.logical_and(mask, accuracy)
        accuracy = tf.cast(accuracy, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

    def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
        encoder_out = self.encoder(img_embed, training=training)
        batch_seq_inp = batch_seq[:, :-1]
        batch_seq_true = batch_seq[:, 1:]
        mask = tf.math.not_equal(batch_seq_true, 0)
        batch_seq_pred = self.decoder(
            batch_seq_inp, encoder_out, training=training, mask=mask
        )
        loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
        acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
        return loss, acc

    def train_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        if self.image_aug:
            batch_img = self.image_aug(batch_img)

        # 1. Get image embeddings
        img_embed = self.cnn_model(batch_img)

        # 2. Pass each of the five captions one by one to the decoder
        # along with the encoder outputs and compute the loss as well as accuracy
        # for each caption.
        for i in range(self.num_captions_per_image):
            with tf.GradientTape() as tape:
                loss, acc = self._compute_caption_loss_and_acc(
                    img_embed, batch_seq[:, i, :], training=True
                )

                # 3. Update loss and accuracy
                batch_loss += loss
                batch_acc += acc

            # 4. Get the list of all the trainable weights
            train_vars = (
                self.encoder.trainable_variables + self.decoder.trainable_variables
            )

            # 5. Get the gradients
            grads = tape.gradient(loss, train_vars)

            # 6. Update the trainable weights
            self.optimizer.apply_gradients(zip(grads, train_vars))

        # 7. Update the trackers
        batch_acc /= float(self.num_captions_per_image)
        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        # 8. Return the loss and accuracy values
        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    def test_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        # 1. Get image embeddings
        img_embed = self.cnn_model(batch_img)

        # 2. Pass each of the five captions one by one to the decoder
        # along with the encoder outputs and compute the loss as well as accuracy
        # for each caption.
        for i in range(self.num_captions_per_image):
            loss, acc = self._compute_caption_loss_and_acc(
                img_embed, batch_seq[:, i, :], training=False
            )

            # 3. Update batch loss and batch accuracy
            batch_loss += loss
            batch_acc += acc

        batch_acc /= float(self.num_captions_per_image)

        # 4. Update the trackers
        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        # 5. Return the loss and accuracy values
        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker, self.acc_tracker]

    
# Data augmentation for image data
image_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.2),
        layers.RandomContrast(0.3),
    ]
)
cnn_model = get_cnn_model()
encoder = TransformerEncoderBlock(embed_dim=EMBEDDING_DIM, dense_dim=FF_DIM, num_heads=1)
decoder = TransformerDecoderBlock(embed_dim=EMBEDDING_DIM, ff_dim=FF_DIM, num_heads=2)
transformer_caption_model = ImageCaptioningModel(
    cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
)

### 3.  Train

In [None]:
# Define the loss function
cross_entropy = keras.losses.SparseCategoricalCrossentropy(
    from_logits=False, reduction="none"
)

# EarlyStopping criteria
early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)


# Learning Rate Scheduler for the optimizer
class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, post_warmup_learning_rate, warmup_steps):
        super().__init__()
        self.post_warmup_learning_rate = post_warmup_learning_rate
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        global_step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        warmup_progress = global_step / warmup_steps
        warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
        return tf.cond(
            global_step < warmup_steps,
            lambda: warmup_learning_rate,
            lambda: self.post_warmup_learning_rate,
        )


# Create a learning rate schedule
num_train_steps = len(train_descriptions) * EPOCHS
num_warmup_steps = num_train_steps // 15
lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)

# Compile the model
transformer_caption_model.compile(optimizer=tf.keras.optimizers.Adam(lr_schedule), loss=cross_entropy)

model_path = os.path.join(".",'./caption-model-transformer.hdf5')

# Fit the model

transformer_caption_model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    callbacks=[early_stopping],
)

transformer_caption_model.save_weights(model_path)

### 4.  Evaluate

In [None]:
import nltk

B_1_score = 0
B_2_score = 0
B_3_score = 0
B_4_score = 0

corpus_hypothesis = list()
corpus_references = list()

idx = 0

n_samples = len(test_descriptions.keys())

for image, desc in tqdm(test_descriptions.items()):
    # Read the image from the disk
    sample_img = decode_and_resize(image)

    # Pass the image to the CNN
    img = tf.expand_dims(sample_img, 0)
    img = transformer_caption_model.cnn_model(img)

    # Pass the image features to the Transformer encoder
    encoded_img = transformer_caption_model.encoder(img, training=False)

    # Generate the caption using the Transformer decoder
    decoded_caption = START
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)
        predictions = transformer_caption_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        if sampled_token == STOP:
            break
        decoded_caption += " " + sampled_token

    decoded_caption = decoded_caption.replace("PRINCIPIO ", "")
    decoded_caption = decoded_caption.replace(" FIN", "").strip()
    corpus_hypothesis.append(decoded_caption.split())
    references = [d.split() for d in desc]
    corpus_references.append(references)

    B1_score += nltk.translate.bleu_score.corpus_bleu(corpus_references, corpus_hypothesis)
    
    idx+=1
    if idx==10:
        break
    #B_1_score += nltk.translate.bleu_score.sentence_bleu(references, y_hat, weights=[(1.0)])
    #B_2_score += nltk.translate.bleu_score.sentence_bleu(references, y_hat, weights=(0.5, 0.5))
    #B_3_score += nltk.translate.bleu_score.sentence_bleu(references, y_hat, weights=(1.0/3.0, 1.0/3.0, 1.0/3.0))
    #B_4_score += nltk.translate.bleu_score.sentence_bleu(references, y_hat, weights=(0.25, 0.25, 0.25, 0.25))
    
print(B_1_score/n_samples)
#print(B_2_score/n_samples)
#print(B_3_score/n_samples)
#print(B_4_score/n_samples)

### 5. Generate captions

In [None]:
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = MAX_LENGTH_SEQ - 1
valid_images = list(valid_descriptions.keys())

def generate_caption():
    # Select a random image from the validation dataset
    sample_img = np.random.choice(valid_images)
    print(sample_img)

    # Read the image from the disk
    sample_img = decode_and_resize(sample_img)
    img = sample_img.numpy().clip(0, 255).astype(np.uint8)
    plt.imshow(img)
    plt.show()

    # Pass the image to the CNN
    img = tf.expand_dims(sample_img, 0)
    img = transformer_caption_model.cnn_model(img)

    # Pass the image features to the Transformer encoder
    encoded_img = transformer_caption_model.encoder(img, training=False)

    # Generate the caption using the Transformer decoder
    decoded_caption = START
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)
        predictions = transformer_caption_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        decoded_caption += " " + sampled_token
        if sampled_token == STOP:
            break
        

    decoded_caption = decoded_caption.replace(START, "")
    decoded_caption = decoded_caption.replace(STOP, "")
    print("Predicted Caption:", decoded_caption.strip())


# Check predictions for a few samples
generate_caption()
#generate_caption()
#generate_caption()