# AgriAutoML Model Training

This notebook handles the training of vision and tabular models using Vertex AI AutoML.

In [None]:
import os
from google.cloud import aiplatform
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Initialize Vertex AI
aiplatform.init(
    project=os.getenv("GCP_PROJECT_ID"),
    location=os.getenv("GCP_REGION"),
)

## Dataset Creation

In [None]:
def create_vision_dataset(bucket_name: str, image_uri: str):
    """Create vision dataset from GCS bucket"""
    dataset = aiplatform.ImageDataset.create(
        display_name="cropnet_vision_dataset",
        gcs_source=f"gs://{bucket_name}/{image_uri}",
        import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification
    )
    return dataset

def create_tabular_dataset(bucket_name: str, table_uri: str):
    """Create tabular dataset from GCS bucket"""
    dataset = aiplatform.TabularDataset.create(
        display_name="cropnet_tabular_dataset",
        gcs_source=f"gs://{bucket_name}/{table_uri}"
    )
    return dataset

## Model Training

In [None]:
def train_vision_model(vision_dataset, budget_hours=1.0):
    """Train AutoML Vision model"""
    job = aiplatform.AutoMLImageTrainingJob(
        display_name="cropnet_vision_training",
        prediction_type="image_classification"
    )
    
    model = job.run(
        dataset=vision_dataset,
        target_column="yield",
        budget_milli_node_hours=budget_hours * 1000,
        model_display_name="cropnet_vision_model",
        training_fraction_split=0.8,
        validation_fraction_split=0.1,
        test_fraction_split=0.1
    )
    return model

def train_tabular_model(tabular_dataset, budget_hours=1.0):
    """Train AutoML Tabular model"""
    job = aiplatform.AutoMLTabularTrainingJob(
        display_name="cropnet_tabular_training",
        optimization_prediction_type="regression"
    )
    
    model = job.run(
        dataset=tabular_dataset,
        target_column="yield",
        optimization_objective="minimize-rmse",
        budget_milli_node_hours=budget_hours * 1000,
        model_display_name="cropnet_tabular_model",
        training_fraction_split=0.8,
        validation_fraction_split=0.1,
        test_fraction_split=0.1
    )
    return model