# Creating and training a custom model

In this section of the tutorial you will,

* Download the data from GCS and process to be appropriate for training a model
* Create a model which uses BERT as the base
* Train the model on the processed data
* Save the model and upload to GCS

<table align="left">
    <td>
        <a target="_blank" href="https://colab.research.google.com/github/thushv89/gcp-tf-review-classification/blob/master/training_custom_model.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png"/>Run in Google Colab</a>
    </td>
</table>

In [None]:
!pip3 install --upgrade  pydantic google-cloud-aiplatform google-cloud-storage "shapely<2" tensorflow-text==2.9.0

In [None]:
import os

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
PROJECT_ID = "gdsc-tensorflow-workshop"
BUCKET_URI = "" # e.g. gs://imdb-movie-review-dataset-thga
MODEL_BUCKET_URI = "" # e.g. "gs://imdb-movie-review-models-thga" 
REGION = "us-central1"

In [None]:
os.environ["PROJECT_ID"] = PROJECT_ID
os.environ["REGION"] = REGION
os.environ["BUCKET_URI"] = BUCKET_URI
os.environ["MODEL_BUCKET_URI"] = MODEL_BUCKET_URI

In [None]:
!gcloud config set project $PROJECT_ID

In [None]:
from google.cloud import aiplatform

DATASET_NAME = "imdb-review-dataset"
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

In [None]:
datasets = aiplatform.TextDataset.list()

DATASET_NAME = "imdb-review-dataset"

DATASET_RESOURCE_NAME = ""
for dataset in datasets:
    if dataset.display_name == DATASET_NAME:
        DATASET_RESOURCE_NAME = dataset.resource_name
        break
print(f"Dataset resource name: {DATASET_RESOURCE_NAME}")

In [None]:
dataset = aiplatform.TextDataset(aiplatform.TextDataset.list()[0].resource_name)

# Automatically creates a directory with the name exported_data_<datetime> - no need to provide
exported_files = dataset.export_data(output_dir=BUCKET_URI)

print("Following files were exported")
print(exported_files)

In [None]:
import pydantic 
from typing import Any, Dict,List,Literal

class ClassificationAnnotation(pydantic.BaseModel):
    displayName: Literal["positive", "negative"]

class DataItemResourceLabels(pydantic.BaseModel):
    ml_use: Literal["training", "validation", "test"] = pydantic.Field(alias="aiplatform.googleapis.com/ml_use")
    # Enables us to use ml_use=<x> instead of the long field name
    class Config:
        allow_population_by_field_name = True

class TextClassificationSample(pydantic.BaseModel):
    textContent: str
    classificationAnnotation: ClassificationAnnotation 
    dataItemResourceLabels: DataItemResourceLabels

In [None]:
import random
from google.cloud import storage 
random.seed(946021)

# TODO: rename to read_from_gcs
def read_gcs_with_full_path(storage_client, bucket_name, blob_name):

    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_name)

    with blob.open("r") as f:
        data = f.read()

    return data


def generate_single_instance(bucket_name, blob_name, ml_use, storage_client):

    label = None
    if blob_name.endswith(".txt"):
        if "pos" in blob_name:
            label = "positive"
        elif "neg" in blob_name:
            label = "negative"
        if label:
            instance = TextClassificationSample(
                textContent=read_gcs_with_full_path(
                    storage_client=storage_client,
                    bucket_name=bucket_name,
                    blob_name=blob_name, 
                ),
                classificationAnnotation=ClassificationAnnotation(displayName=label),
                dataItemResourceLabels=DataItemResourceLabels(ml_use=ml_use)
            )
            return instance
    return None

