In [80]:
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold, cross_val_score, cross_val_predict
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [47]:
columns = ["mpg", "cylinders", "displacement", "horsepower", "weight", "acceleration", "model year", "origin", "car name"]
cars = pd.read_table("data/auto-mpg.data", delim_whitespace=True, names=columns)
filtered_cars = cars[cars['horsepower'] != '?'].copy().reset_index()
filtered_cars['horsepower'] = filtered_cars['horsepower'].astype('float')

In [48]:
def train_and_test(cols):
    lr = LinearRegression()
    lr.fit(filtered_cars[cols], filtered_cars['mpg'])
    predictions = lr.predict(filtered_cars[cols])
    mse = mean_squared_error(filtered_cars['mpg'], predictions)
    return mse, predictions.var()

In [62]:
def train_and_cross_val(cols):
    kf = KFold(n_splits=10, shuffle=True, random_state=3)
    mse_list = []
    var_list = []
    for train_ix, test_ix in kf.split(filtered_cars[cols], filtered_cars['mpg']):
        lr = LinearRegression()
        lr.fit(filtered_cars.loc[train_ix, cols], filtered_cars.loc[train_ix, 'mpg'])
        predictions = lr.predict(filtered_cars.loc[test_ix, cols])
        mse_list.append(mean_squared_error(filtered_cars.loc[test_ix, 'mpg'], predictions))
        var_list.append(np.var(predictions))
    return np.mean(mse_list), np.mean(var_list)
    

In [76]:
feature_list = ['cylinders', 'displacement', 'horsepower', 'weight', 'acceleration', 'model year', 'origin']
plot_data = []
for i in range(2, len(feature_list) + 1):
    subset_features = feature_list[:i]
    num_features = len(subset_features)
    mse, var = train_and_cross_val(subset_features)
    plot_data.append([num_features, mse, var])
plot_df = pd.DataFrame(plot_data, columns=['num_features', 'mse', 'var'])

plt.scatter(plot_df['num_features'], plot_df['mse'], c='red')
plt.scatter(plot_df['num_features'], plot_df['var'], c='blue')
plt.show()