In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [3]:
def explore_numeric_hist(df, col_name, axis, display_labels=True):
    _,_, patches = axis.hist(df[col_name])
    axis.set_title(col_name)
    axis.set_ylabel("count")

    if display_labels:
        add_labels(patches, axis)


def add_labels(patches, axis):
    # add in labels
    max_tick = max(np.abs(axis.get_yticks()))
    for p in patches:
        x,_ = p.get_xy()
        y = p.get_height()
        width = p.get_width()
        v_adjust = max_tick*0.01
        axis.text(x+width/2,y+ v_adjust,str(int(y)), ha='center')
    # increase y-axis limit to accomodate text
    ylim = axis.get_ylim()
    axis.set_ylim(ylim[0],ylim[1]*1.05)
    
def explore_bar(df, col_name, axis, display_labels=True):
    df[col_name].value_counts().plot(kind='bar',ax=axis)
    axis.set_title(col_name)
    axis.set_ylabel("count")
    
    # add in labels
    if display_labels:
        patches = axis.patches
        add_labels(patches, axis)

def explore_col(df, col_name, axis, display_labels=True):
    if (pd.api.types.is_string_dtype(df[col_name])):
        explore_bar(df, col_name, axis, display_labels)
        return
    if (pd.api.types.is_numeric_dtype(df[col_name])):
        explore_numeric_hist(df, col_name, axis, display_labels)
        return

def explore_df(df, num_display_cols=4, single_chart_width=6, single_chart_height=4, display_labels=True):
    #num_cols = num_display_cols
    num_cols_total = len(df.columns)
    num_display_rows = num_cols_total//num_display_cols +1
    fig, axs = plt.subplots(num_display_rows, num_display_cols)
    fig.set_size_inches((single_chart_width*num_display_cols, single_chart_height*num_display_rows))
    fig.subplots_adjust(hspace=1)

    axes_index = 0
    for col in df.columns:
        if num_cols_total >= num_display_cols:
            curr_axs = axs[axes_index//num_display_cols, axes_index%num_display_cols]
        else:
            curr_axs = axs[axes_index]
        explore_col(df, col, curr_axs, display_labels)
        axes_index += 1
    
    # remove excess plots
    for index in range(num_cols_total, num_display_cols*num_display_rows):
        if num_cols_total >= num_display_cols:
            curr_axs = axs[index//num_display_cols, index%num_display_cols]
        else:
            curr_axs = axs[index]
        fig.delaxes(curr_axs)