def create_instances(bucket_uri):

    storage_client = storage.Client()

    train_gcs_bucket_prefix = "train"
    test_gcs_bucket_prefix = "test"
    bucket_name = bucket_uri[5:]

    train_instances = []

    # delimiter only return the items in that directory (exclude subdirs)
    train_blobs = storage_client.list_blobs(bucket_name, prefix=train_gcs_bucket_prefix)
  
    print(f"Reading training data from the GCS bucket")
    for b in train_blobs:
        instance = generate_single_instance(
            bucket_name=bucket_name, blob_name=b.name, ml_use="training", storage_client=storage_client
        )
        if instance:
            train_instances.append(instance)
    print(f"\tFound {len(train_instances)} train instances")

    test_instances = []
    valid_count, test_count = 0,0

    test_blobs = storage_client.list_blobs(bucket_name, prefix=test_gcs_bucket_prefix)
    print(f"Reading test data from the GCS bucket")
    for b in test_blobs:
        if random.uniform(0,1.0)<0.5:
            valid_count += 1
            ml_use="validation"
        else:
            test_count += 1
            ml_use="test"

        instance = generate_single_instance(
            bucket_name=bucket_name, blob_name=b.name, ml_use=ml_use, storage_client=storage_client
        )
        if instance: 
            test_instances.append(instance)

    print(f"\tFound {valid_count} validation instances and {test_count} test instances")

    instances = train_instances + test_instances
    datasets = {"training":{"inputs":[], "labels": []}, "validation":{"inputs":[], "labels": []}, "test":{"inputs":[], "labels": []}}
    label_map = {"positive": 1, "negative": 0}
    for ins in instances:
        datasets[ins.dataItemResourceLabels.ml_use]["inputs"].append(ins.textContent)
        datasets[ins.dataItemResourceLabels.ml_use]["labels"].append(label_map[ins.classificationAnnotation.displayName])
    
    return datasets

datasets = create_instances(BUCKET_URI)

In [None]:
datasets["training"]["inputs"][0]

## Downloading the base model from TFHub

In [None]:
import tensorflow_hub as hub
import tensorflow as tf 
# Unless this import is here, the following error comes up
# Error Op type not registered 'CaseFoldUTF8' in binary running on 932fd13e3432. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
# You may be trying to load on a different device from the computational device. Consider setting the `experimental_io_device` option in `tf.saved_model.LoadOptions` to the io_device such as '/job:localhost'.
import tensorflow_text

tf.keras.backend.clear_session()

def download_base_model() -> tf.keras.Model:
    preprocessor = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
    )
    encoder = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/2",
        trainable=False
    )

    return preprocessor, encoder

## Creating the full TensorFlow model

In [None]:
def create_model(preprocessor: hub.KerasLayer, encoder: hub.KerasLayer) -> tf.keras.Model:
    """ Use the pretrained base and mount a head for sentiment analysis """

    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)

    encoder_inputs = preprocessor(text_input)

    encoder_outputs = encoder(encoder_inputs)

    pooled_output = encoder_outputs["pooled_output"]      # [batch_size, 128].
    hidden_layer = tf.keras.layers.Dense(256, activation="gelu")
    classif_layer = tf.keras.layers.Dense(1, activation="sigmoid")
    hidden_out = hidden_layer(pooled_output)
    final_out = classif_layer(hidden_out)

    model = tf.keras.Model(inputs=text_input, outputs=final_out)
    model.compile(
        loss="binary_crossentropy", 
        optimizer=tf.keras.optimizers.Adam(), 
        metrics="accuracy"
    )

    return model

## Create a TensorFlow dataset to train and validate the model

In [None]:
def generate_tf_dataset(datasets: Dict[str, Any], subset: str, batch_size: int=128, shuffle:bool = False) -> tf.data.Dataset:
    """ Create a tf.data.Dataset from the given data subset """
    
    dataset = tf.data.Dataset.from_tensor_slices((datasets[subset]["inputs"], datasets[subset]["labels"]))
    dataset = dataset.shuffle(batch_size*10) if shuffle else dataset 
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

## Train the model

In [None]:
batch_size = 128
epochs = 5

# Generate datasets
train_ds = generate_tf_dataset(datasets, "training", batch_size=batch_size, shuffle=True)
valid_ds = generate_tf_dataset(datasets, "validation", batch_size=batch_size)
test_ds = generate_tf_dataset(datasets, "test", batch_size=batch_size)

# Create the model
preprocessor, encoder = download_base_model() 
model = create_model(preprocessor, encoder)

# Train the model
model.fit(train_ds, epochs=epochs, validation_data=valid_ds)

# Save the model
tf.saved_model.save(model, "./text_classifier")

## Upload the model to GCS

In [None]:
!gsutil cp -r ./text_classifier $MODEL_BUCKET_URI