In [None]:
import pandas as pd
import numpy as np
import os
from glob import glob
from matplotlib import pyplot as plt

# Figures!

In [None]:
base_dir = "/Users/nick/Documents/Gillings_work/uncertainty_analysis_data/uncertainty_analysis_2024-12-11_07-52-16-945389"
base_dir = "/Users/nick/Documents/Gillings_work/uncertainty_analysis_data/uncertainty_analysis_2025-07-22_15-08-21-725343"
base_dir = "/Users/nick/Documents/Gillings_work/uncertainty_analysis_data/uncertainty_analysis_2025-07-22_18-32-29-456924"

if not os.path.isdir(base_dir): assert False

output_dir = os.path.join(base_dir, "outputs")
outputs_dirs = [os.path.join(output_dir, f"option_{i}") for i in range(6)]
outputs_dirs = outputs_dirs[:3] + outputs_dirs[4:] # remove option 3
print(outputs_dirs)

In [None]:
collection_list_options = []

for opt in range(6):
    if (opt >= len(outputs_dirs)): break
    outputs = outputs_dirs[opt]
    if not os.path.exists(outputs):
        continue
    collection_list = []

    # for each arr, store a 2D array in the list
    # axis = 0 are the groups: menthol, nonmenthol, cig smoker, ecig/dual, former, never, tobacco users. (3, 4, 3+4, 5, 2, 1, 3+4+5)
    # axis = 1 are the years 2016, 2021, 2026, 2031, 2051
    for i,f in enumerate(sorted(glob(outputs + "/*.npy"))):
        """ 
        The shape of each of these arrays is (51,2,2,2,6)
        The index on each axis corresponds to year, black, pov, plus65, smoking state
        The smoking states input to this function are are never smokers, former smokers, menthol smokers, nonmenthol smokers, ecig/dual users, dead people
        And the array element at that location is the weighted count.
        (These are numpy outputs from the simulation)
        """
        arr = np.load(f)
        arr = arr[:,:,:,0,:] # age-restrict 18-64
        arr = arr[:-5] # get the years we are interested in
        arr = arr[:,:,:,:-1] # don't need dead people (smoking state 6)
        arr = arr.transpose((3,1,2,0)) # transpose so we have (smoking states, black, pov, years) as axes
        arr = np.concatenate([ # want to add the smokers together too
            arr[2:4], # menthol smokers, nonmenthol smokers
            (arr[2] + arr[3])[np.newaxis, :], # total cigarette smokers
            arr[4][np.newaxis, :], # e-cig/dual
            arr[1][np.newaxis, :], # former
            arr[0][np.newaxis, :], # nonsmoker
            (arr[2] + arr[3] + arr[4])[np.newaxis, :], # total tobacco users
        ], axis=0)
        collection_list.append(arr)

    collection_list_options.append(collection_list)

collection_list_options = np.array(collection_list_options)

In [None]:
"""
The axes here are ban option #, output #, smoking state, black, pov, year
with the smoking states being menthol, nonmenthol, menthol+nonmenthol, ecig/dual, former, nonsmoker, tobacco user
"""
print(collection_list_options.shape)

# print the total population by year for option 0
# arr = collection_list_options[0,0]
# arr = np.sum(arr, axis=(0,1,2))
# for i in range(len(arr)):
#     print(str(2016+i) + " " + str(arr[i]))

In [None]:
def restrict_demo(arr, demo: str="all"):
    """
    Helper function that restricts our data array to a certain demographic

    Our data comes in the dimensions:
      opt, run, smoking status, black, pov, year

    The "demo" param is a string that can be one of the following:
      "all" 
      "black"
      "nonblack"
      "pov"
      "nonpov"

    The returned array will have dimensionts:
      opt, run, smoking status, year
    """
    if demo=="all":
        return np.sum(arr, axis=(3,4))
    elif demo=="black":
        res = np.sum(arr, axis=4)
        res = res[:,:,:,1,:]
        return res
    elif demo=="nonblack":
        res = np.sum(arr, axis=4)
        res = res[:,:,:,0,:]
        return res
    elif demo=="pov":
        res = np.sum(arr, axis=3)
        res = res[:,:,:,1,:]
        return res
    elif demo=="nonpov":
        res = np.sum(arr, axis=3)
        res = res[:,:,:,0,:]
        return res
    else:
        raise Exception("Not an accepted demographic")

