## Big Data ~ Momento de Retroalimentación

### Estudiante

- Nombre: Carlos Salguero
- Matrícula: A00833341

### Dataset

The Flickr30k dataset has become a standard benchmark for sentence-based image description. This paper presents Flickr30k Entities, which augments the 158k captions from Flickr30k with 244k coreference chains, linking mentions of the same entities across different captions for the same image, and associating them with 276k manually annotated bounding boxes. Such annotations are essential for continued progress in automatic image description and grounded language understanding. They enable us to define a new benchmark for localization of textual entity mentions in an image. We present a strong baseline for this task that combines an image-text embedding, detectors for common objects, a color classifier, and a bias towards selecting larger objects. While our baseline rivals in accuracy more complex state-of-the-art models, we show that its gains cannot be easily parlayed into improvements on such tasks as image-sentence retrieval, thus underlining the limitations of current methods and the need for further research.


In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import io
import json
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import *
from PIL import Image
from tqdm import tqdm
from typing import List, Dict

## Image Caption Model


In [2]:
class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size: int = 256, hidden_size: int = 256):
        super(ImageCaptionModel, self).__init__()

        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        self.image_embed = nn.Linear(resnet.fc.in_features, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.projection = nn.Sequential(
            nn.Linear(embed_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embed_size),
        )

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)

        features = features.view(features.size(0), -1)
        features = self.image_embed(features)
        features = self.dropout(features)
        features = self.projection(features)

        return features

## Auxiliar Functions


In [3]:
def create_spark_session(app_name: str = "Flick39k_PyTorch") -> SparkSession:
    """
    Creates a spark session with the necessary configuration.
    """

    return (
        SparkSession.builder.appName(app_name)
        .config("spark.driver.memory", "16g")
        .config("spark.executor.memory", "16g")
        .config("spark.driver.maxResultSize", "0")
        .config("spark.executor.cores", "4")
        .config("spark.python.worker.memory", "16g")
        .getOrCreate()
    )

In [4]:
def preprocess_image(image_path: str, transform=None) -> torch.Tensor:
    """
    Preprocesses the image to be used by the model.
    """

    if transform is None:
        transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    with Image.open(image_path).convert("RGB") as img:
        return transform(img)

## Combining PySpark and PyTorch


In [5]:
class SparkPyTorchTrainer:
    def __init__(
        self,
        spark: SparkSession,
        model: nn.Module,
        batch_size: int = 32,
        device: str = "cuda",
    ):
        self.spark = spark
        self.model = model.to(device)
        self.device = device
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def prepare_data(self, image_dir: str, annotation_file: str):
        with open(annotation_file, "r") as ann_file:
            annotations = json.load(ann_file)

        annotations_schema = StructType(
            [
                StructField("image_id", StringType(), True),
                StructField("caption", ArrayType(StringType()), True),
                StructField(
                    "boxes",
                    ArrayType(
                        StructType(
                            [
                                StructField("x", FloatType(), True),
                                StructField("y", FloatType(), True),
                                StructField("width", FloatType(), True),
                                StructField("height", FloatType(), True),
                            ]
                        )
                    ),
                    True,
                ),
            ]
        )

        annotations_df = self.spark.createDataFrame(
            annotations, schema=annotations_schema
        )

        annotations_df = annotations_df.withColumn(
            "image_path", concat(lit(image_dir), col("image_id"))
        )

        return annotations_df

    def preprocess_batch(self, batch_df):
        images, captions = [], []

        for row in batch_df.collect():
            try:
                img_tensor = preprocess_image(row.image_path, self.transform)
                images.append(img_tensor)
                captions.append(row.caption)

            except Exception as e:
                print(f"Error processing image: {row.image_path}")
                print(e)

                continue

        if not images:
            return None

        image_batch = torch.stack(images).to(self.device)
        with torch.no_grad():
            image_features = self.model(image_batch)

        return image_features, captions

    def train(self, data_df, batch_size: int = 32, num_epochs: int = 10):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

        for epoch in range(num_epochs):
            self.model.train()
            total_loss, num_batches = 0, 0

            batches = data_df.repartition(data_df.count() // batch_size)
            for batch in batches.toLocalIterator():
                batch_df = self.spark.createDataFrame([batch])
                result = self.preprocess_batch(batch_df)

                if result is None:
                    continue

                features, captions = result
                loss = self.calculate_loss(features, captions)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

            avg_loss = total_loss / num_batches
            print(f"Epoch: {epoch}, Loss: {avg_loss}")

    def calculate_loss(self, image_features, captions):
        return torch.mean(image_features.pow(2))

## Loading Everything


In [8]:
spark = create_spark_session()
print(spark.version)
print(spark.sparkContext.getConf().getAll())

3.5.3
[('spark.app.name', 'Flick39k_PyTorch'), ('spark.driver.port', '46609'), ('spark.driver.extraJavaOptions', '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'), ('spark.app.submitTime', '1

In [9]:
model = ImageCaptionModel()



In [10]:
trainer = SparkPyTorchTrainer(spark, model)

In [None]:
data_df = trainer.prepare_data(image_dir="../data/", annotation_file="../data/")

In [None]:
trainer.train(data_df)