## Detailed Article Explaination

The detailed code explanation for this article is available at the following link:

https://www.daniweb.com/programming/computer-science/tutorials/541308/tensorflow-keras-sequence-data-generator-for-multimodal-classification

For my other articles for Daniweb.com, please see this link:

https://www.daniweb.com/members/1235222/usmanmalik57/posts

In [None]:
from google.colab import files

! pip install -q kaggle

files.upload()

! mkdir ~/.kaggle

! cp kaggle.json ~/.kaggle/

! chmod 600 ~/.kaggle/kaggle.json

In [2]:
! kaggle datasets download -d hammadjavaid/6992-labeled-meme-images-dataset

Downloading 6992-labeled-meme-images-dataset.zip to /content
100% 693M/693M [00:25<00:00, 33.7MB/s]
100% 693M/693M [00:25<00:00, 29.0MB/s]


In [3]:
! unzip /content/6992-labeled-meme-images-dataset.zip -d multimodal-memes

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: multimodal-memes/images/images/image_2793.jpg  
  inflating: multimodal-memes/images/images/image_2794.png  
  inflating: multimodal-memes/images/images/image_2795.png  
  inflating: multimodal-memes/images/images/image_2796.png  
  inflating: multimodal-memes/images/images/image_2797.jpg  
  inflating: multimodal-memes/images/images/image_2798.png  
  inflating: multimodal-memes/images/images/image_2799.png  
  inflating: multimodal-memes/images/images/image_28.jpg  
  inflating: multimodal-memes/images/images/image_280.jpg  
  inflating: multimodal-memes/images/images/image_2800.png  
  inflating: multimodal-memes/images/images/image_2801.png  
  inflating: multimodal-memes/images/images/image_2802.png  
  inflating: multimodal-memes/images/images/image_2803.png  
  inflating: multimodal-memes/images/images/image_2804.png  
  inflating: multimodal-memes/images/images/image_2805.jpg  
  inflating: multimodal

## Importing Required Libraries

In [4]:
! pip install accelerate -U
! pip install datasets transformers[sentencepiece]

