In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import functools
from os import path
from collections import Counter

In [None]:
# Name of file to load in pandas
filename = 'data/Teaching_Assistantship_Survey.csv'
df = pd.read_csv(filename)

In [None]:
# utility functions

def convert_string_num_hours(key):
    """Convert a string number of hours range to the corresponding float"""
    if key=='<5':
        final_val=2.5
    elif key=='5-10':
        final_val=7.5
    elif key=='10-15':
        final_val=12.5
    elif key=='15-20':
        final_val=17.5
    elif key=='>20':
        final_val=22.5

    return final_val

def single_value(bar_data):
    """Compute an average value for number of hours spent for a given input group of data"""
    final_val = 0
    total_count = 0
    for key, value in bar_data.items():
        total_count+=value
        final_val+=convert_string_num_hours(key)*value
    if total_count>0:
        return final_val/(total_count)
    else:
        return 0
            
def df_data_processing(df, conditions_dict, target):
    """Create a view of a DataFrame by using the dictionary of queries in conditions_dict and query the column corresponding to target"""
    cond = [df[k].apply(lambda k: k in v if v != [''] else True) for k, v in conditions_dict.items()]
    cond_total = functools.reduce(lambda x, y: x & y, cond)
    if type(target)==str:
        filter_data = df[cond_total][target].dropna()
    elif type(target)==list:
        filter_data = df[cond_total][target].dropna()
    return filter_data.values


def generate_bar_plots(input_data, output_dir, title, campus, degree):
    """
    Generate bar plots for number of hours per week spent by a group
    """
    bar_data = Counter(input_data)
    single_val = single_value(dict(bar_data))
    labels = ['<5', '5-10', '10-15', '15-20', '>20']
    bar_data = {key: bar_data[key] for key in labels if key in bar_data.keys()}
    
    fig, ax = plt.subplots()
    ax.bar(bar_data.keys(), bar_data.values())
    total_respondents = sum(bar_data.values())
    ax.set_xlabel('Number of hours', size=18)
    ax.tick_params(axis='both', labelsize=18, size=10)
#     ax.set_xticks(fontsize=18)
#     ax.set_yticks(fontsize=18)
    ax.set_title(f'{title} ({total_respondents} respondents)', size=20)
    # Set tick sizes
    if degree:
        fig.savefig(path.join(output_dir, f'num_hours_{title}_{campus}_{degree}'), bbox_inches='tight')
    else:
        fig.savefig(path.join(output_dir, f'num_hours_{title}_{campus}'), bbox_inches='tight')
    return single_val

def wage_time_histograms(input_data, output_dir, title):
    """Compute effective hourly wage"""
    input_data = [float(wage)/convert_string_num_hours(time_str) for wage, time_str in input_data if float(wage)<3200]


def generate_histograms(input_data, output_dir, title, campus, degree):
    """ Create histograms of wages for a given group of input_data"""
    if input_data.any():
        if type(input_data[0])==str:
            mode='monthly_wage'
        elif type(input_data[0])==np.ndarray:
            mode='wage_unit_time'
    
    else:
        return 0

    if mode=='monthly_wage':
        input_data = [float(wage) for wage in input_data if float(wage)<3200]
    elif mode=='wage_unit_time':
        input_data = [float(wage_time[0])/(4.1*convert_string_num_hours(wage_time[1])) for wage_time in input_data if float(wage_time[0])<3200]
        
    total_respondents = len(input_data)
    fig, ax = plt.subplots()
    ax.hist(input_data, bins=6)
    single_val = np.average(input_data)
    if mode=='monthly_wage':
        ax.set_xlabel('Monthly wage', size=18)
    elif mode=='wage_unit_time':
        ax.set_xlabel('Estimated wage per hour', size=18)
    ax.tick_params(axis='both', labelsize=14)
    ax.set_title(f'{title} ({total_respondents} respondents)', size=20)
    # Set tick sizes
    if degree:
        fig.savefig(path.join(output_dir, f'{mode}_{title}_{campus}_{degree}'), bbox_inches='tight')
    else:
        fig.savefig(path.join(output_dir, f'{mode}_{title}_{campus}'), bbox_inches='tight')
    return single_val

