# Data Visual 3 – Enrolment by Institutions - Scatter Plot + Line

## Step 1 - Import Libraries

In [9]:
import pandas as pd
import mysql.connector
from bokeh.models import HoverTool
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from bokeh.palettes import Category20
from bokeh.transform import cumsum
from math import pi

## Step 2 - Connect to SQL Database and Retrieve Data Table

In [10]:
# Database connection configuration
db_config = {
    'user': 'weaver',
    'password': 'web101SG',
    'host': 'localhost',
    'database': 'dbsingaporepoly',
}

# Establish a connection to the database
conn = mysql.connector.connect(**db_config)

# Query to load the processed data from the MySQL table
query = "SELECT * FROM enrolmentbyinstitutions_processed"

# Load the data into a pandas DataFrame
enrolment_data = pd.read_sql(query, conn)

# Close the database connection
conn.close()

  enrolment_data = pd.read_sql(query, conn)


## Step 3 - Setup Variables

In [11]:
# List of institutions to generate line charts for
institutions = [
    'nus', 'ntu', 'smu', 'sit', 'sutd', 'suss', 'nie', 
    'singapore_polytechnic', 'ngee_ann_polytechnic', 'temasek_polytechnic', 
    'nanyang_polytechnic', 'republic_polytechnic', 
    'lasalle_diploma', 'lasalle_degree', 'nafa_diploma', 'nafa_degree', 'ite'
]

# List of years to generate line charts for
years = enrolment_data['year'].unique()

## Step 4 - Setup Grid for Generation of Scatter Plot + Line Charts

In [12]:
# Output to notebook
output_notebook()

# Prepare plots
plots = []

for institution in institutions:
    # Prepare data for male and female enrolment
    male_data = enrolment_data[(enrolment_data['sex'] == 'M') & (enrolment_data[institution].notna())]
    female_data = enrolment_data[(enrolment_data['sex'] == 'F') & (enrolment_data[institution].notna())]

    # Create a figure for the institution
    p = figure(title=f'{institution.upper()} - Male vs Female Enrollment Over the Years',
               x_axis_label='Year', y_axis_label='Enrollment',
               width=400, height=300)

    # Plot male and female data
    p.line(male_data['year'], male_data[institution], color='blue', legend_label='Male', line_width=2)
    p.line(female_data['year'], female_data[institution], color='red', legend_label='Female', line_width=2)

    # Add scatter plot at data points for better visibility
    p.scatter(male_data['year'], male_data[institution], color="blue", size=8)
    p.scatter(female_data['year'], female_data[institution], color="red", size=8)

    # Add hover tool
    hover = HoverTool()
    hover.tooltips = [("Year", "@x"), ("Enrollment", "@y")]
    p.add_tools(hover)

    p.legend.location = "top_left"
    plots.append(p)

# Arrange plots in a grid with 2 plots per row
grid = gridplot([plots[i:i+2] for i in range(0, len(plots), 2)])

# Show the grid of plots
show(grid)