In [1]:
import pandas as pd
import numpy as np
from sklearn import linear_model
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import cross_val_score

In [2]:
df = pd.read_csv("data/movies_ratings.csv")
df.columns

Index(['Unnamed: 0', 'rating_avg', 'timestamp', 'belongs_to_collection',
       'popularity', 'Fantasy', 'TV Movie', 'Adventure', 'Thriller', 'Western',
       ...
       'zu_lan', 'num_languages', 'Canceled', 'In Production', 'Planned',
       'Post Production', 'Released', 'Rumored', 'tagline_len', 'video_int'],
      dtype='object', length=384)

In [3]:
alphas = [0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10]
res = pd.DataFrame()
res['alpha'] = alphas
mae = []
cv_score = []
num = []

for alpha in alphas:
    # lasso feature selection
    clf = linear_model.Lasso(alpha=alpha)
    X = df.drop(columns=['rating_avg', 'Unnamed: 0'])
    y = df.rating_avg
    mdl = clf.fit(X, y)
    coef = clf.coef_
    features = X.columns[coef!=0]
    X = X[features]
    num.append(len(features))
    
    # linear regression
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    linreg = LinearRegression()
    linreg.fit(X_train, y_train)
    y_hat_test = linreg.predict(X_test)
    
    # mean absolute error
    test_mse = mean_squared_error(y_test, y_hat_test)
    test_rmse = np.sqrt(test_mse)
    test_mae = mean_absolute_error(y_test, y_hat_test)
    mae.append(test_mae)
    
    # cross validation scores
    cv_5_results = cross_val_score(linreg, 
                                   X, 
                                   y, 
                                   cv=5, 
                                   scoring="neg_mean_squared_error")
    cv_score.append(np.mean(cv_5_results))
    
res['cv_score'] = cv_score
res['mae'] = mae
res['num of features'] = num
res

  positive)


Unnamed: 0,alpha,cv_score,mae,num of features
0,1e-05,-0.636412,0.529617,267
1,0.0001,-0.62456,0.558974,185
2,0.001,-0.605183,0.533316,62
3,0.01,-0.607309,0.538329,15
4,0.1,-0.606047,0.542659,8
5,1.0,-0.742966,0.66193,5
6,10.0,-0.748317,0.655494,4
