In [12]:
from pathlib import Path
import json
import tensorflow as tf
import numpy as np
from google.cloud import storage
from google.cloud import aiplatform

In [7]:
project_id = "nifty-quanta-390607"
bucket_name = "awesome_mnist_bucket"
location = "us-central1"

In [10]:
storage_client = storage.Client(project=project_id)
bucket = storage_client.get_bucket(bucket_name)

In [3]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [6]:
def upload_directory(bucket, dirpath):
    dirpath = Path(dirpath)
    for filepath in dirpath.glob("**/*"):
        if filepath.is_file():
            blob = bucket.blob(filepath.relative_to(dirpath.parent).as_posix())
            blob.upload_from_filename(filepath)

In [5]:
batch_path = Path("my_mnist_batch")
batch_path.mkdir(exist_ok=True)
X_test_100 = np.expand_dims(X_test[:100], axis=-1)/255
with open(batch_path / "my_mnist_batch.jsonl", "w") as jsonl_file:
    for image in X_test_100.tolist():
        jsonl_file.write(json.dumps(image))
        jsonl_file.write("\n")

In [11]:
upload_directory(bucket, batch_path)

In [13]:
model_id = "4606893247241912320"
mnist_model = aiplatform.Model(model_name=f"projects/{project_id}/locations/{location}/models/{model_id}")

### batch prediction using model (no multiple api calls)

In [14]:
batch_prediction_job = mnist_model.batch_predict(
    job_display_name="my_batch_prediction_job",
    machine_type="n1-standard-4",
    starting_replica_count=1,
    accelerator_type="NVIDIA_TESLA_P4",
    accelerator_count=1,
    gcs_source=[f"gs://{bucket_name}/{batch_path.name}/my_mnist_batch.jsonl"],
    gcs_destination_prefix=f"gs://{bucket_name}/my_mnist_predictions/",
    sync=True # set to False if you don't want to wait for completion
)

Creating BatchPredictionJob
BatchPredictionJob created. Resource name: projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808
To use this BatchPredictionJob in another session:
bpj = aiplatform.BatchPredictionJob('projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808')
View Batch Prediction Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/batch-predictions/929614094619639808?project=83375718718
BatchPredictionJob projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808 current state:
JobState.JOB_STATE_PENDING
BatchPredictionJob projects/83375718718/locations/us-central1/batchPredictionJobs/9296140946196

In [15]:
y_probas = []
for blob in batch_prediction_job.iter_outputs():
    if "prediction.results" in blob.name:
        for line in blob.download_as_text().splitlines():
            y_proba = json.loads(line)["prediction"]
            y_probas.append(y_proba)

In [16]:
y_pred = np.argmax(y_probas, axis=1)
accuracy = np.sum(y_pred == y_test[:100]) / 100
accuracy

1.0

### deleting resources to prevent billing

In [17]:
for prefix in ["my_mnist_model/", "my_mnist_batch/", "my_mnist_predictions/"]:
    blobs = bucket.list_blobs(prefix=prefix)
    for blob in blobs:
        blob.delete()
bucket.delete()  # if the bucket is empty
batch_prediction_job.delete()

Deleting BatchPredictionJob : projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808
BatchPredictionJob deleted. . Resource name: projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808
Deleting BatchPredictionJob resource: projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808
Delete BatchPredictionJob backing LRO: projects/83375718718/locations/us-central1/operations/5200372412318744576
BatchPredictionJob resource projects/83375718718/locations/us-central1/batchPredictionJobs/929614094619639808 deleted.
