In [1]:
!pip install plotly



In [15]:
import pandas as pd
import plotly.express as px

movies_df = pd.read_csv("movies_cleaned.csv")

# cleaning & splitting production companies
movies_df['production_companies'] = movies_df['production_companies'].fillna('')

# production_companies split into distinct company names and exploded so each company gets own row
movies_df_exploded = movies_df.assign(production_company = movies_df['production_companies'].str.split(', ')).explode('production_company')

# aggregating revenue & voting averages
company_revenue = (
    movies_df_exploded.groupby('production_company').agg(total_revenue = ('revenue', 'sum'), average_vote=('vote_average', 'mean')).reset_index()
)

# generating top 25 companies by total revenue
top_25_comp = company_revenue.nlargest(25, 'total_revenue')

fig = px.treemap(
    top_25_comp,
    path = ['production_company'],
    values = 'total_revenue',
    color = 'average_vote',
    color_continuous_scale = 'OrRd',
    title = 'Top 25 Production Companies by Total Revenue'
)

fig.update_layout(margin = dict(t = 60, l = 35, r = 35, b = 35))
fig.show()
fig.write_html("top_25_treemap.html")