In [4]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Creating a Vertex Pipeline to extract training data

This notebook (the second in a five-part series) creates a Vertex AI pipeline that scrapes images from an online source (e.g. Reddit) and stores the image metadata in Firestore. Here, you will build a pipeline that 

This notebook covers the following steps:

1. Creating a pipeline component to collect images from Reddit
1. Creating a pipeline component to store images in Cloud Storage
1. Creating a pipeline component to store metadata about the images in Firestore.

### Set IAM permissions

When you run a notebook on Vertex Workbench, the notebook runs in a Compute Engine context that has its own service account. You will need to give your service account IAM permissions to access Secret Manager before you can use it (in a pipeline).



### Enable the Cloud resources

For this notebook, you must have a Google Cloud project with the following resources:

+ A Cloud Storage bucket
+ The following APIs enabled:
  - Cloud Firestore
  - Vertex AI
  - Storage
  - Secret Manager
  
If you completed the [first](1_firestore.ipynb) notebook in this series, you should have these APIs already enabled.

In [3]:
# Get your GCP project id from gcloud
shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null
PROJECT_ID=shell_output[0]
print("Project ID: ", PROJECT_ID)

Project ID:  fantasymaps-334622


In [5]:
BUCKET = "fantasy-maps" # Google Cloud Storage bucket
COLLECTION_NAME = "FantasyMaps2" # Firestore collection name
LOCATION = "us-west1"
GCS_PREFIX = "FantasyMapsTest"
SUBREDDIT_NAME = "battlemaps"
LIMIT=300
MODEL_ID = "4304645197347684352"
MODEL_NAME = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"

### Install the required Python libraries

In [3]:
! rm -rfd requirements.txt

In [4]:
%%writefile requirements.txt
google-cloud-secret-manager
google-cloud-aiplatform
google-cloud-pipeline-components>=1.0.30
kfp
praw
pandas
spacy
pillow

Writing requirements.txt


In [5]:
! pip install -r requirements.txt

Collecting google-cloud-storage<3.0.0dev,>=1.32.0
  Downloading google_cloud_storage-1.44.0-py2.py3-none-any.whl (106 kB)
     |████████████████████████████████| 106 kB 4.9 MB/s            
Collecting google-auth<2,>=1.6.1
  Using cached google_auth-1.35.0-py2.py3-none-any.whl (152 kB)
Collecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5
  Using cached google_api_core-2.10.2-py3-none-any.whl (115 kB)
Collecting google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.10.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,<3.0.0dev,>=1.34.0
  Downloading google_api_core-1.34.0-py3-none-any.whl (120 kB)
     |████████████████████████████████| 120 kB 29.5 MB/s            
Installing collected packages: google-auth, google-api-core, google-cloud-storage
  Attempting uninstall: google-auth
    Found existing installation: google-auth 2.17.3
    Uninstalling google-auth-2.17.3:
      Successfully uninstalled google-auth-2.17.3
  Attempting uninstall: google-api-core
   

In [6]:
! python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.5.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl (12.8 MB)
     |████████████████████████████████| 12.8 MB 4.9 MB/s            
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


## Store your Reddit API key in Cloud Secret Manager

