In [ ]:
!pip install mlflow==2.13.2 sagemaker-mlflow==0.1.0 cloudpickle==2.2.1

In [ ]:
from functions import *

import os
import boto3
import pandas as pd
import io
import mlflow
from mlflow.tracking import MlflowClient
from itertools import combinations
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from sagemaker.sklearn.model import SKLearnModel
from sagemaker import get_execution_role
from datetime import datetime

settings = read_settings()

In [ ]:
mlflow.set_tracking_uri(settings['mlflow_arn'])
mlflow.set_experiment(settings['mlflow_experiment_name'])
client = MlflowClient()
role = get_execution_role()

In [ ]:
registered_model = client.get_registered_model(name=settings['mlflow_model_name'])
run_id = registered_model.latest_versions[0].run_id
source_path = registered_model.latest_versions[0].source
model_path = os.path.join(source_path, 'model.tar.gz')

In [ ]:
os.makedirs("inference_script", exist_ok=True)

In [ ]:
%%writefile inference_script/inference.py

import os
import pickle
import json
import pandas as pd
import io
from sagemaker_containers.beta.framework import worker

def model_fn(model_dir):
    with open(os.path.join(model_dir, "model.pkl"), "rb") as f:
        model = pickle.load(f)
    return model

def input_fn(input_data, content_type):
    if content_type == 'text/csv':
        return pd.read_csv(io.StringIO(input_data))
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

def predict_fn(input_data, model):
    predictions = model.predict_proba(input_data)
    return predictions

def output_fn(prediction, content_type):
    if content_type == 'application/json':
        return worker.Response(json.dumps(prediction.tolist()), mimetype=content_type)
    elif content_type == 'text/csv':
        return worker.Response(pd.DataFrame(prediction).to_csv(index=False), mimetype=content_type)
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

In [ ]:
%%writefile inference_script/inference.py

import os
import pickle
import json
import pandas as pd
import io
from sagemaker_containers.beta.framework import worker
import joblib

def model_fn(model_dir):
    model_path = os.path.join(model_dir, 'model.joblib')
    model = joblib.load(model_path)
    return model

def input_fn(input_data, content_type):
    if content_type == 'text/csv':
        return pd.read_csv(io.StringIO(input_data))
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

def predict_fn(input_data, model):
    # Ensure the input data is in the correct format
    if isinstance(input_data, pd.DataFrame):
        predictions = model.predict_proba(input_data)
    else:
        raise ValueError("Input data should be a pandas DataFrame")
    return predictions

def output_fn(prediction, accept='application/json'):
    if accept == 'application/json':
        return json.dumps(prediction.tolist())
    elif accept == 'text/csv':
        return pd.DataFrame(prediction).to_csv(index=False)
    else:
        raise ValueError(f"Unsupported accept type: {accept}")

In [ ]:
sklearn_model = SKLearnModel(
    model_data=model_path,
    role=role,
    entry_point='inference.py',
    source_dir='inference_script',
    framework_version='1.2-1'
)

In [ ]:
output_path = os.path.join("s3://",settings['bucket_name'],settings['project_path_s3'],"output","job_pred")

transformer = sklearn_model.transformer(
    instance_count=1,
    instance_type='ml.m5.large',
    output_path=output_path,
    assemble_with='Line'
)

In [ ]:
input_path = os.path.join("s3://",settings['bucket_name'],settings['project_path_s3'],"data","inference_train_job")

timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
transform_job_name_with_timestamp = f"{settings['transform_job_name']}-{timestamp}"

transformer.transform(
    data=input_path,
    content_type='text/csv',
    split_type='Line',
    job_name = transform_job_name_with_timestamp
)

In [ ]:
inference_job_s3_key = os.path.join(settings['project_path_s3'], "output", "job_pred", "inference.csv.out")