Collecting accelerate
  Downloading accelerate-0.26.1-py3-none-any.whl (270 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/270.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m194.6/270.9 kB[0m [31m5.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m270.9/270.9 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.26.1
Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  D

In [19]:
import pandas as pd
import os
import numpy as np

import tensorflow as tf

from transformers import AutoTokenizer, TFBertModel
from transformers import AutoImageProcessor, TFViTModel


from keras.utils import Sequence
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Dense, Dropout, Concatenate
from keras.callbacks import ModelCheckpoint
from keras.models import load_model, Model
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

## Importing and Preprocessing the Dataset

In [20]:
# dataset download link
# https://www.kaggle.com/datasets/hammadjavaid/6992-labeled-meme-images-dataset

labels_df = pd.read_csv("/content/multimodal-memes/labels.csv")
labels_df.head()

Unnamed: 0.1,Unnamed: 0,image_name,text_ocr,text_corrected,overall_sentiment
0,0,image_1.jpg,LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...,LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...,very_positive
1,1,image_2.jpeg,The best of #10 YearChallenge! Completed in le...,The best of #10 YearChallenge! Completed in le...,very_positive
2,2,image_3.JPG,Sam Thorne @Strippin ( Follow Follow Saw every...,Sam Thorne @Strippin ( Follow Follow Saw every...,positive
3,3,image_4.png,10 Year Challenge - Sweet Dee Edition,10 Year Challenge - Sweet Dee Edition,positive
4,4,image_5.png,10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...,10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...,neutral


In [21]:
image_folder_path = '/content/multimodal-memes/images/images'
labels_df['image_path'] = labels_df['image_name'].apply(lambda x: os.path.join(image_folder_path, x))
labels_df = labels_df[labels_df['text_corrected'].notna() & (labels_df['text_corrected'] != '')]
labels_df = labels_df.filter(["text_corrected", "image_path", "overall_sentiment"])

print("====================================")
print(f'Dataset Shape: {labels_df.shape}')
print("====================================")
print(f'Sentiments % Per Category:\n{labels_df.overall_sentiment.value_counts(normalize=True) * 100}%')
print("====================================")

labels_df.head()

Dataset Shape: (6987, 3)
Sentiments % Per Category:
positive         44.697295
neutral          31.487047
very_positive    14.784600
negative          6.869901
very_negative     2.161156
Name: overall_sentiment, dtype: float64%


Unnamed: 0,text_corrected,image_path,overall_sentiment
0,LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...,/content/multimodal-memes/images/images/image_...,very_positive
1,The best of #10 YearChallenge! Completed in le...,/content/multimodal-memes/images/images/image_...,very_positive
2,Sam Thorne @Strippin ( Follow Follow Saw every...,/content/multimodal-memes/images/images/image_...,positive
3,10 Year Challenge - Sweet Dee Edition,/content/multimodal-memes/images/images/image_...,positive
4,10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...,/content/multimodal-memes/images/images/image_...,neutral


In [22]:
labels_df['overall_sentiment'] = labels_df['overall_sentiment'].replace({'very_positive': 'positive', 'very_negative': 'negative'})
labels_df = labels_df.sample(frac=1).reset_index(drop=True)

In [23]:
X = labels_df.drop('overall_sentiment', axis=1)
y = labels_df["overall_sentiment"]

# convert labels to one-hot encoded vectors
y = pd.get_dummies(y)

In [24]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_test, X_val, y_test, y_val = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)


## Defining Transformer Models for Text and Image Data

In [25]:
## importing text model and tokenizer

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = TFBertModel.from_pretrained("bert-base-uncased")

for layer in bert_model.layers[:-4]:
    layer.trainable = False

## importing image model and tokenizer

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

for layer in vit_model.layers[:-4]:
    layer.trainable = False

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

## Creating Keras Sequence Data Generator for Batch Processing

In [26]:
class MultiModalDataGenerator(Sequence):

    def __init__(self, df, labels, tokenizer, image_processor, batch_size=32, max_length=128):
        self.df = df
        self.labels_df = labels
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.batch_size = batch_size
        self.max_length = max_length

    def __len__(self):
        # Number of batches per epoch
        return int(np.ceil(len(self.df) / float(self.batch_size)))

    def __getitem__(self, idx):
        # Batch indices
        batch_indices = self.df.index[idx * self.batch_size:(idx + 1) * self.batch_size]

        # Initialize lists to store data
        batch_texts = []
        batch_images = []
        batch_labels = []

        # Loop over each index in the batch
        for i in batch_indices:
            # Append text
            batch_texts.append(self.df.at[i, 'text_corrected'])  # Replace 'text_column' with the name of your text column
            # Append Image paths

            batch_images.append(Image.open(self.df.at[i, 'image_path']).convert("RGB"))

            # Fetch labels
            label_values = self.labels_df.loc[i].values
            batch_labels.append(label_values)

        # Tokenize text data in the batch
        tokenized_data = self.tokenizer(batch_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="tf")

        # Process images

        processed_images = [self.image_processor(images=image, return_tensors="tf") for image in batch_images]
        image_tensors = tf.concat([img['pixel_values'] for img in processed_images], axis=0)


        # Convert labels to numpy array
        batch_labels = np.array(batch_labels, dtype='float32')

        final_features = {'input_ids': tokenized_data['input_ids'],
                          'attention_mask': tokenized_data['attention_mask'],
                          'image_input': image_tensors}
        return final_features, batch_labels


In [27]:
max_text_length = 128
batch_size = 8

train_generator = MultiModalDataGenerator(X_train,
                                y_train,
                                bert_tokenizer,
                                image_processor,
                                batch_size,
                                max_text_length)

test_generator = MultiModalDataGenerator(X_test,
                                y_test,
                                bert_tokenizer,
                                image_processor,
                                batch_size,
                                max_text_length)

val_generator = MultiModalDataGenerator(X_val,
                              y_val,
                              bert_tokenizer,
                              image_processor,
                              batch_size,
                              max_text_length)

In [28]:
# results = next(iter(train_generator))

# results[0]['image_input']

In [29]:
# output = vit_model(results[0]['images'])
# outputs.pooler_output.shape

## Defining the Keras Multimodal Classifier

In [30]:
# Define input layers for text
input_ids = Input(shape=(None,), dtype=tf.int32, name="input_ids")
attention_mask = Input(shape=(None,), dtype=tf.int32, name="attention_mask")

# Define input layer for images
image_input = Input(shape=(3, 224, 224), dtype=tf.float32, name="image_input")

# Get the output of BERT model
bert_outputs = bert_model(input_ids, attention_mask=attention_mask)
pooled_output = bert_outputs.pooler_output

# Get the output of ViT model
vit_outputs = vit_model(image_input)
vit_pooled_output = vit_outputs.pooler_output

# Concatenate the outputs from BERT and ViT
concatenated_outputs = Concatenate()([pooled_output, vit_pooled_output])


# Add additional layers for fine-tuning
x = Dense(512, activation='relu')(concatenated_outputs)
x = Dropout(0.1)(x)
final_output = tf.keras.layers.Dense(3, activation='softmax')(x)

# Create the model
model = Model(inputs=[input_ids, attention_mask, image_input], outputs=final_output)

adam_optimizer = Adam(learning_rate=2e-5)

# Compile the model
model.compile(optimizer = adam_optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])


In [31]:
# Define the checkpoint callback
checkpoint = ModelCheckpoint(
    'best_model.h5',
    monitor='val_accuracy',
    verbose=1,
    save_best_only=True,
    mode='max',
    save_weights_only=False
)

# Train the model
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=5,
    callbacks=[checkpoint],
    verbose=1
)

Epoch 1/5



Epoch 1: val_accuracy improved from -inf to 0.58083, saving model to best_model.h5


  saving_api.save_model(


Epoch 2/5
Epoch 2: val_accuracy did not improve from 0.58083
Epoch 3/5
Epoch 3: val_accuracy did not improve from 0.58083
Epoch 4/5
Epoch 4: val_accuracy did not improve from 0.58083
Epoch 5/5
Epoch 5: val_accuracy did not improve from 0.58083


## Making Predictions and Evaluating Model Performance

In [32]:
# Load the model, including the custom TFBertModel and TFViTModel layers
custom_objects = {"TFBertModel": TFBertModel, "TFViTModel": TFViTModel}
best_model = load_model('best_model.h5', custom_objects=custom_objects)


predictions = best_model.predict(test_generator)

# convert predicitons to binary values
predictions = (predictions == predictions.max(axis=1)[:, None]).astype(int)

# printing results
print(classification_report(y_test, predictions))
print(f"Accuracy score: {accuracy_score(y_test, predictions)}")

              precision    recall  f1-score   support

           0       0.00      0.00      0.00        65
           1       0.00      0.00      0.00       231
           2       0.58      1.00      0.73       403

   micro avg       0.58      0.58      0.58       699
   macro avg       0.19      0.33      0.24       699
weighted avg       0.33      0.58      0.42       699
 samples avg       0.58      0.58      0.58       699

Accuracy score: 0.5765379113018598


  _warn_prf(average, modifier, msg_start, len(result))
