# Gotta override the predict method for this to be usable
A country of 64 is not going to help much. Thankfully, after doing some research it does not appear that the prediction inputs need to be scaled as the pipeline will take care of that for us.

In [74]:
import pandas as pd
import lakefs_client
from lakefs_client.client import LakeFSClient
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import mlflow
from sklearn.linear_model import SGDClassifier

In [75]:
# lakeFS credentials and endpoint
configuration = lakefs_client.Configuration()
configuration.username = ''
configuration.password = ''
configuration.host = '' # yeah so you do need to make a hole in the security group
                         # even if the two ec2 instances are in the same SG.
                         # These are private IPs, if you are wondering - hopefully this
                         # does not go over the net...
client = LakeFSClient(configuration)

In [76]:
# This grabs the data.csv file out of the main branch of the countries repository in lakeFS 
file = client.objects.get_object('countries','main','data.csv')
df = pd.read_csv(file)

In [77]:
features = df[["Longitude", "Latitude"]].to_numpy()
target = df[["CID"]].to_numpy()

In [78]:
X_train, X_test, y_train, y_test = train_test_split(features, target.ravel(), test_size=0.2, random_state=42)

In [79]:
clf = make_pipeline(StandardScaler(), LinearSVC(dual="auto", random_state=0, tol=1e-4))

In [80]:
clf.fit(X_train, y_train)

In [81]:
params = {"pipeline":True, "scaler": "standard", "dual": "auto", "random_state": 0, "tol":1e-4}

In [82]:
y_pred = clf.predict(X_test)

In [83]:
accuracy = accuracy_score(y_test, y_pred)
accuracy

0.07256704980842912

In [84]:
predictions = clf.predict(X_test)

feature_names = ['lon', 'lat']

# Convert X_test validation feature data to a Pandas DataFrame
result = pd.DataFrame(X_test, columns=feature_names)

# Add the actual classes to the DataFrame
result["actual_class"] = y_test

# Add the model predictions to the DataFrame
result["predicted_class"] = predictions

result[:10]

Unnamed: 0,lon,lat,actual_class,predicted_class
0,177.754059,-17.431265,77,234
1,159.882172,-9.700796,208,158
2,-78.123383,24.624023,17,252
3,27.555864,61.573971,78,78
4,126.664238,38.245316,165,114
5,-73.44484,-46.934589,47,252
6,-8.826954,7.314955,127,252
7,-62.989243,17.494589,189,252
8,69.982071,30.833944,170,252
9,-13.343016,10.354385,97,252


Messed with the LSVC params a bit. It's stuck around 0.07. Don't think this can be improved much, but I don't really know what I am doing. Let's see if we can override the predict method.
https://mlflow.org/docs/latest/traditional-ml/creating-custom-pyfunc/notebooks/override-predict.html

In [130]:
from joblib import dump

from mlflow.pyfunc import PythonModel
from mlflow.models import infer_signature

In [131]:
model_directory = "/tmp/sklearn_model.joblib"
dump(clf, model_directory)

['/tmp/sklearn_model.joblib']

In [132]:
# This grabs the data.csv file out of the main branch of the countries repository in lakeFS 
file = client.objects.get_object('countries','main','countries.csv')
df = pd.read_csv(file)

In [133]:
# Need a python dictionary representation for my method. Not very fancy, I know.
countries_dict = df.set_index('CID').T.to_dict('records')[0] # Comes in a list, stripping that with [0]