## Proportion of overall smokers

Include all the ban options in each figure

Make 4 figures for each of the demographcis

In [None]:
for including_ecig_smokers in [False, True]:
    for all_ban_options in [False, True]:
        print("------------------------------------------")
        print("For the next five figures we are:")
        if not all_ban_options:
            print("- Just comparing status quo to ban option 1 (base)")
        else:
            print("- Comparing status quo to all ban options")

        if not including_ecig_smokers:
            print("- Looking at cigarette smokers (menthol + nonmenthol, no ecig)")
        else:
            print("- Looking at all tobacco users (menthol + nonmenthol + ecig/dual")
        print("------------------------------------------")
        print(" ")
        
        for demo in [
            "all",
            "black",
            "nonblack",
            "pov",
            "nonpov",
            ]:
            fig, ax = plt.subplots(1,1,figsize=(12,6), dpi=100)
            x = np.arange(2020,2056)
            title = f"Demographic group: {demo} \n"

            to_plot = np.copy(collection_list_options[:, :, :, :, :, :])
            to_plot = restrict_demo(to_plot,demo)
            # dims:  opt, output #, smoking status, year
            """ 
            When summing to get population totals, make sure that the menthol+nonmenthol group
            is not summed, because then we will be double counting those smoker groups
            """
            sums = np.sum(to_plot[:,:,[0,1,3,4,5],:], axis=2) # sum over smoking states to get totals
            to_plot = to_plot / sums[:, :, np.newaxis, :] # change from absolute counts to proportion of total
            if not including_ecig_smokers:
                to_plot = to_plot[:,:,2,:] # only want menthol + nonmenthol
            else:
                to_plot = to_plot[:,:,6,:] # only want menthol + nonmenthol + ecig/dual
            to_plot *= 100
            to_plot = to_plot[:,:,4:40] # years 2020 - 2065 inclusive

            # dims = opt, output #

            mean = np.mean(to_plot, axis=1)
            upper = np.percentile(to_plot, 97.5, axis=1)
            lower = np.percentile(to_plot, 2.5, axis=1)

            #dims = opt

            for ban_opt in [0,1,2,3,4]:
                if (not all_ban_options) and (ban_opt >1):
                    continue
                line_color = [
                    "blue",
                    "red",
                    "green",
                    "purple",
                    "darkorange",
                ][ban_opt]
                shaded_color = [
                    "deepskyblue",
                    "salmon",
                    "lightgreen",
                    "plum",
                    "moccasin",
                ][ban_opt]
                ax.plot(x, mean[ban_opt], line_color)
                ax.fill_between(x, lower[ban_opt], upper[ban_opt], facecolor=shaded_color, alpha=0.5)


            plt.ylim(0,25)
            plt.xlim(2019, 2056)
            plt.xlabel("Year", fontsize=12)
            if not including_ecig_smokers:
                plt.ylabel("Percentage of Cigarette Smokers", fontsize=12)
            else:
                plt.ylabel("Percentage of Tobacco Users", fontsize=12)
            plt.xticks(x[::5], fontsize=10, horizontalalignment='center')
            y_tick_nums = np.arange(0,25 + 1,5) # plus one because arange doesn't include the "stop" param
            y_tick_labels = [str(x) + "%" for x in y_tick_nums]
            plt.yticks(y_tick_nums, y_tick_labels)
            if all_ban_options:
                plt.legend([ 
                    f"Status quo",
                    f"Ban 1 (base)",
                    f"Ban 2 (liberal)",
                    f"Ban 3 (conservative)",
                    f"Ban 4 (e-cig)",
                ])
            else:
                plt.legend([ 
                    f"Status quo",
                    f"Menthol ban",
                ])

            # plt.title(title)
            ax.axvline(2024, ymin=0.1, ymax=0.9, color='k', linestyle='dashed')
            ax.text(2025, 21, "Menthol ban start", c='k')

            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

            print(title)
            filename = "Percentage_"
            if not including_ecig_smokers:
                filename += "Cigarette_Smokers_"
            else:
                filename += "Tobacco_Users_"
            if not all_ban_options:
                filename += "SQ_Versus_Ban_"
            else:
                filename += "All_Ban_Options_"
            filename += "Demographic_" + demo
            filename += ".png"
            
            plt.savefig(os.path.join("ban_paper_figures", filename), dpi=300)

            plt.show()