In [None]:
def plot_entities(mode, target, campus="Atlanta", degree=None):
    """
    Plot number of hours per week or monthly wages/effective hourly wages for a target group specified by the input variables
    callback function generates the plot for each constituent unit while this function generates an aggregate plot
    
    Args:
    
    mode (str): Either 'college', 'school_CoE' or 'school_CoC' to pick out a specific group of interest
    
    target (str): Target question to plot data for. Either 'montly_wage', 'num_hours' or 'wage_unit_time'
    
    campus (str): Either 'Atlanta' or 'Online'
    
    degree (str): Either "PhD" or "Master's" for degree level of respondent
    """
    assert mode in ['college', 'school_CoE', 'school_CoC']
    assert target in ['monthly_wage', 'num_hours', 'wage_unit_time']
    if mode=='college':
        entities = ['College of Engineering', 'College of Computing', 'College of Sciences', 'Ivan Allen College', 'College of Design', 'Scheller College of Business']
    if mode=='school_CoE':
        entities = ['School of Aerospace Engg', 'School of Civil & Environmental Engg', 'School of Electrical & Computer Engg', 'School of Mechanical Engg', 'School of Industrial & Systems Engg', 'School of Chemical & Biomolecular Engg', 'School of Materials Sci & Engg', 'School of Biomedical Engg']
    if mode=='school_CoC':
        entities = ['School of Computer Science', 'School of Interactive Computing', 'School of Computational Science & Engg', 'School of Cybersecurity']
    entities_data_dict = dict()
    entities_num_dict = dict()
    output_dir = './figures'
    if target=='monthly_wage':
        cond_dict = {'Q14': ['Monthly']}
        callback = generate_histograms
        figure_suffix = 'monthly_wage'
        target_q = 'Q15'
        ylabel = 'Average monthly wage'
    elif target=='wage_unit_time':
        cond_dict = {'Q14': ['Monthly']}
        callback = generate_histograms
        figure_suffix = 'wage_unit_time'
        target_q = ['Q15', 'Q11']
        ylabel = 'Average wage per hour'
    elif target=='num_hours':
        cond_dict = dict()
        callback = generate_bar_plots
        figure_suffix = 'num_hours'
        target_q = 'Q11'
        ylabel = 'Average number of hours'
    else:
        raise ValueError

    if campus=="Atlanta":
        cond_dict['Q2'] = ['Atlanta']
    elif campus=="Online":
        cond_dict['Q2'] = ['Online']
    else:
        raise ValueError
    
    if degree:
        assert degree in ["Master's", "PhD"]
        cond_dict['Q5'] = [degree]
    
    for entity in entities:
        if mode=='college':
            cond_dict['Q3'] = [entity]
        if mode=='school_CoE':
            cond_dict['Q4_CoE'] = [entity]
        if mode=='school_CoC':
            cond_dict['Q4_CoC'] = [entity]
        entity_data = df_data_processing(df, cond_dict, target=target_q)
        entities_data_dict[entity] = callback(entity_data, output_dir, entity, campus, degree)
        entities_num_dict[entity] = len(entity_data)
    fig, ax = plt.subplots()
    ax.bar(entities_data_dict.keys(), entities_data_dict.values(), color='#af8dc3')
    ax.set_ylabel(ylabel, size=20, color='#af8dc3')
    if target=='monthly_wage':
        ax.set_ylim(0, 3000)
    if target=='wage_unit_time':
        ax.set_ylim(0, 40)
        
    ax.tick_params(axis='x', labelsize=18, size=6, rotation=90)
    ax.tick_params(axis='y', labelsize=18, size=6, colors='#af8dc3')
    ax.spines['left'].set_color('#af8dc3')
    ax2 = ax.twinx()
    ax2.scatter(entities_num_dict.keys(), entities_num_dict.values(), color='black', marker='x')
    ax2.tick_params(axis='y', labelsize=18)
    ax2.set_ylabel('Number of respondents', size=20, color='black')
    if degree:
        fig.savefig(path.join(output_dir, f'{figure_suffix}_{campus}_{mode}_{degree}.png'), bbox_inches='tight')
    else:
        fig.savefig(path.join(output_dir, f'{figure_suffix}_{campus}_{mode}.png'), bbox_inches='tight')


In [None]:
# Change the input parameters to obtain visualizations for various groups of students ot Georgia Tech

plot_entities(mode='college', degree='Master\'s', target='monthly_wage')

In [None]:
# Bar plots for TAship benefits and responsibilities

def process_benefits(entity_data):
    """Convert a float value of perceived benefits to an integer"""
    data = [int(i) for i in entity_data]
    return data
    
def process_responsibilities(entity_data):
    """Combine all entered responsibilities by each respondent into a single list"""
    data = []
    for entry in entity_data:
        data.extend(entry.split(','))
    return data

def bar_chart(mode, cond_dict, remove_tick_labels=False):
    """
    Plot bar chart of TAship benefits or TAship responsibilities
    
    Args:
    
    mode (str): Either 'TAship benefits' or 'TAship responsibilities'
    
    cond_dict (dict): Dictionary of conditions to use for generating a view of the DataFrame
    
    remove_tick_labels (bool): If True, remove Y-axis labels from plot. Used to generate a plot that would fit on the
                               page of the report
    """
    output_dir = '/home/pranav/Dropbox (GaTech)/Spring 2022/senate_work/TAship survey/output/figures'
    if mode=='TAship benefits':
        target_q = 'Q18'
        callback = process_benefits
    if mode=='TAship responsibilities':
        target_q = 'Q7'
        callback = process_responsibilities
        key_list = ['Grading assignments', 'Grading exams', 'Office hours',
                    'Responding to questions on Piazza/Canvas',
                    'Conducting lab sessions',
                    'Holding recitations', 'Teaching classes',
                    'Other (Please specify)']

    entity_data = df_data_processing(df, cond_dict, target=target_q)
    data = process_responsibilities(entity_data)
    data = dict(Counter(data))
    print(data)
    if mode=='TAship benefits':
        data = dict(sorted(data.items(), key=lambda x: int(x[0])))
    elif mode=='TAship responsibilities':
        data = {key: data.get(key, 0) for key in reversed(key_list)}
    fig, ax = plt.subplots()
    ax.barh(list(data.keys()), list(data.values()))
    ax.set_xlabel('Number of responses', size=18)
    ax.tick_params(axis='both', labelsize=18, size=10)
    if mode=='TAship responsibilities':
        ax.set_title(cond_dict['Q3'][0], size=20)
    if remove_tick_labels: ax.axes.yaxis.set_ticklabels([])
    suffix = ''
    if cond_dict.get('Q3', ''): suffix+='_'.join(cond_dict['Q3'])
    if cond_dict.get('Q4_CoE', ''): suffix+='_'.join(cond_dict['Q4_CoE'])
    fig.savefig(path.join(output_dir, f'{mode}_{suffix}'), bbox_inches='tight')


In [None]:
# Change Q3 to different colleges to obtain visualizations for different colleges 

cond_dict = {'Q1': ['Yes'],'Q2': ['Atlanta'], 'Q3': ['College of Engineering']}
bar_chart(mode='TAship responsibilities', cond_dict=cond_dict, remove_tick_labels=True)