# Extract training data from Reddit

This notebook uses Cloud Secret Manager to import an API key into a Vertex AI pipeline.

## Install all dependencies

In [1]:
! pip install google-cloud-secret-manager google-cloud-aiplatform kfp google-cloud-pipeline-components praw --upgrade



### Set project information

In [2]:
# Get your GCP project id from gcloud
shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null
PROJECT_ID=shell_output[0]
print("Project ID: ", PROJECT_ID)

Project ID:  fantasymaps-334622


### Set IAM permissions on your service account

`secretmanager.versions.access`

## Store your API key in Cloud Secret Manager

Although you can [create a new secret in Cloud Secret Manager programmatically](https://cloud.google.com/secret-manager/docs/creating-and-accessing-secrets#create), in this notebook you must create it using the Cloud Console.

To create a new secret in the Cloud Console, do the following:

  1. Open the [Cloud Console](https://console.cloud.google.com/security/secret-manager).
  1. Click **Create secret**.
  1. In the **Create secret** page, do the following:
     
     + Give your secret a memorable name. This notebook uses the Reddit API, so the name of the secret
       is 'reddit-api-key'.
     + Upload the credentials file. In this example, the `client_id`, `secret`, and `user_agent` credentials
       provided by Reddit are stored as JSON in a single file.
  
  1. Click **Create secret** at the bottom of the page.
  

## Access the key programmatically

In [None]:
from google.cloud import secretmanager
import json

client = secretmanager.SecretManagerServiceClient()

secret_resource_name = f"projects/{PROJECT_ID}/secrets/reddit-api-key/versions/1"

response = client.access_secret_version(request={"name": secret_resource_name})

payload = response.payload.data.decode("UTF-8")

reddit_key_json = json.loads(payload)

### Construct a request to the Reddit API

In [None]:
import praw

reddit = praw.Reddit(client_id=reddit_key_json["client_id"], 
                     client_secret=reddit_key_json["secret"],
                     user_agent=reddit_key_json["user_agent"])
print(f'Reddit is in read-only mode: {reddit.read_only}')

In [None]:
import numpy as np
import pandas as pd

nan_value = float("NaN")
sciatica_sub = "sciatica"

In [None]:
posts = reddit.subreddit(sciatica_sub).hot(limit=100)
filtered_posts = [[s.title, s.selftext, s.id] for s in posts]

filtered_posts = np.array(filtered_posts)
reddit_posts_df = pd.DataFrame(filtered_posts,
                               columns=['Title', 'Posts', 'ID'])

# Drop all the rows with empty values
reddit_posts_df.replace("", nan_value, inplace=True)
reddit_posts_df = reddit_posts_df[reddit_posts_df.Posts != nan_value]


# Print 
reddit_posts_df.head(10)

print(reddit_posts_df.iloc[8]['Title'])

In [None]:
from typing import NamedTuple
from google.cloud import secretmanager
import json

def get_google_cloud_credentials():
    from google import auth
    creds, project = auth.default()

    LocalCredentials = NamedTuple("LocalCredentials",
    [
        ("creds", str),
        ("project", str),
    ])
    return LocalCredentials(creds, project)

local_creds = get_google_cloud_credentials()

client = secretmanager.SecretManagerServiceClient(credentials=local_creds.creds)

secret_resource_name = f"projects/{local_creds.project}/secrets/reddit-api-key/versions/1"
response = client.access_secret_version(request={"name": secret_resource_name})
payload = response.payload.data.decode("UTF-8")

print(json.loads(payload))

## Troubleshoot Firestore component code

```json
{
    'imageWidth': 4620, 
    'imageHeight': 2940, 
    'cellOffsetX': 0, 
    'cellOffsetY': 0, 
    'cellWidth': 140, 
    'cellHeight': 140, 
    'path': 'Abandoned Mine Entrance [33x21] - $5 Rewards/Gridded/G_AbandonedMineEntrance_Crystal.jpg'
}

```

In [220]:
project_id = PROJECT_ID
collection_name = "FantasyMapsTest"
gcs_bucket_name = "fantasy-maps"
gcs_prefix_name = "ScrapedData"
csv_input_file = "ScrapedData/reddit-scraped-20220404223231.csv"

from datetime import datetime
import hashlib
from io import BytesIO
import json
import pandas as pd
from PIL import Image
import regex as re
import requests
import shutil

from google.cloud import firestore
from google.cloud import storage

storage_client = storage.Client(project=project_id)
bucket = storage_client.bucket(gcs_bucket_name)

firestore_client = firestore.Client(project=project_id)
collection_ref = firestore_client.collection(collection_name)

blob = bucket.blob(csv_input_file)
csv_bytes = blob.download_as_string()
csv_buffer = BytesIO(csv_bytes)

jpg_df = pd.read_csv(csv_buffer)

hashes = [None] * len(jpg_df.index)
jpg_df.insert(1, "HashId", hashes, True)
jpg_df.insert(6, "GcsURI", hashes, True)

# Concatenate string of batch prediction inputs
bp_inputs = ""

def make_nice_filename(name):
    regex = "[\s|\(|\"|\)]"
    new_name = re.sub(regex, "_", name)
    new_name = new_name.lower()[:30]
    new_name = new_name.replace("__", "_")
    return f"{new_name}.jpg"


def create_vtt_json(content, title):
    img = Image.open(BytesIO(content))
    w, h = img.size
    
    dims = re.findall("\d+x\d+", title)
    if len(dims) is 0:
        return None
    
    dims = dims[0].split("x")
    
    if len(dims) is not 2:
        return None
    
    rows = int(dims[0])
    cols = int(dims[1])
    
    cell_w = w / rows
    cell_h = h / cols
    if cell_w != cell_h:
        return None
    
    return {
        "imageWidth": w,
        "imageHeight": h,
        "cellOffsetX": 0,
        'cellOffsetY': 0, 
        'cellWidth': cell_w, 
        'cellHeight': cell_h, 
    }
    

# Iterate over JPG URIs, download them in batches, convert to sha values
for i, r in jpg_df.iterrows():
    jpg_url = r["URL"]
    title = r["Title"]
    
    req = requests.get(jpg_url, stream=True)
    if req.status_code == 200:
        req.raw.decode_content = True
        sha1 = hashlib.sha1()
        jpg_hash = sha1.update(req.content)
        jpg_hash = sha1.hexdigest()
        
        jpg_df["HashId"][i] = jpg_hash
        #print(f"Index {i}, hash {jpg_hash}")
        hashes.append(jpg_hash)
        
        # Try to fetch each document from Firestore. If it does not exist,
        # overwrite and download the image.
        doc_ref = collection_ref.document(jpg_hash)
        doc = doc_ref.get()
        if not doc.exists:
            
            file_name = make_nice_filename(title)
            img_gcs_uri = f"gs://{gcs_bucket_name}/{gcs_prefix_name}/{file_name}"
            blob_name = f"{gcs_prefix_name}/{file_name}"
            
            file_blob = bucket.blob(blob_name)
            image_buffer = BytesIO(req.content)
            
            # Get image grid metadata
            img_data = create_vtt_json(req.content, title)
            print(img_data)
            
            file_blob.upload_from_file(BytesIO(req.content))
            
            data = {
                u"filename": file_name,
                u"gcsURI": img_gcs_uri,
                u"source": gcs_prefix_name,
                u"vttData": img_data,
                u"userId": "None",
            }
            doc_ref.set(data)
            print(f"Set data: {data}")
            bp_inputs += json.dumps({ "content": img_gcs_uri, "mimeType": "image/jpeg"})
            bp_inputs += "\n"

# No fresh JPGs in this scraping; return empty string
if bp_inputs is "":
    # return ""
    print("no inputs")
            
print(f"First ten: {jpg_df.head(10)}")

# Save the batch_predict file
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 
batch_predict_file_uri = f"gs://{gcs_bucket_name}/{gcs_prefix_name}/bp_input_{timestamp}.jsonl"

bp_blob_name = f"{gcs_prefix_name}/bp_input_{timestamp}.jsonl"
bp_blob = bucket.blob(bp_blob_name)

bp_blob.upload_from_string(bp_inputs)

print(batch_predict_file_uri)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


{'imageWidth': 1610, 'imageHeight': 4060, 'cellOffsetX': 0, 'cellOffsetY': 0, 'cellWidth': 70.0, 'cellHeight': 70.0}
Set data: {'filename': 'feywild_waterfalls_+_50%_disco.jpg', 'gcsURI': 'gs://fantasy-maps/ScrapedData/feywild_waterfalls_+_50%_disco.jpg', 'source': 'ScrapedData', 'vttData': {'imageWidth': 1610, 'imageHeight': 4060, 'cellOffsetX': 0, 'cellOffsetY': 0, 'cellWidth': 70.0, 'cellHeight': 70.0}, 'userId': 'None'}
None
Set data: {'filename': 'neon_alley_-_an_experiment_usi.jpg', 'gcsURI': 'gs://fantasy-maps/ScrapedData/neon_alley_-_an_experiment_usi.jpg', 'source': 'ScrapedData', 'vttData': None, 'userId': 'None'}
{'imageWidth': 3360, 'imageHeight': 3360, 'cellOffsetX': 0, 'cellOffsetY': 0, 'cellWidth': 70.0, 'cellHeight': 70.0}
Set data: {'filename': 'a_small_castle_in_a_forest_[48.jpg', 'gcsURI': 'gs://fantasy-maps/ScrapedData/a_small_castle_in_a_forest_[48.jpg', 'source': 'ScrapedData', 'vttData': {'imageWidth': 3360, 'imageHeight': 3360, 'cellOffsetX': 0, 'cellOffsetY': 0

KeyboardInterrupt: 

## Troubleshoot BP output to FS component

In [212]:
! pip install jsonlines



In [125]:
bp_resource = "projects/733537716875/locations/us-central1/batchPredictionJobs/3414681244372303872"
collection_name = "FantasyMapsTest"
gcs_bucket_name = "fantasy-maps"
project = "fantasymaps-334622"
location = "us-central1"
minimum_confidence = 0.15

In [221]:
def save_bp_output_to_firestore(
    bp_resource=bp_resource,
    collection_name=collection_name,
    gcs_bucket_name=gcs_bucket_name,
    project=project,
    location=location,
    minimum_confidence=minimum_confidence):
    
    import json
    
    from google.cloud import aiplatform as aip
    from google.cloud import firestore as fs
    
    aip.init(project=project, location=location)
    
    bp_job = aip.BatchPredictionJob(
        batch_prediction_job_name=bp_resource)
    
    output_info = bp_job.output_info
    
    # Get the predictions out of GCS
    predictions = []
    for out in bp_job.iter_outputs():
        out_str = out.download_as_string()
        p = out_str.decode("utf-8")
        
        ps = p.split("\n")
        predictions.extend(ps)

    if len(predictions) is 0:
        return
    
    fs_client = fs.Client()
    collection_ref = fs_client.collection(collection_name)
    
    docs = []
    prediction_data = dict()
    
    # Query Firestore for all documents relevant to these predictions
    for p in predictions:
        try:
            data = json.loads(p)
            instance = data["instance"]["content"]
            prediction_data[instance] = data
            docs_ref = collection_ref.where("gcsURI", "==", instance).stream()
            
            docs_tmp = [doc for doc in docs_ref]
            docs.extend(docs_tmp)
            
        except json.JSONDecodeError as e:
            print(p)
    
    print(f"Images processed: {len(docs)}")
        
    # Update all of the Firestore documents with the predictions
    for d in docs:
        doc_dict = d.to_dict()
        gcsURI = doc_dict["gcsURI"]
        doc_predictions = prediction_data[gcsURI]
        
        # Iterate over bboxes and labels to create
        # training-ready data
        bboxes = doc_predictions["prediction"]["bboxes"]
        labels = doc_predictions["prediction"]["displayNames"]
        confidences = doc_predictions["prediction"]["confidences"]
        
        training_data = []
        
        for i, e in enumerate(bboxes):
            confidence = confidences[i]
            if confidence >= minimum_confidence:
                training_data.append({
                    "displayName": labels[i],
                    "xMin": e[0],
                    "xMax": e[1],
                    "yMin": e[2],
                    "yMax": e[3],
                })
        
        # If training_data is empty for this image, skip
        if len(training_data) is 0:
            continue
            
        d.reference.set({"predictedBBoxes": training_data}, merge=True)    

In [144]:
save_bp_output_to_firestore()


65
gs://fantasy-maps/ScrapedData/the_cat_sanctury_needs_defendi.jpg
gs://fantasy-maps/ScrapedData/fae_village_-_mirage_[30x50]_[.jpg
gs://fantasy-maps/ScrapedData/snowy_field.jpg
gs://fantasy-maps/ScrapedData/the_citadel_of_ash_[54x88].jpg
gs://fantasy-maps/ScrapedData/desert_ruins_[battlemap][2304x.jpg
gs://fantasy-maps/ScrapedData/the_rift.jpg
gs://fantasy-maps/ScrapedData/winter_battle.jpg
gs://fantasy-maps/ScrapedData/[1960x2940][28x42]_into_the_mi.jpg
gs://fantasy-maps/ScrapedData/{the_road_to_bellshire}_-enjoy.jpg
gs://fantasy-maps/ScrapedData/old_school_battle_map,_stock_d.jpg
gs://fantasy-maps/ScrapedData/sith_temple.jpg
gs://fantasy-maps/ScrapedData/[oc]_desert_monorail_heist_bat.jpg
gs://fantasy-maps/ScrapedData/3rd_time_lucky...._hi_red_spac.jpg
gs://fantasy-maps/ScrapedData/the_broken_bridge_[51x62]_[cav.jpg
gs://fantasy-maps/ScrapedData/defend_the_wall_!.jpg
gs://fantasy-maps/ScrapedData/hillside_expedition_store_[13x.jpg
gs://fantasy-maps/ScrapedData/the_forgotten_temple

## Create a custom Reddit pipelines component

In [180]:
from typing import NamedTuple

import kfp
from kfp import dsl
from kfp.v2 import compiler
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, ClassificationMetrics, Metrics, component)
from kfp.v2.google.client import AIPlatformClient

from google.cloud import aiplatform
from google_cloud_pipeline_components import aiplatform as gcc_aip

In [192]:
@component(packages_to_install=["praw",
                                "google-cloud-secret-manager",
                                "google-cloud-storage",
                                "numpy",
                                "pandas"],
           output_component_file="reddit.yaml")
def reddit(
    secret_name: str,
    subreddit_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    project_id: str,
) -> str:
    from datetime import datetime
    import numpy as np
    import pandas as pd
    import praw
    import regex as re
    
    from google.cloud import storage

    def get_reddit_credentials(project_id):
        from google.cloud import secretmanager
        import json

        client = secretmanager.SecretManagerServiceClient()

        secret_resource_name = f"projects/{project_id}/secrets/{secret_name}/versions/1"
        response = client.access_secret_version(request={"name": secret_resource_name})
        payload = response.payload.data.decode("UTF-8")

        return json.loads(payload)
    
    def get_reddit_posts(reddit_credentials):
        import praw

        reddit = praw.Reddit(client_id=reddit_credentials["client_id"], 
                     client_secret=reddit_credentials["secret"],
                     user_agent=reddit_credentials["user_agent"])
        print(f"Reddit is in read-only mode: {reddit.read_only}")
        return reddit.subreddit(subreddit_name).hot(limit=100)
    
    nan_value = float("NaN")
    
    print(f"Project ID is: {project_id}")
    
    # Get the data from Reddit
    credentials = get_reddit_credentials(project_id)
    posts = get_reddit_posts(credentials)
    
    posts = filter(lambda p: len(re.findall("\d+x\d+", p.title)) > 0, posts)
    
    # Filter the posts the data that we want and store as DataFrame
    filtered_posts = [[s.title, s.selftext, s.id, s.url] for s in posts]

    filtered_posts = np.array(filtered_posts)
    reddit_posts_df = pd.DataFrame(filtered_posts,
                               columns=['Title', 'Post', 'ID', 'URL'])

    reddit_posts_df.replace("", nan_value, inplace=True)
    reddit_posts_df = reddit_posts_df[reddit_posts_df["Post"] != nan_value]
    
    jpg_df = reddit_posts_df.loc[reddit_posts_df["URL"].str.contains("jpg")]
    jpg_df.head(10)
    
    # Save the dataframe as CSV in Storage
    csv_str = jpg_df.to_csv()
    
    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(gcs_bucket_name)
    
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    
    csv_file_uri = f"{gcs_prefix_name}/reddit-scraped-{subreddit_name}-{timestamp}.csv"
    
    file_blob = bucket.blob(csv_file_uri)
    file_blob.upload_from_string(csv_str)
    
    return csv_file_uri
    

## Create the Firestore component

In [193]:
from typing import NamedTuple

@component(packages_to_install=["google-cloud-firestore",
                                "google-cloud-storage",
                                "numpy",
                                "pandas",
                                "Pillow"],
           output_component_file="firestore.yaml")
def firestore(
    subreddit_name: str,
    collection_name: str,
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    csv_input_file: str,
    project_id: str,
) -> NamedTuple(
    "Outputs",
    [
        ("batch_predict_file_uri", str),
        ("bp_inputs_count", int),
    ]
):
    
    from datetime import datetime
    import hashlib
    from io import BytesIO
    import json
    import pandas as pd
    from PIL import Image
    import re
    import requests
    import shutil

    from google.cloud import firestore
    from google.cloud import storage

    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(gcs_bucket_name)

    firestore_client = firestore.Client(project=project_id)
    collection_ref = firestore_client.collection(collection_name)

    blob = bucket.blob(csv_input_file)
    csv_bytes = blob.download_as_string()
    csv_buffer = BytesIO(csv_bytes)

    jpg_df = pd.read_csv(csv_buffer)

    hashes = [None] * len(jpg_df.index)
    jpg_df.insert(1, "HashId", hashes, True)
    jpg_df.insert(6, "GcsURI", hashes, True)

    # Concatenate string of batch prediction inputs
    bp_inputs = ""
    bp_inputs_count = 0

    def make_nice_filename(name):
        regex = "[\s|\(|\"|\)]"
        new_name = re.sub(regex, "_", name)
        new_name = new_name.lower()[:30]
        new_name = new_name.replace("__", "_")
        return f"{new_name}.jpg"

    # Iterate over JPG URIs, download them in batches, convert to sha values
    for i, r in jpg_df.iterrows():
        jpg_url = r["URL"]
        title = r["Title"]

        req = requests.get(jpg_url, stream=True)
        if req.status_code == 200:
            req.raw.decode_content = True
            sha1 = hashlib.sha1()
            jpg_hash = sha1.update(req.content)
            jpg_hash = sha1.hexdigest()

            jpg_df["HashId"][i] = jpg_hash
            #print(f"Index {i}, hash {jpg_hash}")
            hashes.append(jpg_hash)

            # Try to fetch each document from Firestore. If it does not exist,
            # overwrite and download the image.
            doc_ref = collection_ref.document(jpg_hash)
            doc = doc_ref.get()
            if not doc.exists:

                file_name = make_nice_filename(title)
                img_gcs_uri = f"gs://{gcs_bucket_name}/{gcs_prefix_name}/{file_name}"
                blob_name = f"{gcs_prefix_name}/{file_name}"

                file_blob = bucket.blob(blob_name)
                image_buffer = BytesIO(req.content)
                
                # Get the image size and width
                img = Image.open(image_buffer)
                width, height = img.size
                
                file_blob.upload_from_file(image_buffer)

                data = {
                    u"filename": file_name,
                    u"gcsURI": img_gcs_uri,
                    u"source": gcs_prefix_name,
                    u"userId": "None",
                }
                doc_ref.set(data)
                print(f"Set data: {data}")
                bp_inputs += json.dumps({ "content": img_gcs_uri, "mimeType": "image/jpeg"})
                bp_inputs += "\n"
                bp_inputs_count += 1

    print(f"bp_inputs_count={bp_inputs_count}")

    # Save the batch_predict file
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 

    bp_blob_name = f"{gcs_prefix_name}/bp_input_{subreddit_name}_{timestamp}.jsonl"
    batch_predict_file_uri = f"gs://{gcs_bucket_name}/{bp_blob_name}"
    bp_blob = bucket.blob(bp_blob_name)

    bp_blob.upload_from_string(bp_inputs)
    
    return (batch_predict_file_uri, bp_inputs_count)
    

## Create a (better) batch prediction component

In [194]:
from typing import List

@component(packages_to_install=["google-cloud-aiplatform"],
           output_component_file="batch_prediction.yaml")
def batch_prediction(
    gcs_bucket_name: str,
    gcs_prefix_name: str,
    input_file_1: str,
    input_file_2: str,
    project_id: str,
    location: str,
    model_id: str,
) -> str:
    
    from google.cloud import aiplatform as aip
    from datetime import datetime

    csv_input_files = [input_file_1, input_file_2]
    
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    aip.init(project=project_id, location=location)
    
    model_resource_name = f"projects/{project_id}/locations/{location}/models/{model_id}"
    model = aip.Model(model_resource_name)
    
    batch_prediction_job = model.batch_predict(
        job_display_name=f"reddit-scraping-batch-predict-{timestamp}",
        gcs_source=csv_input_files,
        gcs_destination_prefix=f"gs://{gcs_bucket_name}/{gcs_prefix_name}",
        sync=True
    )
    
    batch_prediction_job.wait()

    print(batch_prediction_job.display_name)
    print(batch_prediction_job.resource_name)
    print(batch_prediction_job.state)
    return str(batch_prediction_job.resource_name)

## Create a prediction-to-Firestore component

In [195]:
@component(packages_to_install=["google-cloud-aiplatform", "google-cloud-firestore"],
           output_component_file="prediction_to_firestore.yaml")
def save_bp_output_to_firestore(
    bp_resource: str,
    collection_name: str,
    gcs_bucket_name: str,
    project: str,
    location: str,
    minimum_confidence: float):
    
    import json
    
    from google.cloud import aiplatform as aip
    from google.cloud import firestore as fs
    
    aip.init(project=project, location=location)
    
    bp_job = aip.BatchPredictionJob(
        batch_prediction_job_name=bp_resource)
    
    output_info = bp_job.output_info
    
    # Get the predictions out of GCS
    predictions = []
    for out in bp_job.iter_outputs():
        out_str = out.download_as_string()
        p = out_str.decode("utf-8")
        
        ps = p.split("\n")
        predictions.extend(ps)

    if len(predictions) is 0:
        return
    
    fs_client = fs.Client(project=project)
    collection_ref = fs_client.collection(collection_name)
    
    docs = []
    prediction_data = dict()
    
    # Query Firestore for all documents relevant to these predictions
    for p in predictions:
        try:
            data = json.loads(p)
            instance = data["instance"]["content"]
            prediction_data[instance] = data
            docs_ref = collection_ref.where("gcsURI", "==", instance).stream()
            
            docs_tmp = [doc for doc in docs_ref]
            docs.extend(docs_tmp)
            
        except json.JSONDecodeError as e:
            print(p)
    
    print(f"Images processed: {len(docs)}")
        
    # Update all of the Firestore documents with the predictions
    for d in docs:
        doc_dict = d.to_dict()
        gcsURI = doc_dict["gcsURI"]
        doc_predictions = prediction_data[gcsURI]
        
        # Iterate over bboxes and labels to create
        # training-ready data
        bboxes = doc_predictions["prediction"]["bboxes"]
        labels = doc_predictions["prediction"]["displayNames"]
        confidences = doc_predictions["prediction"]["confidences"]
        
        training_data = []
        
        for i, e in enumerate(bboxes):
            confidence = confidences[i]
            if confidence >= minimum_confidence:
                training_data.append({
                    "displayName": labels[i],
                    "xMin": e[0],
                    "xMax": e[1],
                    "yMin": e[2],
                    "yMax": e[3],
                })
        
        # If training_data is empty for this image, skip
        if len(training_data) is 0:
            continue
            
        d.reference.set({"predictedBBoxes": training_data}, merge=True) 

## Build a simple pipeline

In [201]:
GCS_BUCKET = "fantasy-maps"
GCS_PREFIX = "ScrapedData"
MODEL_ID = "7292897899317297152"
LOCATION = "us-central1"
COLLECTION_NAME = "FantasyMaps"
print(PROJECT_ID)    

fantasymaps-334622


In [None]:
# Clear out the test collection before continuing

# NOTE: DOES NOT WORK :S
if COLLECTION_NAME.find("Test") > -1:
    
    print(f"Deleting {COLLECTION_NAME} ...")
    from google.cloud import firestore
    
    client = firestore.Client(project=PROJECT_ID)
    coll_ref = client.collection(COLLECTION_NAME)
    client.recursive_delete(coll_ref)
    
    #client.collection(COLLECTION_NAME)

In [205]:
@dsl.pipeline(
    name="reddit-scraper-pipeline",
    description="Gets data from a subreddit",
    pipeline_root=f"gs://{GCS_BUCKET}/pipeline_root",
)
def reddit_pipeline(
    collection_name: str = COLLECTION_NAME,
    secret_name: str = "reddit-api-key",
    subreddit_name_1: str = "battlemaps",
    subreddit_name_2: str = "FantasyMaps",
    gcs_bucket: str = GCS_BUCKET,
    gcs_prefix: str = GCS_PREFIX,
    project_id: str = PROJECT_ID,
    location: str = LOCATION,
    model_id: str = MODEL_ID,
):
    
    # First stream of Reddit scraping
    reddit_op_1 = reddit(
        secret_name=secret_name,
        subreddit_name=subreddit_name_1,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        project_id=project_id
    )
    
    reddit_csv_file_1 = reddit_op_1.output
    
    firestore_op_1 = firestore(
        subreddit_name=subreddit_name_1,
        collection_name=collection_name,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        csv_input_file=reddit_csv_file_1,
        project_id=project_id,
    )
    
    # Second stream of Reddit scraping
    reddit_op_2 = reddit(
        secret_name=secret_name,
        subreddit_name=subreddit_name_2,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        project_id=project_id
    )

    reddit_csv_file_2 = reddit_op_2.output
    
    firestore_op_2 = firestore(
        subreddit_name=subreddit_name_2,
        collection_name=collection_name,
        gcs_bucket_name=gcs_bucket,
        gcs_prefix_name=gcs_prefix,
        csv_input_file=reddit_csv_file_2,
        project_id=project_id,
    )
  
    #inputs_count = firestore_op_1.outputs["bp_inputs_count"] + firestore_op_2.outputs["bp_inputs_count"] 

    with dsl.Condition((firestore_op_1.outputs["bp_inputs_count"] > 0)
                      or (firestore_op_2.outputs["bp_inputs_count"] > 0), name="hasBPInputs"):
        batch_prediction_op = batch_prediction(
            gcs_bucket_name=gcs_bucket,
            gcs_prefix_name=gcs_prefix,
            input_file_1=firestore_op_1.outputs["batch_predict_file_uri"],
            input_file_2=firestore_op_2.outputs["batch_predict_file_uri"],
            project_id=project_id,
            location=location,
            model_id=model_id)

        # Set minimum_confidence as global--should be configurable
        min_confidence = 0.15

        # Update Firestore with BP results
        save_bp_output_to_firestore(
            bp_resource=batch_prediction_op.output,
            collection_name=collection_name,
            gcs_bucket_name=gcs_bucket,
            project=project_id,
            location=location,
            minimum_confidence=min_confidence
        )
    

In [206]:
compiler.Compiler().compile(
    pipeline_func=reddit_pipeline, package_path="reddit_scraper_pipeline_job.json"
)

In [207]:
api_client = AIPlatformClient(
    project_id=PROJECT_ID,
    region=LOCATION,
)

In [208]:
response = api_client.create_run_from_job_spec(
    job_spec_path="reddit_scraper_pipeline_job.json",
)