Although you can [create a new secret in Cloud Secret Manager programmatically](https://cloud.google.com/secret-manager/docs/creating-and-accessing-secrets#create), in this notebook you must create it using the Cloud Console.

To create a new secret in the Cloud Console, do the following:

  1. Open the [Cloud Console](https://console.cloud.google.com/security/secret-manager).
  1. Click **Create secret**.
  1. In the **Create secret** page, do the following:
     
     + Give your secret a memorable name. This notebook uses the Reddit API, so the name of the secret
       is `reddit-api-key`.
     + Upload the credentials file. In this example, the `client_id`, `secret`, and `user_agent` credentials
       provided by Reddit are stored as JSON in a single file.
  
  1. Click **Create secret** at the bottom of the page.
  

## Get Reddit API key from Secret Manager

The important bit about an API key is that it should remain _secret_. You don't want to have it embedded in a notebook where anyone can see it!

The next step is to make sure that you can access your Reddit API key programmatically from the notebook. We'll use the API key stored in Secret Manager to make calls to Reddit, both in the notebook and later from a Vertex AI pipeline.

This notebook assumes that your Reddit API key is stored as a JSON-formatted string, with the following fields:

```
{
    "secret": "YOUR_SECRET",
    "client_id": "YOUR_CLIENT_ID",
    "user_agent": "YOUR_USER_AGENT",
    "user_name": "YOUR_REDDIT_USER_NAME"
}
```

In [None]:
def get_reddit_credentials(project_id):
    """Gets the Reddit API key out of Secrets Manager
    
    Arguments:
        project_id (str): the current project ID
    
    Returns:
        JSON object (dict)
    """
    from google.cloud import secretmanager
    import json

    client = secretmanager.SecretManagerServiceClient()

    secret_resource_name = f"projects/{project_id}/secrets/reddit-api-key/versions/1"
    response = client.access_secret_version(request={"name": secret_resource_name})

    payload = response.payload.data.decode("UTF-8")
    reddit_key_json = json.loads(payload)

    return reddit_key_json

In [None]:
reddit_key_json = get_reddit_credentials(PROJECT_ID)

## Create a custom Reddit pipelines component

The pipeline and all it components need to be compiled into a runnable format. We use the Kubeflow Pipelines (`kfp`) SDK to create this uploadable pipelines file.

In [31]:
from typing import NamedTuple

import kfp
from kfp import dsl
from kfp.v2 import compiler
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, ClassificationMetrics, Metrics, component)
from kfp.v2.google.client import AIPlatformClient

from google.cloud import aiplatform
from google_cloud_pipeline_components import aiplatform as gcc_aip

Now we can define the pipeline. For this component, we are going to store the `pandas.DataFrame` that we compose from the Redit posts as a CSV file on Cloud Storage. We'll pass the URI of this Storage file onto the next piece of the pipeline.

In [43]:
"""Stage 1. Identify the images on Reddit that we could scrape.

This part of the pipeline calls the Reddit API, reads `limit` number of posts on the subreddit,
and then stores metadata about the posts in a CSV file on Storage.
"""
@component(packages_to_install=["praw",
                                "google-cloud-secret-manager",
                                "google-cloud-storage",
                                "numpy",
                                "pandas",
                                "spacy"])
def reddit(
    secret_name: str,
    subreddit_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    project_id: str,
    limit: int,
) -> str:
    from datetime import datetime
    import numpy as np
    import pandas as pd
    import praw
    import re
    
    from google.cloud import storage

    def get_reddit_credentials(project_id):
        """Gets the Reddit API key out of Secrets Manager
    
        Arguments:
            project_id (str): the current project ID

        Returns:
            JSON object (dict)
        """
        from google.cloud import secretmanager
        import json

        client = secretmanager.SecretManagerServiceClient()

        secret_resource_name = f"projects/{project_id}/secrets/{secret_name}/versions/1"
        response = client.access_secret_version(request={"name": secret_resource_name})
        payload = response.payload.data.decode("UTF-8")

        return json.loads(payload)
    
    def get_reddit_posts(reddit_credentials, subreddit_name, limit):
        """Gets posts from a subreddit.

        Arguments:
            reddit_credentials (dict): a dictionary with client_id, secret, and user_agent
            subreddit_name (str): the name of the subreddit to scrape posts from
            limit (int): the maximum number of posts to grab

        Returns:
            List of Reddit API objects
        """
        import praw

        reddit = praw.Reddit(client_id=reddit_credentials["client_id"], 
                     client_secret=reddit_credentials["secret"],
                     user_agent=reddit_credentials["user_agent"])

        return reddit.subreddit(subreddit_name).hot(limit=limit)

    def convert_posts_to_dataframe(posts, columns):
        """Converts a sequence of Reddit API post objects into a pandas.DataFrame.
        
        Arguments:
            posts (list(praw.Post)): the posts from Reddit
            columns (list(str)): the column headings for the Dataframe
        
        Returns:
            A pandas.Dataframe
        """
        import numpy as np
        import pandas as pd

        filtered_posts = [[s.title, s.selftext, s.id, s.url] for s in posts]
        filtered_posts = np.array(filtered_posts)
        reddit_posts_df = pd.DataFrame(filtered_posts,
                                   columns=columns)

        return reddit_posts_df
    
    COLUMNS = ['Title', 'Post', 'ID', 'URL']
    
    # Get the data from Reddit
    credentials = get_reddit_credentials(project_id=project_id)
    posts = get_reddit_posts(reddit_credentials=credentials, subreddit_name=subreddit_name,
                             limit=limit)
    
    reddit_posts_df = convert_posts_to_dataframe(posts=posts, columns=COLUMNS)
    
    # Remove all of the posts that don't meet our criteria
    import re
    jpg_df = reddit_posts_df[(reddit_posts_df["URL"].str.contains("jpg")) &
                             (reddit_posts_df["Title"].str.contains(pat = "\d+x\d"))]
    
    # Save the dataframe as CSV in Storage
    csv_str = jpg_df.to_csv()
    
    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(gcs_bucket_name)
    
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    
    csv_file_uri = f"{gcs_prefix_name}/_reddit-scraped-{subreddit_name}-{timestamp}.csv"
    
    file_blob = bucket.blob(csv_file_uri)
    file_blob.upload_from_string(csv_str)
    
    return csv_file_uri
    

## Create the Cloud Storage component

In this next component, we need to store any unique map images that we have picked up from the scraping. However, we need to validate that these images are useful training data before we process them and store their metadata in Firestore.

To validate the images, we will do the following:

1. Ensure that we don't already have the image in Firestore
2. Use a pre-trained, earlier version of our model to infer the existence of gridlines on the image

We'll do the first step in validation in the `storage` component. The second step (using an existing model to validate the usefulness of the images) will require using batch predictions on Vertex AI; we'll create another component to handle that part.

In [44]:
"""Stage 2. Save images from Reddit in Cloud Storage.

This part of the pipeline reads the CSV from the previous step, downloads
the image from Reddit locally, compares a hash value of the image against the
document IDs in Firestore, and then stores any image with new hash values (IDs)
in Cloud Storage.

One interstitial step in this process is to create a new human-friendly filename
for the image. The pipeline uses spaCy to asses the post title for tokens to
be used as filenames.
"""
@component(packages_to_install=["spacy",
                                "google-cloud-firestore",
                                "google-cloud-storage",
                                "pandas",
                                "jsonlines"])
def storage(
    project_id: str,
    location_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    collection_name: str,
    csv_input_file: str,
) -> NamedTuple(
    "outputs",
    [
        ("batch_predict_file_uri", str),
        ("posts_csv_file", str),
    ]
):
    
    from google.cloud import firestore
    from google.cloud import storage
    
    import base64
    from datetime import datetime
    from io import BytesIO
    import jsonlines
    import pandas as pd
    
    import spacy
    spacy.cli.download("en_core_web_sm")
    spacy.prefer_gpu()
    nlp = spacy.load("en_core_web_sm")
    
    def make_nice_filename(name):
        """Converts Reddit post title into a meaningful(ish) filename.

        Arguments:
            name (str): title of the post

        Returns:
            String. Format is `<adj.>-<nouns>.<cols>x<rows>.jpg`
        """
        import re

        dims = re.findall("\d+x\d+", name)
        if len(dims) is 0:
            return ""

        dims = dims[0].split("x")
        if len(dims) is not 2:
            return ""

        tokens = get_tokens(name)
        new_name = name.lower()[:30]

        if len(tokens) > 0:
            tokens = tokens[:6] # Arbitrarily keep new names to six words or less
            new_name = "_".join(tokens)

        return f"{new_name}.{dims[0]}x{dims[1]}.jpg"

    def get_tokens(title):
        """Analyzes a post for nouns, proper nouns, and adjectives.

        Arguments:
            title (str): title of the post

        Returns:
            List of string. Words to use in a filename.    
        """
        import spacy

        POS = ["PROPN", "NOUN", "ADJ"]
        words = []

        tokens = nlp(title)
        for t in tokens:
            pos = t.pos_

            if pos in POS:
                words.append(t.text.lower())

        return words
    def convert_image_to_hash(content):
        """Convert image data to hash value (str).

        Arguments:
            content (byte array): the image

        Return:
            The image hash value as a string.
        """
        import hashlib

        sha1 = hashlib.sha1()
        jpg_hash = sha1.update(content)
        jpg_hash = sha1.hexdigest()

        return jpg_hash
    
    def download_image(url):
        """Download an image from the internet to local file system.

        Arguments:
            url (str): the image to download

        Returns:
            Bool. Indicates whether downloading the image was successful.
        """
        import requests

        r = requests.get(url, stream=True)
        if r.status_code == 200:
            r.raw.decode_content = True
            
            hsh = convert_image_to_hash(r.content)
            return (r.content, hsh)
    
    # Begin pipeline
    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(gcs_bucket_name)

    firestore_client = firestore.Client(project=project_id)
    collection_ref = firestore_client.collection(collection_name)

    blob = bucket.blob(csv_input_file)
    csv_bytes = blob.download_as_string()
    csv_buffer = BytesIO(csv_bytes)

    jpg_df = pd.read_csv(csv_buffer)
    batch_prediction_inputs = []
    
    for i, row, in jpg_df.iterrows():
        url = row["URL"]
        title = row["Title"]
        
        content, hsh = download_image(url)

        # Check whether we already have this image
        doc_ref = collection_ref.document(hsh)
        doc_ref = doc_ref.get()
        if doc_ref.exists:
            continue
        
        file_name = make_nice_filename(title)
                
        img_gcs_uri = f"gs://{gcs_bucket_name}/{gcs_prefix_name}/{file_name}"
        blob_name = f"{gcs_prefix_name}/{file_name}"

        file_blob = bucket.blob(blob_name)
        image_buffer = BytesIO(content)

        # Get image grid metadata
        file_blob.upload_from_file(image_buffer)
        jpg_df.at[i, "URI"] = img_gcs_uri
        jpg_df.at[i, "UID"] = hsh
        jpg_df.at[i, "Filename"] = file_name
        
        batch_prediction_inputs.append({
            "content": img_gcs_uri,
            "mimeType": "image/jpeg",
        })

    # Save the dataframe as CSV in Storage (again)
    csv_str = jpg_df.to_csv()
    file_blob = bucket.blob(csv_input_file)
    file_blob.upload_from_string(csv_str)
    
    # Create batch prediction input file
    bpi = BytesIO()
    writer = jsonlines.Writer(bpi)
    writer.write_all(batch_prediction_inputs)
    writer.close()
    
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    bpi_gcs_path = f"{gcs_prefix_name}/_batch_prediction_input_{timestamp}.jsonl"
    batch_prediction_input_file = f"gs://{gcs_bucket_name}/{bpi_gcs_path}"
    
    bpi_str = str(bpi.getvalue(), encoding="UTF8")
    bpi_blob = bucket.blob(bpi_gcs_path)
    bpi_blob.upload_from_string(bpi_str)
    
    return (batch_prediction_input_file, csv_input_file)

## Create a custom batch prediction component

In [45]:
"""Stage 3. Use an earlier version of the model to identify good data candidates.

This part of the pipeline uses batch prediction to measure the usefulness of
the images from Reddit for refining the model. The batch prediction operation
returns confidence values for gridlines that are used in subsequent training runs. 
"""
@component(packages_to_install=["google-cloud-aiplatform"])
def custom_batch_prediction(
    project: str,
    location: str,
    model_resource_name: str,
    job_display_name: str,
    gcs_input_file: str,
    gcs_output_dir: str
) -> str:
    from google.cloud import aiplatform
    
    aiplatform.init(project=project, location=location)
    
    model=aiplatform.Model(model_resource_name)
    
    batch_prediction_job = model.batch_predict(
        job_display_name=job_display_name,
        gcs_source=gcs_input_file,
        gcs_destination_prefix=gcs_output_dir,
        sync=True,
    )

    batch_prediction_job.wait()
    
    return batch_prediction_job.output_info.gcs_output_directory
    

## Create the Firestore component

Before creating this component, it helps to understand the schema for documents in the Firestore collection. The identifier of the document is a truncated hash value derived from the data in the image.

Each document has the following fields and data types (expressed as JSON):

```json
{
   "BBoxes": [{
        "xMin": 0.0,
        "xMax": 0.0,
        "yMin": 0.0,
        "yMax": 0.0,
        "displayName": "string"
   }],
   "VTT": {
        "cellsOffsetX": 0,
        "cellsOffsetY": 0,
        "imageWidth": 0,
        "imageHeight": 0,
        "cellWidth": 0,
        "cellHeight": 0
    },
    "URI": "string",
    "URL": "string"
    "Title": "string",
    "Post": "string"
    "Filename": "string",
    "Width": 0.0,
    "Height": 0.0
    "Columns": 0,
    "Rows": 0,
    "NeedsSharding": false
}
```

In [46]:
"""Stage 4. Store the metadata of useful images into Firestore.

This pipeline component reads the results of the batch prediction, filters on
a set confidence threshold for the predictions, and then stores the metadata
of the "good" maps into Firestore.
"""
@component(packages_to_install=["Pillow",
                                "google-cloud-firestore",
                                "google-cloud-storage",
                                "numpy",
                                "pandas",
                                "jsonlines"])
def firestore(
    subreddit_name: str,
    collection_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    csv_input_file: str,
    batch_prediction_uri: str,
    project_id: str,
    threshold: float,
    percentage: float
) -> NamedTuple(
    "outputs",
    [
        ("usable", int),
        ("unusable", int),
        ("stored", int),
    ]
):
    
    from datetime import datetime
    import hashlib
    from io import BytesIO
    import json
    import jsonlines
    import math
    import pandas as pd
    from PIL import Image
    import re
    import requests
    import shutil

    from google.cloud import firestore
    from google.cloud import storage

    def process_batch_predict_output(bp_jsonl, threshold, percentage):
        """Parses a set of batch predictions for highest confidence inputs.
        
        Arguments:
            bp_jsonl (str): the batch prediction results, as a string
            threshold (float): the lowest confidence value to accept from 0.0 to 1.0
            percentage (float): the top percentage (by quality) of predictions to check
            
        Returns:
            list of GCS URIs
        """
        predictions = bp_jsonl.decode("utf-8").split("\n")
        reader = jsonlines.Reader(predictions)

        images_uris = {
            "usable": [],
            "unusable": [],
        }

        for obj in reader.iter(type=dict, skip_invalid=True):
            confidences = obj["prediction"]["confidences"]
            image_gcs_uri = obj["instance"]["content"]
            # NOTE: we check the predictions above the percentage of
            # to ensure they are above the acceptable threshold. If
            # the image is above the acceptable threshold, we keep
            # the image for training data.
            top_n_images = int(len(confidences) * percentage)
            marginal_result = confidences[top_n_images]
            if marginal_result > threshold:
                images_uris["usable"].append(image_gcs_uri)
            else:
                images_uris["unusable"].append(image_gcs_uri)
        
        return images_uris
        
    def get_image_width_and_height(img_bytes):
        """Open the image and get the image's height and width in pixels.

        Arguments:
            img_bytes (str):

        Returns:
            Tuple of width, height
        """
        w = h = 0
        
        f = BytesIO(img_bytes)
        with Image.open(f) as img:
            w, h = img.size

        return (int(w), int(h))

    def compute_vtt_data(width, height, columns, rows):
        """Calculate the VTT values for the image.

        Arguments:
            width (int):
            height (int):
            columns (int):
            rows (int): 
        Returns:
            Dict.
        """

        return {
            "cellsOffsetX": 0, # Assumes no offset
            "cellsOffsetY": 0, # Assumes no offset
            "imageWidth": int(width),
            "imageHeight": int(height),
            "cellWidth": int(width / columns),
            "cellHeight": int(height / rows)
        }

    def convert_image_to_hash(content):
        """Convert image data to hash value (str).

        Arguments:
            content (byte array): the image

        Return:
            The image hash value as a string.
        """
        sha1 = hashlib.sha1()
        jpg_hash = sha1.update(content)
        jpg_hash = sha1.hexdigest()

        return jpg_hash    
    
    def compute_bboxes(*, width=0, height=0, columns=0, rows=0, cell_width=0, cell_height=0):
        """Determines bounding boxes for image object detection.

        Arguments:
            width (int): width of the image
            height (int): height of the image
            columns (int): number of columns in the grid
            rows (int): number of rows in the grid
            cell_width (int):
            cell_height (int):

        Returns:
            List of dict.
        """
        bboxes = []
        BORDER = 1 # 1px border around the outside of the cell
        LABEL = "cell"

        curr_x = cell_width
        while curr_x < width:
            curr_y = cell_height
            while curr_y < height:
                x_min = (curr_x - BORDER) / width
                y_min = (curr_y - BORDER) / height
                x_max = (curr_x + cell_width + BORDER) / width
                y_max = (curr_y + cell_height + BORDER) / height
                bboxes.append({
                    "xMin": x_min,
                    "xMax": x_max,
                    "yMin": y_min,
                    "yMax": y_max,
                    "displayName": LABEL
                })
                curr_y = curr_y + cell_height
            curr_x = curr_x + cell_width
            
        return bboxes
    
    def store_metadata_fs(*, project_id, series, collection_name, uid):
        """Upserts image metadata into a Firestore collection.

        Arguments:
            project_id (str): the Google Cloud project to store these in
            series (pd.Series): a Pandas series with the image's metadata
            collection_name (str): the Firestore collection to store the data in
        """
        client = firestore.Client(project=project_id)

        series_dict = series.to_dict()

        # clean up the data a little bit before upserting
        vtt = series["VTT"]
        if vtt is not "":
            vtt = json.loads(vtt)
            series_dict["VTT"] = vtt

        bboxes = series["BBoxes"]
        if bboxes is not "":
            bboxes = json.loads(bboxes)["bboxes"]
            series_dict["BBoxes"] = bboxes

        # upsert the dict directly into Firestore!
        client.collection(collection_name).document(uid).set(series_dict)
    
    # BEGIN MAIN
    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(gcs_bucket_name)
    
    firestore_client = firestore.Client(project=project_id)
    collection_ref = firestore_client.collection(collection_name)

    # Determine how many of the scraped data is usable for training a model
    bp_path = batch_prediction_uri.replace(f"gs://{gcs_bucket_name}/", "")
    
    blobs = bucket.list_blobs(prefix=bp_path)
    for b in blobs:
        if b.name.find(".jsonl") > -1:
            bp_jsonl_blob = b
            break
            
    bp_jsonl = bp_jsonl_blob.download_as_string()
    
    image_uris = process_batch_predict_output(bp_jsonl, threshold, percentage)
    
    # Open up the complete list of scraped images as a DataFrame
    blob = bucket.blob(csv_input_file)
    csv_bytes = blob.download_as_string()
    csv_buffer = BytesIO(csv_bytes)

    jpg_df = pd.read_csv(csv_buffer)
    
    # Iterate over JPG URIs, download them in batches, convert to sha values
    for i, row in jpg_df.iterrows():
        jpg_uri = row["URI"]
        filename = row["Filename"]
        
        if jpg_uri in image_uris["unusable"]:
            continue

        jpg_uri = jpg_uri.replace(f"gs://{gcs_bucket_name}/", "")
        jpg_blob = bucket.blob(jpg_uri)
        jpg_bytes = jpg_blob.download_as_bytes()
        
        w, h = get_image_width_and_height(jpg_bytes)
    
        jpg_df.at[i, "Width"] = w
        jpg_df.at[i, "Height"] = h

        # Get columns & rows for original image, based upon the name.
        paths = filename.split(".")
        dims = paths[-2]
        cols, rows = dims.split("x")
    
        cols = int(cols)
        rows = int(rows)

        jpg_df.at[i, "Columns"] = cols
        jpg_df.at[i, "Rows"] = rows
        
        # Compute the vtt data for the image
        vtt = compute_vtt_data(width=w, height=h, columns=cols, rows=rows)
    
        # Note: pandas has issues storing a dict in a cell
        jpg_df.at[i, "VTT"] = json.dumps(vtt)  
        bboxes = compute_bboxes(width=w,
                                height=h,
                                columns=cols,
                                rows=rows,
                                cell_width=vtt["cellWidth"],
                                cell_height=vtt["cellHeight"])
        if len(bboxes) is 0:
            print(f"Error: {filename}") 
                  
        jpg_df.at[i, "BBoxes"] = json.dumps({ "bboxes": bboxes })
        
        if (cols * rows) > 500:
            jpg_df.at[i, "NeedsSharding"] = True
        else:
            jpg_df.at[i, "NeedsSharding"] = False
            
    complete_df.set_index("UID", inplace=True)
    #complete_df.fillna("", inplace=True)
    complete_df.head(10)
    
    for uid, row in complete_df.iterrows():
        store_metadata_fs(project_id=project_id, series=row,
                          collection_name=collection_name, uid=uid)
        
    return (len(image_uris["usable"]), len(image_uris["unusable"]), len(complete_df.index))

## Create image shards

In [47]:
"""Stage 5. Create smaller training images from stored images

This pipeline component filters the Firestore collection to identify images that are
too large (e.g. have more than 500 grid cells). For those images, the component 
creates smaller images (shards) from cropped versions of the original image.
"""
@component(packages_to_install=["Pillow",
                                "google-cloud-firestore",
                                "google-cloud-storage",
                                "numpy",
                                "pandas",
                                "jsonlines"])
def create_shards(
    collection_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    project_id: str
) -> NamedTuple(
    "outputs",
    [
        ("shards_created", int),
    ]
):
    import math
    import pandas as pd
    from PIL import Image
    
    def create_shard_path(filename, x_min, y_min, cols, rows):
        """Convert an image path string to new string.

        Assumes the filename is of the format:
            <name>.<cols>x<rows>.jpg

        Arguments:
            filename (str):
            x_min (int):
            y_min (int):
            cols (int):
            rows (int):

        Returns:
            String. New image path.
        """
        paths = filename.split(".")
        paths[-2] = f"{math.floor(x_min)}_{math.floor(y_min)}.{cols}x{rows}"
        s_path = ".".join(paths)
        return s_path
    
    def create_shard(x_min, y_min, x_max, y_max, cols, rows, filename, content, parent_id):
        """Crops and saves an image.

        Arguments:
            x_min (int): the left-most point to crop, relative to the parent image
            y_min (int): the top-most point to crop, relative to the parent image
            x_max (int): the right-most point, relative to the parent image
            y_max (int): the bottom-most poinst, relative to the parent image
            cols (cols): the grid columns in this shard
            rows (rows): the grid rows in this shard
            filename (str): the parent image's local path
            content (bytes): the original image
            parent_id (str): the parent image's UID

        Returns:
            DataFrame with local path, UID, width, height, columns, and rows

        """
        try:

            f = BytesIO(content)
            with Image.open(f) as img:
                shard = img.crop((int(x_min), int(y_min), int(x_max), int(y_max)))

                # Get new filepath name
                s_path = create_shard_path(filename, x_min, y_min, cols, rows)

                # Get new UID
                hash_ = convert_image_to_hash(shard.tobytes())

                shard.save(s_path)

                d = {
                    "Width": int(x_max - x_min),
                    "Height": int(y_max - y_min),
                    "Columns": cols,
                    "Rows": rows,
                    "UID": hash_,
                    "Path": s_path,
                    "IsShard": True,
                    "Parent": parent_id
                }

        except SystemError as e:
            print(f"Error: {img_path}, bounds: {x_max},{y_max}")
            return None

        return pd.DataFrame(data=d, index=[0])

    def compute_shard_coordinates(width, height, cell_width,
                                  cell_height, columns, rows):
        """Converts image data into 1,or more shards.

        Arguments:
            width (int):
            height (int):
            cell_width (int):
            cell_height (int):
            columns (int):
            rows (int):

        Returns:
            List of tuples of (xMin, yMin, xMax, yMax, columns, rows)
        """
        total_cells = columns * rows
        if total_cells <= 500:
            return

        # Assume that a perfectly square map that approaches 500 cells is 22 cols by 22 rows.
        # Cut an image into as many 22x22 shards as possible
        SQRT = 22

        h_shards = math.floor(columns / SQRT)
        h_rem = columns % SQRT
        v_shards = math.floor(rows / SQRT)
        v_rem = rows % SQRT
        shard_columns = shard_rows = SQRT
        shards = []
    
        # Edge case 1: we have a narrow width (portrait-oriented) map
        if h_shards == 0:
            h_shards = 1
            h_rem = 0
            shard_columns = columns

            # Edge case 2: we have a short height (landscape-oriented) map
            if v_shards == 0:
                v_shards = 1
                v_rem = 0
                shard_rows = rows

            curr_min_x = 0
            curr_min_y = 0
            for _ in range(h_shards):
                max_x = (cell_width * shard_columns) + curr_min_x
                if max_x > width:
                    max_x = width
                for _ in range(v_shards):
                    max_y = (cell_height * shard_rows) + curr_min_y
                    if max_y > height:
                        max_y = height

                    shards.append((curr_min_x, curr_min_y, max_x, max_y, shard_columns, shard_rows))
                    curr_min_y = max_y

                curr_min_y = 0
                curr_min_x = max_x
    
        # Get the right-side remainder
        curr_min_x = width - (h_rem * cell_width)
        curr_min_y = 0
        for _ in range(v_shards):
            max_y = (cell_height * shard_rows) + curr_min_y
            if max_y > height:
                max_y = height
            shards.append((curr_min_x, curr_min_y, width, max_y, h_rem, shard_rows))
            curr_min_y = max_y

        # Get the bottom-side remainder
        curr_min_y = height - (v_rem * cell_height)
        curr_min_x = 0
        for _ in range(h_shards):
            max_x = (cell_width * shard_columns) + curr_min_x
            if max_x > width:
                max_x = width
            shards.append((curr_min_x, curr_min_y, max_x, height, shard_columns, v_rem))
            curr_min_x = max_x

        return shards
    
    # BEGIN MAIN
    shards = compute_shard_coordinates(width=w, height=h, columns=cols, rows=rows,
                                       cell_width=vtt["cellWidth"], cell_height=vtt["cellHeight"])
            
    for shard in shards:
        shard_df = create_shard(x_min=shard[0], y_min=shard[1], x_max=shard[2],
                                   y_max=shard[3], cols=shard[4], rows=shard[5],
                                   filename=filename, content=jpg_bytes, parent_id=row["UID"])

        if shard_df is None:
            continue

        s_vtt = vtt
        s_vtt["width"] = int(shard_df.iloc[0]["Width"])
        s_vtt["height"] = int(shard_df.iloc[0]["Height"])
        shard_df.at[0, "VTT"] = json.dumps(s_vtt)

        bboxes = compute_bboxes(dataframe=shard_df,
                                cell_width=vtt["cellWidth"],
                                cell_height=vtt["cellHeight"])

        shard_df.at[0, "BBoxes"] = json.dumps({ "bboxes": bboxes })
        shards_df = pd.concat([shards_df, shard_df])

## Build a simple pipeline

In [59]:
@dsl.pipeline(
    name="reddit-scraper-pipeline",
    description="Gets data from a subreddit",
    pipeline_root=f"gs://{BUCKET}/pipeline_root",
)
def reddit_pipeline(
    collection_name: str = COLLECTION_NAME,
    secret_name: str = "reddit-api-key",
    subreddit_name: str = SUBREDDIT_NAME,
    gcs_bucket: str = BUCKET,
    gcs_prefix: str = GCS_PREFIX,
    gcs_bp_output: str = "gs://fantasy-maps/ScrapedData",
    project_id: str = PROJECT_ID,
    location: str = LOCATION,
    limit: int = LIMIT,
    threshold: float = 0.3, # confidence value of 0.30 or better
    percentage: float = 0.1, # top 10%
    model_name: str = MODEL_NAME
):
    
    # Get the images from Reddit
    reddit_op = reddit(
        secret_name=secret_name,
        subreddit_name=subreddit_name,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        project_id=project_id,
        limit=limit,
    )
    
    reddit_csv_file = reddit_op.output
    
    # Store the new images on Cloud Storage
    storage_op = storage(
        project_id=project_id,
        location_name=location,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        collection_name=collection_name,
        csv_input_file=reddit_csv_file
    )
    
    batch_prediction_op= custom_batch_prediction(
        project=project_id,
        location=location,
        model_resource_name=model_name,
        job_display_name="test-custom-bp",
        gcs_input_file=storage_op.outputs["batch_predict_file_uri"],
        gcs_output_dir=gcs_bp_output
    )
    
    # Store training data in Firestore!
    firestore(
        subreddit_name = subreddit_name,
        collection_name = collection_name,
        gcs_bucket_name = gcs_bucket,
        gcs_prefix_name = gcs_prefix,
        csv_input_file = storage_op.outputs["posts_csv_file"],
        batch_prediction_uri = batch_prediction_op.output,
        project_id = project_id,
        threshold = threshold,
        percentage = percentage
    )

In [60]:
compiler.Compiler().compile(
    pipeline_func=reddit_pipeline, package_path="artifacts/reddit_scraper_pipeline_job.json"
)

In [61]:
api_client = AIPlatformClient(
    project_id=PROJECT_ID,
    region=LOCATION,
)

When we run the pipeline, we don't want it to cache the pipeline, since caching the pipeline will likely result in producing the exact same results.

In [62]:
response = api_client.create_run_from_job_spec(
    job_spec_path="artifacts/reddit_scraper_pipeline_job.json",
    enable_caching=False # Change to False when needing to generate new values per job run
)