In [134]:
class ModelWrapper(PythonModel):
    def __init__(self):
        self.model = None

    def load_context(self, context):
        import sklearn
        from joblib import load

        self.model = load(context.artifacts["model_path"])

    # Return the country string, not the country id from the prediction method
    def predict(self, context, model_input):
        countries_dict = {
             1: 'Afghanistan',
             2: 'Albania',
             3: 'Algeria',
             4: 'American Samoa',
             5: 'Andorra',
             6: 'Angola',
             7: 'Anguilla',
             8: 'Antarctica',
             9: 'Antigua and Barbuda',
             10: 'Argentina',
             11: 'Armenia',
             12: 'Aruba',
             13: 'Australia',
             14: 'Austria',
             15: 'Azerbaijan',
             16: 'Azores',
             17: 'Bahamas',
             18: 'Bahrain',
             19: 'Bangladesh',
             20: 'Barbados',
             21: 'Belarus',
             22: 'Belgium',
             23: 'Belize',
             24: 'Benin',
             25: 'Bermuda',
             26: 'Bhutan',
             27: 'Bolivia',
             28: 'Bonaire',
             29: 'Bosnia and Herzegovina',
             30: 'Botswana',
             31: 'Bouvet Island',
             32: 'Brazil',
             33: 'British Indian Ocean Territory',
             34: 'British Virgin Islands',
             35: 'Brunei Darussalam',
             36: 'Bulgaria',
             37: 'Burkina Faso',
             38: 'Burundi',
             39: 'Cabo Verde',
             40: 'Cambodia',
             41: 'Cameroon',
             42: 'Canada',
             43: 'Canarias',
             44: 'Cayman Islands',
             45: 'Central African Republic',
             46: 'Chad',
             47: 'Chile',
             48: 'China',
             49: 'Christmas Island',
             50: 'Cocos Islands',
             51: 'Colombia',
             52: 'Comoros',
             53: 'Congo',
             54: 'Congo DRC',
             55: 'Cook Islands',
             56: 'Costa Rica',
             57: "Côte d'Ivoire",
             58: 'Croatia',
             59: 'Cuba',
             60: 'Curacao',
             61: 'Cyprus',
             62: 'Czech Republic',
             63: 'Denmark',
             64: 'Djibouti',
             65: 'Dominica',
             66: 'Dominican Republic',
             67: 'Ecuador',
             68: 'Egypt',
             69: 'El Salvador',
             70: 'Equatorial Guinea',
             71: 'Eritrea',
             72: 'Estonia',
             73: 'Eswatini',
             74: 'Ethiopia',
             75: 'Falkland Islands',
             76: 'Faroe Islands',
             77: 'Fiji',
             78: 'Finland',
             79: 'France',
             80: 'French Guiana',
             81: 'French Polynesia',
             82: 'French Southern Territories',
             83: 'Gabon',
             84: 'Gambia',
             85: 'Georgia',
             86: 'Germany',
             87: 'Ghana',
             88: 'Gibraltar',
             89: 'Glorioso Islands',
             90: 'Greece',
             91: 'Greenland',
             92: 'Grenada',
             93: 'Guadeloupe',
             94: 'Guam',
             95: 'Guatemala',
             96: 'Guernsey',
             97: 'Guinea',
             98: 'Guinea-Bissau',
             99: 'Guyana',
             100: 'Haiti',
             101: 'Heard Island and McDonald Islands',
             102: 'Honduras',
             103: 'Hungary',
             104: 'Iceland',
             105: 'India',
             106: 'Indonesia',
             107: 'Iran',
             108: 'Iraq',
             109: 'Ireland',
             110: 'Isle of Man',
             111: 'Israel',
             112: 'Italy',
             113: 'Jamaica',
             114: 'Japan',
             115: 'Jersey',
             116: 'Jordan',
             117: 'Juan De Nova Island',
             118: 'Kazakhstan',
             119: 'Kenya',
             120: 'Kiribati',
             121: 'Kuwait',
             122: 'Kyrgyzstan',
             123: 'Laos',
             124: 'Latvia',
             125: 'Lebanon',
             126: 'Lesotho',
             127: 'Liberia',
             128: 'Libya',
             129: 'Liechtenstein',
             130: 'Lithuania',
             131: 'Luxembourg',
             132: 'Madagascar',
             133: 'Madeira',
             134: 'Malawi',
             135: 'Malaysia',
             136: 'Maldives',
             137: 'Mali',
             138: 'Malta',
             139: 'Marshall Islands',
             140: 'Martinique',
             141: 'Mauritania',
             142: 'Mauritius',
             143: 'Mayotte',
             144: 'Mexico',
             145: 'Micronesia',
             146: 'Moldova',
             147: 'Monaco',
             148: 'Mongolia',
             149: 'Montenegro',
             150: 'Montserrat',
             151: 'Morocco',
             152: 'Mozambique',
             153: 'Myanmar',
             154: 'Namibia',
             155: 'Nauru',
             156: 'Nepal',
             157: 'Netherlands',
             158: 'New Caledonia',
             159: 'New Zealand',
             160: 'Nicaragua',
             161: 'Niger',
             162: 'Nigeria',
             163: 'Niue',
             164: 'Norfolk Island',
             165: 'North Korea',
             166: 'North Macedonia',
             167: 'Northern Mariana Islands',
             168: 'Norway',
             169: 'Oman',
             170: 'Pakistan',
             171: 'Palau',
             172: 'Palestinian Territory',
             173: 'Panama',
             174: 'Papua New Guinea',
             175: 'Paraguay',
             176: 'Peru',
             177: 'Philippines',
             178: 'Pitcairn',
             179: 'Poland',
             180: 'Portugal',
             181: 'Puerto Rico',
             182: 'Qatar',
             183: 'Réunion',
             184: 'Romania',
             185: 'Russian Federation',
             186: 'Rwanda',
             187: 'Saba',
             188: 'Saint Barthelemy',
             189: 'Saint Eustatius',
             190: 'Saint Helena',
             191: 'Saint Kitts and Nevis',
             192: 'Saint Lucia',
             193: 'Saint Martin',
             194: 'Saint Pierre and Miquelon',
             195: 'Saint Vincent and the Grenadines',
             196: 'Samoa',
             197: 'San Marino',
             198: 'Sao Tome and Principe',
             199: 'Saudi Arabia',
             200: 'Senegal',
             201: 'Serbia',
             202: 'Seychelles',
             203: 'Sierra Leone',
             204: 'Singapore',
             205: 'Sint Maarten',
             206: 'Slovakia',
             207: 'Slovenia',
             208: 'Solomon Islands',
             209: 'Somalia',
             210: 'South Africa',
             211: 'South Georgia and South Sandwich Islands',
             212: 'South Korea',
             213: 'South Sudan',
             214: 'Spain',
             215: 'Sri Lanka',
             216: 'Sudan',
             217: 'Suriname',
             218: 'Svalbard',
             219: 'Sweden',
             220: 'Switzerland',
             221: 'Syria',
             222: 'Tajikistan',
             223: 'Tanzania',
             224: 'Thailand',
             225: 'Timor-Leste',
             226: 'Togo',
             227: 'Tokelau',
             228: 'Tonga',
             229: 'Trinidad and Tobago',
             230: 'Tunisia',
             231: 'Turkiye',
             232: 'Turkmenistan',
             233: 'Turks and Caicos Islands',
             234: 'Tuvalu',
             235: 'Uganda',
             236: 'Ukraine',
             237: 'United Arab Emirates',
             238: 'United Kingdom',
             239: 'United States',
             240: 'United States Minor Outlying Islands',
             241: 'Uruguay',
             242: 'US Virgin Islands',
             243: 'Uzbekistan',
             244: 'Vanuatu',
             245: 'Vatican City',
             246: 'Venezuela',
             247: 'Vietnam',
             248: 'Wallis and Futuna',
             249: 'Yemen',
             250: 'Zambia',
             251: 'Zimbabwe',
             252: 'Ocean'
        }
        key = self.model.predict(model_input)
        return countries_dict[key[0]] # So this is for one prediction at a time, not an array of them


In [135]:
# Define the required artifacts associated with the saved custom pyfunc
artifacts = {"model_path": model_directory}

# Define the signature associated with the model
signature = infer_signature(X_train)


In [149]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")

In [150]:
with mlflow.start_run() as run:
    mlflow.pyfunc.log_model(
        python_model=ModelWrapper(),
        input_example=X_train,
        signature=signature,
        artifacts=artifacts,
        pip_requirements=["joblib", "sklearn"],
        artifact_path="countries_name",
        registered_model_name="tracking-name",
    )


Successfully registered model 'tracking-name'.
2023/12/03 19:42:03 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-name, version 1
Created version '1' of model 'tracking-name'.


In [151]:
mlflow.end_run()

So this should output the name for a single inference at a time. No batch inferences please, my ghetto code cannot handle your big ass numpy arrays