In [1]:
import numpy as np
import pandas as pd
from gensim.parsing.preprocessing import remove_stopwords
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [2]:
# loading datasets and preliminary cleaning
rotten=pd.read_csv('../../data/rotten_tomatoes_movies.csv')
rotten=rotten.drop(['rotten_tomatoes_link', 'content_rating', 'actors', 'streaming_release_date', 
                    'movie_info', 'critics_consensus', 'genres', 'directors', 'authors',# 'original_release_date',  
                    'production_company', 'tomatometer_fresh_critics_count', 'tomatometer_rotten_critics_count'], axis=1)


meta=pd.read_csv('../../data/movies_meta.csv')
meta=meta.loc[meta['original_language']=='en']  # only english movies
meta=meta.drop(['production_countries', 'overview', 'tagline', 'belongs_to_collection', 'homepage', 'revenue', 'spoken_languages', 'video', 'homepage', "poster_path", 'production_companies'], axis=1)
#meta['release_date']=pd.to_datetime(meta['release_date'])
print(rotten)

                                             movie_title  \
0      Percy Jackson & the Olympians: The Lightning T...   
1                                            Please Give   
2                                                     10   
3                        12 Angry Men (Twelve Angry Men)   
4                           20,000 Leagues Under The Sea   
...                                                  ...   
17707                                          Zoot Suit   
17708                                           Zootopia   
17709                                    Zorba the Greek   
17710                                               Zulu   
17711                                          Zulu Dawn   

      original_release_date  runtime tomatometer_status  tomatometer_rating  \
0                2010-02-12    119.0             Rotten                49.0   
1                2010-04-30     90.0    Certified-Fresh                87.0   
2                1979-10-05    122.0      

In [4]:
#clean titles
meta_drama=meta
def clean_title(x):
    '''
    Method to normalize the movie titles
    '''
    for char in [",", "'", ".", ":", ";", "(", ")", "/", "!", "?", "%", "-", "_", "="]:
        x=x.replace(char, "")        
    return x.lower()
 
rotten['movie_title']=rotten['movie_title'].apply(clean_title)
meta_drama['title']=meta_drama['title'].apply(clean_title)



In [18]:
# combine rotten and imdb data

def combine(df1, df2):
    result = pd.merge(df1, df2, how="inner", left_on='movie_title', right_on='title')
    result = result.drop(['movie_title'], axis=1)
    return result

combined=combine(rotten, meta_drama)
print(combined.columns)

Index(['original_release_date', 'runtime_x', 'tomatometer_status',
       'tomatometer_rating', 'tomatometer_count', 'audience_status',
       'audience_rating', 'audience_count', 'tomatometer_top_critics_count',
       'adult', 'budget', 'genres', 'id', 'imdb_id', 'original_language',
       'original_title', 'popularity', 'release_date', 'runtime_y', 'status',
       'title', 'vote_average', 'vote_count'],
      dtype='object')


In [19]:
print(rotten['tomatometer_status'].value_counts())
print(combined['genres']) ### extract genres

Rotten             7565
Fresh              6844
Certified-Fresh    3259
Name: tomatometer_status, dtype: int64
0                         [{'id': 37, 'name': 'Western'}]
1                         [{'id': 37, 'name': 'Western'}]
2                           [{'id': 18, 'name': 'Drama'}]
3       [{'id': 28, 'name': 'Action'}, {'id': 12, 'nam...
4       [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...
                              ...                        
2548    [{'id': 10751, 'name': 'Family'}, {'id': 14, '...
2549    [{'id': 18, 'name': 'Drama'}, {'id': 10752, 'n...
2550    [{'id': 28, 'name': 'Action'}, {'id': 12, 'nam...
2551    [{'id': 35, 'name': 'Comedy'}, {'id': 80, 'nam...
2552    [{'id': 27, 'name': 'Horror'}, {'id': 28, 'nam...
Name: genres, Length: 2553, dtype: object


In [7]:
print(combined.columns)
print(combined.__len__())

Index(['original_release_date', 'runtime_x', 'tomatometer_status',
       'tomatometer_rating', 'tomatometer_count', 'audience_status',
       'audience_rating', 'audience_count', 'tomatometer_top_critics_count',
       'adult', 'budget', 'genres', 'id', 'imdb_id', 'original_language',
       'original_title', 'popularity', 'release_date', 'runtime_y', 'status',
       'title', 'vote_average', 'vote_count'],
      dtype='object')
2553


In [9]:
le = LabelEncoder()
combined = combined[['original_release_date', 'runtime_x', 'tomatometer_status',
       'tomatometer_rating', 'tomatometer_count', 'audience_status',
       'audience_rating', 'audience_count', 'tomatometer_top_critics_count',
       'adult', 'budget', 'genres', 'id', 'imdb_id', 'original_language',
       'original_title', 'popularity', 'release_date', 'runtime_y', 'status',
       'title', 'vote_average', 'vote_count']].apply(le.fit_transform)
    

In [10]:
X=combined.drop(['tomatometer_status', 'tomatometer_rating', 'tomatometer_count', 'vote_count', 
                 'popularity', 'vote_average',  'audience_count', 'audience_rating', 'audience_status'], axis=1)
y=combined['tomatometer_status']

X_train, X_test, y_train, y_test=train_test_split(X, y)

In [14]:
model=RandomForestClassifier(n_estimators=100, max_depth=10)
model.fit(X_train, y_train)

y_pred=model.predict(X_test)

y_train_pred = model.predict(X_train) 

print(accuracy_score(y_train, y_train_pred) )
print(accuracy_score(y_test, y_pred))

0.8610240334378265
0.6338028169014085


In [12]:
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.69      0.50      0.58       135
           1       0.65      0.40      0.50       204
           2       0.61      0.85      0.71       298
           3       0.00      0.00      0.00         2

    accuracy                           0.63       639
   macro avg       0.49      0.44      0.45       639
weighted avg       0.64      0.63      0.61       639



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
