In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
# Snowpark for Python
import snowflake.snowpark.functions as F
from snowflake.snowpark.types import DecimalType
TRANSFORMERS_CACHE = '/tmp'
# Snowpark ML
import snowflake.ml.modeling.preprocessing as snowml
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.metrics.correlation import correlation
#from snowflake.ml import classifier as clf

# Data Science Libs
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from snowflake.ml.model.model_signature import FeatureSpec, DataType, ModelSignature
import streamlit as st
from sentence_transformers import SentenceTransformer


In [None]:
from snowflake.ml.registry import Registry

reg = Registry(session=session, database_name="BMUTHUKRISHNAN_DB", schema_name="REGISTRY")

In [None]:

df = session.table("QUICKSTART.ML_FUNCTIONS.FORECAST_TRAINING_v1").select("AGE_BIN").limit(10).to_pandas()

st.write("Available models in the registry:")
st.write(reg.models())

model_sig = ModelSignature(
    inputs=[FeatureSpec(dtype=DataType.STRING, name='TEXT')],
    outputs=[FeatureSpec(dtype=DataType.DOUBLE, name='EMBEDDING', shape=(384,))]
)

st.write("Getting model from HuggingFace")
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

In [None]:
mv = reg.log_model(model='test',
                   model_name="bank_classifier",
                   version_name="v1",
                   conda_dependencies=["scikit-learn"],
                   comment="My awesome ML model",
                   metrics={"score": 96},
                   sample_input_data=df["AGE_BIN"].tolist()
)

In [None]:
from pycaret.classification import ClassificationExperiment
from pycaret.classification import predict_model, load_model
from pycaret.datasets import get_data

from snowflake.ml.model import custom_model
from snowflake.ml.model import model_signature
from snowflake.ml.registry import Registry

import os
import shutil
import pandas as pd

data = get_data('juice')

cl_exp = ClassificationExperiment()
cl_exp.setup(data, target='Purchase', session_id=123)
best_model = cl_exp.compare_models()

cl_exp.save_model(best_model, "juice_best_model")

ARTIFACTS_DIR = "/tmp/pycaret/"

os.makedirs(os.path.join(ARTIFACTS_DIR, "model"), exist_ok=True)
shutil.move('juice_best_model.pkl', os.path.join(ARTIFACTS_DIR, 'model',  'juice_best_model.pkl'))

class PyCaretModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        #cl_exp = ClassificationExperiment()
        model_dir = self.context.path("model_file")[:-4]
        self.model = load_model(model_dir, verbose=False)
        self.model.memory='/tmp/' # Default is /var/ that is not accesable in a WH

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        model_output = predict_model(self.model, data=X)
        res_df = pd.DataFrame({"prediction_label": model_output['prediction_label'], "prediction_score": model_output['prediction_score']})
        
        return res_df

pycaret_mc = custom_model.ModelContext(
	models={ # This should be for models that is supported by Model Registry
	},
	artifacts={ # Everything not supported needs to be here
		'model_file': os.path.join(ARTIFACTS_DIR, "model",  'juice_best_model.pkl'),
	}
)

model = PyCaretModel(pycaret_mc)

model_name = "pycaret_juice"
version_name = "v1"

reg = Registry(session)

output_data = data[['Purchase']]
output_data.columns = ['prediction_label']
new_data =[]
output_data_pd = pd.DataFrame([['CH', 0.876], ['CH', 0.876], ['CH', 0.876]], columns=['prediction_label', 'prediction_score'])
predict_sign = model_signature.infer_signature(input_data=new_data, output_data=output_data_pd)

mv = reg.log_model(
    model,
    model_name=model_name,
    version_name=version_name,
    conda_dependencies=["pycaret"],
    signatures={
        "predict": predict_sign
    },
    #sample_input_data=data
)