"blah"


In [None]:


for all_ban_options in [False, True]:
    print("------------------------------------------")
    if all_ban_options is False:
        print("Just comparing status quo to ban option 1 (base)")
    else:
        print("Comparing status quo to all ban options")
    print("------------------------------------------")
    
    demo = "all"
    # Menthol Smokers
    fig, ax = plt.subplots(1,1,figsize=(12,6), dpi=100)
    x = np.arange(2020,2056)
    title = f"Demographic group: {demo} \n"

    to_plot = np.copy(collection_list_options[:, :, :, :, :, :])
    to_plot = restrict_demo(to_plot,demo)
    # dims:  opt, output #, smoking status, year
    """ 
    When summing to get population totals, make sure that the menthol+nonmenthol group
    is not summed, because then we will be double counting those smoker groups
    """
    sums = np.sum(to_plot[:,:,[0,1,3,4,5],:], axis=2) # sum over smoking states to get totals
    to_plot = to_plot / sums[:, :, np.newaxis, :] # change from absolute counts to proportion of total
    # the smoking states being menthol, nonmenthol, menthol+nonmenthol, ecig/dual, former, nonsmoker
    to_plot = to_plot[:,:,0,:] # only want menthol 
    to_plot *= 100
    to_plot = to_plot[:,:,4:40] # years 2020 - 2065 inclusive

    # dims = opt, output #

    mean = np.mean(to_plot, axis=1)
    upper = np.percentile(to_plot, 97.5, axis=1)
    lower = np.percentile(to_plot, 2.5, axis=1)

    #dims = opt

    for ban_opt in [0,1,2,3,4]:
        if (not all_ban_options) and (ban_opt >1):
            continue
        line_color = [
            "blue",
            "red",
            "green",
            "purple",
            "darkorange",
        ][ban_opt]
        shaded_color = [
            "deepskyblue",
            "salmon",
            "lightgreen",
            "plum",
            "moccasin",
        ][ban_opt]
        ax.plot(x, mean[ban_opt], line_color)
        ax.fill_between(x, lower[ban_opt], upper[ban_opt], facecolor=shaded_color, alpha=0.5)


    plt.ylim(0,15)
    plt.xlim(2019, 2056)
    plt.xlabel("Year", fontsize=12)
    plt.ylabel("Percentage of Menthol Smokers", fontsize=12)
    plt.xticks(x[::5], fontsize=10, horizontalalignment='center')
    y_tick_nums = np.arange(0,15 + 1,3) # plus one because arange doesn't include the "stop" param
    y_tick_labels = [str(x) + "%" for x in y_tick_nums]
    plt.yticks(y_tick_nums, y_tick_labels)
    if all_ban_options:
        plt.legend([ 
            f"Status quo",
            f"Ban 1 (base)",
            f"Ban 2 (liberal)",
            f"Ban 3 (conservative)",
            f"Ban 4 (e-cig)",
        ])
    else:
        plt.legend([ 
            f"Status quo",
            f"Menthol ban",
        ])
    ax.axvline(2024, ymin=0.1, ymax=0.9, color='k', linestyle='dashed')
    ax.text(2025, 13, "Menthol ban start", c='k')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    print(title)
    filename = "Percentage_Menthol_Smokers_"
    if not all_ban_options:
        filename += "SQ_Versus_Ban_"
    else:
        filename += "All_Ban_Options_"
    filename += "Demographic_" + demo
    filename += ".png"
    
    plt.savefig(os.path.join("ban_paper_figures", filename), dpi=300)
    plt.show()

