In [94]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error

import pickle

In [88]:
class FixDatafarme:
    
    def __init__(self, path:str):
         self.path = path
         self.df = pd.read_csv(path)

    def remove_columns(self, col_names:list):
        self.df = self.df.drop(col_names, axis=1)
        return self.df

    def fix_col_names(self):
        cols = []
        for col in list(self.df.columns):
            cols.append(col.lower().replace(" ", "_"))
        self.df.columns = cols
        return self.df

    def impute_missing(self):
        self.cols = self.df.columns
        self.df = self.df[self.cols]
        self.df[self.cols] = self.df[self.cols].apply(pd.to_numeric, errors='coerce')
        for col in self.cols:
            if self.df[col].dtype != "object":
                #self.df[col] = self.df[col].apply(pd.to_numeric, errors='coerce')
                self.df[col].fillna(self.df[col].median(), inplace=True)

    def return_df(self):
        return self.df

In [89]:
fd = FixDatafarme("./auto-mpg.csv")
fd.remove_columns(["Unnamed: 9", "car name"])
fd.fix_col_names()
fd.impute_missing()
df = fd.return_df()


In [102]:
class PredictionModel:
    
    def __init__(self, data, alpha):
        self.data = data
        self.alpha = alpha

    def model_fit(self):
        self.model = Ridge(alpha=self.alpha)
        self.X_train = self.data.drop(columns=["mpg"], axis=1)
        self.Y_train = self.data["mpg"]
        self.model.fit(self.X_train, self.Y_train)
        self.pred_op = self.model.predict(self.X_train)

        self.mse = mean_squared_error(self.Y_train, self.pred_op, squared=False)

        with open(f"ridge_auto_mse-{np.round(self.mse, 3)}.bin", "wb") as f_out:
            pickle.dump(self.model, f_out)
        return self.mse

In [103]:
model = PredictionModel(df, 0.5)

model.model_fit()

3.2933011409318067

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 398 entries, 0 to 397
Data columns (total 9 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   mpg           398 non-null    float64
 1   cylinders     398 non-null    int64  
 2   displacement  398 non-null    float64
 3   horsepower    398 non-null    object 
 4   weight        398 non-null    float64
 5   acceleration  398 non-null    float64
 6   model_year    398 non-null    float64
 7   origin        398 non-null    int64  
 8   car_name      398 non-null    object 
dtypes: float64(5), int64(2), object(2)
memory usage: 28.1+ KB
