In [5]:
import tensorflow as tf
import apache_beam as beam
import json
from oauth2client.client import GoogleCredentials
import requests

In [2]:
class ModelPredict:
    def __init__(self, project, region, endpoint_id):
        self._api = "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/endpoints/{}:predict".format(
                                                                                                                   region,
                                                                                                                   project,
                                                                                                                   region,
                                                                                                                   endpoint_id
                                                                                                                  )   
        
    def __call__(self, filenames):        
        token = GoogleCredentials.get_application_default().get_access_token().access_token
        if isinstance(filenames, str):
            # Only one element, put it into a batch of 1.
            data = {
                    "instances": [
                                  {"filenames": filenames}
                                 ]
                   }
        else:
            data = {
                    "instances": []
                   }
            for f in filenames:
                data["instances"].append({
                                          "filenames" : f
                                        })
        # print(data)
        headers = {"Authorization": "Bearer " + token }
        response = requests.post(self._api,
                                 json=data,
                                 headers=headers)
        response = json.loads(response.content.decode("utf-8"))
        
        if isinstance(filenames, str):
            result = response["predictions"][0]
            result["filename"] = filenames
            yield result
        else:
            for (a,b) in zip(filenames, response["predictions"]):
                result = b
                result["filename"] = a
                yield result

In [3]:
PROJECT = "kubeflow-1-0-2"
REGION = "us-central1"
ENDPOINT_ID = "2613644692339818496"

In [6]:
filenames = tf.io.gfile.glob("gs://fire_detection_anurag/test_images/*")

filenames

['gs://fire_detection_anurag/test_images/batchinputs.jsonl',
 'gs://fire_detection_anurag/test_images/fire1.jpg',
 'gs://fire_detection_anurag/test_images/fire2.jpg',
 'gs://fire_detection_anurag/test_images/fire3.jpg',
 'gs://fire_detection_anurag/test_images/fire4.jpg',
 'gs://fire_detection_anurag/test_images/fire5.jpg',
 'gs://fire_detection_anurag/test_images/no_fire1.jpg',
 'gs://fire_detection_anurag/test_images/no_fire2.jpg',
 'gs://fire_detection_anurag/test_images/no_fire3.jpg',
 'gs://fire_detection_anurag/test_images/no_fire4.jpg',
 'gs://fire_detection_anurag/test_images/no_fire5.jpg',
 'gs://fire_detection_anurag/test_images/prediction-fire20211105140403-2021-11-05T15:04:29.256648Z',
 'gs://fire_detection_anurag/test_images/test.jsonl']

In [7]:
filenames[1:-2]

['gs://fire_detection_anurag/test_images/fire1.jpg',
 'gs://fire_detection_anurag/test_images/fire2.jpg',
 'gs://fire_detection_anurag/test_images/fire3.jpg',
 'gs://fire_detection_anurag/test_images/fire4.jpg',
 'gs://fire_detection_anurag/test_images/fire5.jpg',
 'gs://fire_detection_anurag/test_images/no_fire1.jpg',
 'gs://fire_detection_anurag/test_images/no_fire2.jpg',
 'gs://fire_detection_anurag/test_images/no_fire3.jpg',
 'gs://fire_detection_anurag/test_images/no_fire4.jpg',
 'gs://fire_detection_anurag/test_images/no_fire5.jpg']

In [8]:
with beam.Pipeline() as p:    
    (p 
     | "getinput" >> beam.Create(filenames[1:-2]) 
     | "batch" >> beam.BatchElements(min_batch_size=2,
                                     max_batch_size=3)
     | "getpred" >> beam.FlatMap(ModelPredict(PROJECT,
                                              REGION,
                                              ENDPOINT_ID))
     | "write" >> beam.Map(print)
    )



{'image_type_str': 'Fire', 'image_type_int': 0, 'probability': 0.994521737, 'filename': 'gs://fire_detection_anurag/test_images/fire1.jpg'}
{'image_type_int': 0, 'image_type_str': 'Fire', 'probability': 0.711659372, 'filename': 'gs://fire_detection_anurag/test_images/fire2.jpg'}
{'probability': 0.987360537, 'image_type_int': 0, 'image_type_str': 'Fire', 'filename': 'gs://fire_detection_anurag/test_images/fire3.jpg'}
{'probability': 0.983968496, 'image_type_int': 0, 'image_type_str': 'Fire', 'filename': 'gs://fire_detection_anurag/test_images/fire4.jpg'}
{'probability': 0.995875299, 'image_type_str': 'Fire', 'image_type_int': 0, 'filename': 'gs://fire_detection_anurag/test_images/fire5.jpg'}
{'probability': 0.827532887, 'image_type_str': 'No-Fire', 'image_type_int': 1, 'filename': 'gs://fire_detection_anurag/test_images/no_fire1.jpg'}
{'image_type_str': 'No-Fire', 'probability': 0.990039051, 'image_type_int': 1, 'filename': 'gs://fire_detection_anurag/test_images/no_fire2.jpg'}
{'probab