In [None]:
"""The file is used for training automl models for forecasting. 
The file requires a config file (e.g. model_run.json) that has arguments
corresponding to the different member variables of the below classes.
The config file is specified in the variable filepath (refer end of the program)
When the notebook begins running, if training succeeds the cell output will be of the form
"PIPELINE_STATE_RUNNING". You will also be able to verify if the model training through the cloud GUI
in the section Vertex AI -> Training.
The final trained model is stored at Vertex AI -> Models.
For additional docs specifically for forecasting: https://drive.google.com/drive/folders/1Uj0mFA-u9_g1jHYxzy8zqUVDtdQuNxPw?usp=sharing
For docs related to API for the functions used below: https://googleapis.dev/python/aiplatform/latest/aiplatform.html
"""
from typing import Generic, TypeVar, Optional, List, Dict
from pydantic import BaseModel, parse_file_as
from google.cloud import aiplatform
from pathlib import Path
import os
class Setup(BaseModel): 
    project_id: str 
    # It is not the "NAME", but the unique ID assigned during creation of the project
    
    region: Optional[str] 
    # This is optional, stores where the model is to be saved. Currently forecasting is 
    # available only in us-central1 and eu-west4
    
    staging_bucket_name: str 
    # Any bucket that Vertex AI can use for its training.
    
    def init_ai(self):
        aiplatform.init(project=self.project_id, staging_bucket=self.staging_bucket_name, location=self.region)

class Dataset(BaseModel):
    # Dataset requirements: There are several subtleties involved in preparing the dataset if it does not exist
    # Please refer https://cloud.google.com/vertex-ai/docs/datasets/prepare-tabular#import-source for best practices and data preparation.
    
    exists: bool 
    # Argument if dataset already created (Vertex AI first converts csv/bigquery to "DATASET")
    
    source: str 
    # If exists is True: pass the ID of the dataset. E.g. projects/242781379053/locations/us-central1/datasets/8236001799018905600 or 8236001799018905600.
    #If False: pass the url (either gcs or bq) where the dataset is present. 
    
    gcs: Optional[bool] = True 
    # This option is required if exists != True. In that case, this option specifies whether 
    # dataset is bigquery or gcs.
    
    display_name: Optional[str] 
    # This option is required if exists != True. You can specify the display name that appears
    # in the vertex AI console using this method.
    
    def init_dataset(self):
        if self.exists:
            return aiplatform.TimeSeriesDataset(dataset_name = self.source)
        else:
            if not self.gcs:
                return aiplatform.datasets.TimeSeriesDataset.create(display_name=self.display_name,
                                                                    bq_source=self.source)
            else:
                return aiplatform.datasets.TimeSeriesDataset.create(display_name=self.display_name, gcs_source=self.source)

class TrainingJob(BaseModel):
    display_name: str 
    # Required: Displays the name of the training job in the console
    
    time_column: str 
    # A timestamp field. Tells the date/time of that datapoint
    
    target_column: str 
    # The column of the dataset that needs to be forecast
    
    time_series_identifier_column: str 
    # Column showing different time series, e.g. user ID during BP prediction 
    
    available_at_forecast_columns: List[str] 
    # Columns that are available during forecasting (e.g. time_column)
    
    unavailable_at_forecast_columns: List[str] 
    # Columns that are unavailable during forecasting (e.g. )
    
    data_granularity_unit: str 
    # Possible values: minute, hour, day, week, month, year, defines the time period of the provided time series.
    
    data_granularity_count = 1 
    # By default 1 unless specified in JSON (allowed:  1, 5, 10, 15, or 30). Only allowed to specify if data_granularity_unit is minutes.
    # Example if data_granularity_count = 5 then 5 minutes between each datapoint in the time series.
    
    forecast_horizon: int # The number of points in the future to be predicted
    
    context_window : int 
    # The number of points used in the past for prediction. The recommended value is [forecast_horizon, 10*forecast_horizon].
    
    time_series_attribute_columns: Optional[List[str]] = None 
    # Attributes that are fixed for a fixed time_series_identifier_column. For example: a user's age may
    # be considered fixed and thus belongs to the time_series_attribute_columns.
    
    predefined_split_column: Optional[str] = None 
    # Vertex AI Default split: https://cloud.google.com/vertex-ai/docs/datasets/prepare-tabular#import-source
    # If custom split required then each datapoint needs to be classified as TRAIN, VAL and TEST. Care needs to be taken to ensure no information leakage during training.
    
    optimization_objective: Optional[str] = None 
    # Options "minimize-rmse", "minimize-mae", "minimize-rmspe", "minimize-wape-mae", "minimize-quantile-loss" 
    # ref: https://googleapis.dev/python/aiplatform/latest/aiplatform.html
    
    column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None 
    # available transformations: numeric, text, categorical, timestamp
    
    weight_column: Optional[str] = None 
    # Specifies weight of the row/datapoint for prediction accuracy. Similar to Weighted LR.
    
    quantiles: Optional[List[float]] = None 
    # This option is only required if loss is quantile loss
    
    budget_milli_node_hours: Optional[int] = None 
    # By Default 1000 hours. Defines training budget, based on this more models are explored
    
    model_display_name: Optional[str] = None 
    # Name of the model after training (optional).
    
    export_evaluated_data_items: False 
    # If False, then does not export Test results. If True, test results exported as bigquery database.
    
    export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None
    # Optional, if export_evaluated_data_items is True, exports as bigquery. Format of string: bq://<project_id>:<dataset_id>:<table>
    # If not specified and export_evaluated_data_items is True, creates new table of format: 
    # <project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>.evaluated_examples
    
    def create_and_run(self, ds):
        job = aiplatform.AutoMLForecastingTrainingJob(
            display_name=self.display_name,
            optimization_objective=self.optimization_objective,    
            column_transformations = self.column_transformations
            )
        model = job.run(
            dataset=ds,
            target_column=self.target_column,
            time_column=self.time_column,
            time_series_identifier_column=self.time_series_identifier_column,
            available_at_forecast_columns=self.available_at_forecast_columns,
            unavailable_at_forecast_columns=self.unavailable_at_forecast_columns,
            time_series_attribute_columns=self.time_series_attribute_columns,
            forecast_horizon=self.forecast_horizon,
            context_window=self.context_window,
            data_granularity_unit=self.data_granularity_unit,
            data_granularity_count=self.data_granularity_count,
            weight_column=self.weight_column,
            budget_milli_node_hours=self.budget_milli_node_hours,
            model_display_name=self.model_display_name, 
            predefined_split_column_name=self.predefined_split_column,
            export_evaluated_data_items=self.export_evaluated_data_items,
            export_evaluated_data_items_bigquery_destination_uri=self.export_evaluated_data_items_bigquery_destination_uri
        )
    
class Main(BaseModel):
    setup: Setup
    dataset: Dataset
    training_job: TrainingJob
    def run(self):
        os.system("gcloud config set project " + self.setup.project_id)
        self.setup.init_ai()
        self.training_job.create_and_run(self.dataset.init_dataset())

filepath = "training_automl_config.json" # Filepath to config file for training
main_func = parse_file_as(Main, filepath)
main_func.run()