Multimodal entailment involves analysis of text, images, audio, and video data sources to determine if a piece of information contradict another or whether a given piece of information implies the other. This is applied is social media content moderation where platform operators audits and moderate content.

In [1]:
!pip install -q tensorflow_text

In [2]:
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os
import matplotlib.pyplot as plt

os.environ['KERAS_BACKEND'] = 'jax'

import keras
import keras_hub
from keras.utils import PyDataset

In [6]:
# Define a label map
label_map = {
    'Contradictory': 0,
    'Implies': 1,
    'NoEntailment': 2
}

In [4]:
# Collect dataset
image_base_path = keras.utils.get_file(
    "tweet_images",
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
    untar=True,
)

# Read dataset and apply preprocessing to the first 1k samples
df = pd.read_csv(
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[0:1000]

df.sample(10)

Downloading data from https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz
[1m344273442/344273442[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 0us/step


Unnamed: 0,id_1,text_1,image_1,id_2,text_2,image_2,label
927,1334686373266198529,"Solar-powered holiday lights are convenient, m...",http://pbs.twimg.com/media/EoXCLslXIAMBFt5.jpg,1339307380450988034,"Solar-powered holiday lights are convenient, m...",http://pbs.twimg.com/media/EpYs9BaWEAAa2gI.jpg,NoEntailment
924,1378332375923499012,I just added Invincible (2021) to my library! ...,http://pbs.twimg.com/media/EyDR_bjXEAQDjeb.jpg,1380680917585383425,I've just watched episode S01 | E05 of Invinci...,http://pbs.twimg.com/media/Eykp-lFWUAU5PHD.jpg,NoEntailment
247,1353726537741312001,Total #COVID19 cases by age group (change from...,http://pbs.twimg.com/media/EslnEuJW8AESfdu.jpg,1360249815418945536,Total #COVID19 cases by age group (change from...,http://pbs.twimg.com/media/EuCT42TXIAAeWh8.jpg,NoEntailment
449,1373696320612003843,Heard Island and McDonald Islands still numero...,http://pbs.twimg.com/media/ExBZhhlW8AQBw6u.jpg,1374421735206887425,Heard Island and McDonald Islands still leads!...,http://pbs.twimg.com/media/ExLtSJfUcAwW9E8.jpg,NoEntailment
938,1364042142243504131,#NHL Impact Card for New York Islanders on 202...,http://pbs.twimg.com/media/Eu4NGaTWQAEeQ_f.png,1374174084821614602,#NHL Impact Card for New York Islanders on 202...,http://pbs.twimg.com/media/ExIMDFjXAAIufEh.png,NoEntailment
485,1342913748680519688,Lineups 🇸🇪 vs 🇨🇿\n#worldjuniors https://t.co/j...,http://pbs.twimg.com/media/EqLmNwdW8AQ8W-1.jpg,1346232678744526848,Line-ups for the semi-final between Canada and...,http://pbs.twimg.com/media/Eq62ZlcW4AUZ4uH.jpg,NoEntailment
160,1359907916489318405,"$SPY Today (8:30 CST), equities higher, FI mix...",http://pbs.twimg.com/media/Et88sRHXMAEXg-T.jpg,1360234816134905862,"$SPY Today (8:30 CST), equities mixed, FI lowe...",http://pbs.twimg.com/media/EuCGWzfXYAM7Cmg.jpg,Contradictory
469,1347963434550308867,WE'RE HYPED AND READY TO ENJOY #SUPERWILDCARDW...,http://pbs.twimg.com/media/ErTtmZFWMAM6zir.jpg,1350554189416583172,WE'RE HYPED AND READY FOR #DIVISIONALROUND 🏈 ...,http://pbs.twimg.com/media/Er4h4JaXUAA2zUN.jpg,NoEntailment
643,1363989251629740033,Mon 17:00: Sunny; Temp 2 C; Wind W 25 km/h; Hu...,http://pbs.twimg.com/media/Eu3c_uZXAAg82fz.png,1371071703888101381,Sun 08:00: Mainly Sunny; Temp -15.8 C; Windchi...,http://pbs.twimg.com/media/EwcGctnWgAQqUIA.png,NoEntailment
299,1362935097079521283,1st quarter of the Boys Varsity Basketball gam...,http://pbs.twimg.com/media/EuoePpoXYAEfv-m.png,1364366166282633218,1st quarter of the Girls Varsity Basketball ga...,http://pbs.twimg.com/media/Eu8zzEjUcAAjHU1.jpg,NoEntailment


In [7]:
# We formulate the entailment task as - given pairs of {text_1, image_1} and {text_2, image_2}, predict (contradict, entail, or neutral)

images_1_paths = []
images_2_paths = []

for i in range(len(df)):
    current_row = df.iloc[i]
    id_1 = current_row["id_1"]
    id_2 = current_row["id_2"]
    extension_1 = current_row["image_1"].split(".")[-1]
    extension_2 = current_row["image_2"].split(".")[-1]
    image_1_path = os.path.join(image_base_path, f"{id_1}.{extension_1}")
    image_2_path = os.path.join(image_base_path, f"{id_2}.{extension_2}")
    images_1_paths.append(image_1_path)
    images_2_paths.append(image_2_path)

df["image_1_path"] = images_1_paths
df["image_2_path"] = images_2_paths

# Add label column
df['label_idx'] = df['label'].apply(lambda x: label_map[x])

In [None]:
# Visualize sample dataset
def visualize(idx):
    current_row = df.iloc[idx]
    image_1 = plt.imread(current_row["image_1_path"])
    image_2 = plt.imread(current_row["image_2_path"])
    text_1 = current_row["text_1"]
    text_2 = current_row["text_2"]
    label = current_row["label"]

    plt.subplot(1, 2, 1)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title(f"Image 1: {text_1}")

    plt.subplot(1, 2, 2)
    plt.imshow(image_2)
    plt.axis("off")
    plt.title(f"Image 2: {text_2}")
    plt.show()
    print(f"Label: {label}")

random_idx = random.choice(range(len(df)))
visualize(random_idx)

random_idx = random.choice(range(len(df)))
visualize(random_idx)


In [None]:
# Train-Test split
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['label'].values, random_state=42)

# validation set
train_df, val_df = train_test_split(train_df, test_size=0.05, stratify=train_df['label'].values, random_state=42)

print(f"Total train examples: {len(train_df)}")
print(f"Total val examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}"

In [None]:
# Data input pipeline
text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset("bert_base_en_uncased", sequence_lengh=128,
)

In [None]:
# Run the preprocessor on a sample input
idx = random.choice(range(len(train_df)))
sample_input = train_df.iloc[idx]
sample_text_1 = sample_input["text_1"]
sample_text_2 = sample_input["text_2"]

print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")

processed_text = text_preprocessor([sample_text_1, sample_text_2])
print(processed_text)

print("Keys            : ", list(processed_text.keys()))
print("Shape Token Ids : ", processed_text['token_ids'].shape)
print("Token Ids       : ", processed_text['token_ids'][0, :16])
print("Shape Padding Masks     : ", processed_text['padding_mask'].shape)
print("Padding Masks     : ", processed_text['padding_mask'][0, :16])
print("Shape Segment Ids : ", processed_text['segment_ids'].shape)
print("Segment Ids       : ", processed_text['segment_ids'][0, :16])

In [None]:
# Create tf.data.Dataset objects from dataframes
def dataframe_to_dataset(dataframe):
    columns = ['image_1_path', 'image_2_path', 'text_1', 'text_2', 'label_idx']
    dataset = UnifiedPyDataset(dataframe, batch_size=32, workers=4)
    return dataset

# Preprocessing utiliteis
bert_input_features = ['padding_mask', 'segment_ids', 'token_ids']
def preprocess_text(text_1, text_2):
    output = text_preprocessor([text_1, text_2])
    return {feature: keras.ops.reshape(output[feature], [-1]) for feature in bert_input_features}

In [None]:
class UnifiedPyDataset(PyDataset):
    """A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""

    def __init__(
        self,
        df,
        batch_size=32,
        workers=4,
        use_multiprocessing=False,
        max_queue_size=10,
        **kwargs,
    ):
        """
        Args:
            df: pandas DataFrame with data
            batch_size: Batch size for dataset
            workers: Number of workers to use for parallel loading (Keras)
            use_multiprocessing: Whether to use multiprocessing
            max_queue_size: Maximum size of the data queue for parallel loading
        """
        super().__init__(**kwargs)
        self.dataframe = df
        columns = ["image_1_path", "image_2_path", "text_1", "text_2"]

        # image files
        self.image_x_1 = self.dataframe["image_1_path"]
        self.image_x_2 = self.dataframe["image_1_path"]
        self.image_y = self.dataframe["label_idx"]

        # text files
        self.text_x_1 = self.dataframe["text_1"]
        self.text_x_2 = self.dataframe["text_2"]
        self.text_y = self.dataframe["label_idx"]

        # general
        self.batch_size = batch_size
        self.workers = workers
        self.use_multiprocessing = use_multiprocessing
        self.max_queue_size = max_queue_size

    def __getitem__(self, index):
        """
        Fetches a batch of data from the dataset at the given index.
        """

        # Return x, y for batch idx.
        low = index * self.batch_size
        # Cap upper bound at array length; the last batch may be smaller
        # if the total number of items is not a multiple of batch size.

        high_image_1 = min(low + self.batch_size, len(self.image_x_1))
        high_image_2 = min(low + self.batch_size, len(self.image_x_2))

        high_text_1 = min(low + self.batch_size, len(self.text_x_1))
        high_text_2 = min(low + self.batch_size, len(self.text_x_1))

        # images files
        batch_image_x_1 = self.image_x_1[low:high_image_1]
        batch_image_y_1 = self.image_y[low:high_image_1]

        batch_image_x_2 = self.image_x_2[low:high_image_2]
        batch_image_y_2 = self.image_y[low:high_image_2]

        # text files
        batch_text_x_1 = self.text_x_1[low:high_text_1]
        batch_text_y_1 = self.text_y[low:high_text_1]

        batch_text_x_2 = self.text_x_2[low:high_text_2]
        batch_text_y_2 = self.text_y[low:high_text_2]

        # image number 1 inputs
        image_1 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
        ]
        image_1 = [
            (  # exeperienced some shapes which were different from others.
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_1
        ]
        image_1 = np.array(image_1)

        # Both text inputs to the model, return a dict for inputs to BertBackbone
        text = {
            key: np.array(
                [
                    d[key]
                    for d in [
                        preprocess_text(file_path1, file_path2)
                        for file_path1, file_path2 in zip(
                            batch_text_x_1, batch_text_x_2
                        )
                    ]
                ]
            )
            for key in ["padding_mask", "token_ids", "segment_ids"]
        }

        # Image number 2 model inputs
        image_2 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
        ]
        image_2 = [
            (  # exeperienced some shapes which were different from others
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_2
        ]
        # Stack the list comprehension to an nd.array
        image_2 = np.array(image_2)

        return (
            {
                "image_1": image_1,
                "image_2": image_2,
                "padding_mask": text["padding_mask"],
                "segment_ids": text["segment_ids"],
                "token_ids": text["token_ids"],
            },
            # Target lables
            np.array(batch_image_y_1),
        )

    def __len__(self):
        """
        Returns the number of batches in the dataset.
        """
        return math.ceil(len(self.dataframe) / self.batch_size)

In [None]:
# Create train, validation and test datasets
def prepare_dataset(dataframe):
    ds = dataframe_to_dataset(dataframe)
    return ds


train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)

In [None]:
# Model Building
# The model will take 2 images along with corresponding texts.
# The images can be fed directly to the model while the text will have to be preprocessed
# The model consist of Image encoder (ResNet50V2) and text encoder pretrained (BERT)

def project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate):
    projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = keras.ops.nn.gelu(projected_embeddings)
        x = keras.layers.Dense(projection_dims)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
        x = keras.layers.Add()([projected_embeddings, x])
        projected_embeddings = keras.layers.LayerNormalization()(x)
    return projected_embeddings

In [None]:
# Visual encoder utilities
def create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, trainable=False
        ):
    # Load the pre-trained ResNet50V2 model to be used as the base encoder.
    resnet_v2 = keras.applications.ResNet50V2(
        include_top=False, weights="imagenet", pooling="avg"
    )
    # Set the trainability of the base encoder.
    for layer in resnet_v2.layers:
        layer.trainable = trainable

    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Preprocess the input image.
    preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
    preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)

    # Generate the embeddings for the images using the resnet_v2 model
    # concatenate them.
    embeddings_1 = resnet_v2(preprocessed_1)
    embeddings_2 = resnet_v2(preprocessed_2)
    embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the vision encoder model.
    return keras.Model([image_1, image_2], outputs, name="vision_encoder")

In [None]:
def create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, trainable=False
        ):
    # Load the pre-trained BERT BackBone using KerasHub.
    bert = keras_hub.models.BertBackbone.from_preset(
        "bert_base_en_uncased", num_classes=3
    )

    # Set the trainability of the base encoder.
    bert.trainable = trainable

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }

    # Generate embeddings for the preprocessed text using the BERT model.
    embeddings = bert(inputs)["pooled_output"]

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the text encoder model.
    return keras.Model(inputs, outputs, name="text_encoder")

In [None]:
# Create multimodal model
def create_multimodal_model(
        num_projection_layers=1,
        projection_dims=256,
        dropout_rate=0.1,
        vision_trainable=False,
        text_trainable=False,
        ):
    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    text_inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }
    text_inputs = list(text_inputs.values())
    # Create the encoders.
    vision_encoder = create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, vision_trainable
    )
    text_encoder = create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, text_trainable
    )

    # Fetch the embedding projections.
    vision_projections = vision_encoder([image_1, image_2])
    text_projections = text_encoder(text_inputs)

    # Concatenate the projections and pass through the classification layer.
    concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
    outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
    return keras.Model([image_1, image_2, *text_inputs], outputs)


multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)

In [None]:
# Compilte and train
multimodal_model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)

In [None]:
# Evaluate the model
_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")

### Appendix

In [None]:
# we can introduce cross attention to ensure the model focuses on part(s) of the image that relate to the corresponding textual input.

# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)

# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
    [vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])