"blah"

In [None]:

for all_ban_options in [False, True]:
    print("------------------------------------------")
    if all_ban_options is False:
        print("Just comparing status quo to ban option 1 (base)")
    else:
        print("Comparing status quo to all ban options")
    print("------------------------------------------")
    
    demo = "all"
    # Menthol Smokers
    fig, ax = plt.subplots(1,1,figsize=(12,6), dpi=100)
    x = np.arange(2020,2056)
    title = f"Demographic group: {demo} \n"

    to_plot = np.copy(collection_list_options[:, :, :, :, :, :])
    to_plot = restrict_demo(to_plot,demo)
    # dims:  opt, output #, smoking status, year
    """ 
    When summing to get population totals, make sure that the menthol+nonmenthol group
    is not summed, because then we will be double counting those smoker groups
    """
    sums = np.sum(to_plot[:,:,[0,1,3,4,5],:], axis=2) # sum over smoking states to get totals
    to_plot = to_plot / sums[:, :, np.newaxis, :] # change from absolute counts to proportion of total
    # the smoking states being menthol, nonmenthol, menthol+nonmenthol, ecig/dual, former, nonsmoker
    to_plot = to_plot[:,:,1,:] # only want nonmenthol 
    to_plot *= 100
    to_plot = to_plot[:,:,4:40] # years 2020 - 2055 inclusive

    # dims = opt, output #

    mean = np.mean(to_plot, axis=1)
    upper = np.percentile(to_plot, 97.5, axis=1)
    lower = np.percentile(to_plot, 2.5, axis=1)

    #dims = opt

    for ban_opt in [0,1,2,3,4]:
        if (not all_ban_options) and (ban_opt >1):
            continue
        line_color = [
            "blue",
            "red",
            "green",
            "purple",
            "darkorange",
        ][ban_opt]
        shaded_color = [
            "deepskyblue",
            "salmon",
            "lightgreen",
            "plum",
            "moccasin",
        ][ban_opt]
        ax.plot(x, mean[ban_opt], line_color)
        ax.fill_between(x, lower[ban_opt], upper[ban_opt], facecolor=shaded_color, alpha=0.5)


    plt.ylim(0,15)
    plt.xlim(2019, 2056)
    plt.xlabel("Year", fontsize=12)
    plt.ylabel("Percentage of Non-Menthol Smokers", fontsize=12)
    plt.xticks(x[::5], fontsize=10, horizontalalignment='center')
    y_tick_nums = np.arange(0,15 + 1,3) # plus one because arange doesn't include the "stop" param
    y_tick_labels = [str(x) + "%" for x in y_tick_nums]
    plt.yticks(y_tick_nums, y_tick_labels)
    if all_ban_options:
        plt.legend([ 
            f"Status quo",
            f"Ban 1 (base)",
            f"Ban 2 (liberal)",
            f"Ban 3 (conservative)",
            f"Ban 4 (e-cig)",
        ])
    else:
        plt.legend([ 
            f"Status quo",
            f"Menthol ban",
        ])
    ax.axvline(2024, ymin=0.1, ymax=0.9, color='k', linestyle='dashed')
    ax.text(2025, 13, "Menthol ban start", c='k')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    print(title)
    filename = "Percentage_Non-Menthol_Smokers_"
    if not all_ban_options:
        filename += "SQ_Versus_Ban_"
    else:
        filename += "All_Ban_Options_"
    filename += "Demographic_" + demo
    filename += ".png"
    
    plt.savefig(os.path.join("ban_paper_figures", filename), dpi=300)
    plt.show()

    "blah"