In [None]:
!pip install -U sagemaker

In [None]:
import os
import time
import boto3
import numpy as np
import pandas as pd
import sagemaker
from sagemaker import get_execution_role
from sagemaker.workflow.pipeline_context import PipelineSession

In [None]:
sess = boto3.Session()
sm = sess.client("sagemaker")
role = get_execution_role()
sagemaker_session = sagemaker.Session(boto_session=sess)
bucket = sagemaker_session.default_bucket()
default_bucket_prefix = sagemaker_session.default_bucket_prefix
region = boto3.Session().region_name


In [None]:
bucket, default_bucket_prefix

In [None]:

!pip install -U sagemaker xgboost


In [None]:

from sagemaker.xgboost.estimator import XGBoost

xgb_estimator = XGBoost(
    entry_point="train.py",
    role=role,
    instance_count=1,
    instance_type="ml.m5.large",
    framework_version="1.7-1",
    py_version="py3",
    base_job_name="xgboost-pipeline-model",
    output_path=f"s3://{bucket}/{prefix}/training-jobs"
)


In [None]:

%%writefile code/train.py

import argparse
import os
import pandas as pd
import xgboost as xgb
import joblib

label_column = "actual_productivity"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
    parser.add_argument("--test", type=str, default=os.environ.get("SM_CHANNEL_TEST"))
    parser.add_argument("--sm-model-dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    train_df = pd.read_csv(os.path.join(args.train, "train.csv"))
    test_df = pd.read_csv(os.path.join(args.test, "test.csv"))

    X_train = train_df.drop(columns=[label_column])
    y_train = train_df[label_column]
    X_test = test_df.drop(columns=[label_column])
    y_test = test_df[label_column]

    model = xgb.XGBRegressor(objective="reg:squarederror", n_estimators=100, max_depth=4)
    model.fit(X_train, y_train)

    joblib.dump(model, os.path.join(args.sm_model_dir, "xgboost-model.joblib"))
