In [9]:
from sklearn.preprocessing import OneHotEncoder
import pandas as pd


def build_x():
    # X has 54 columns, after 9 columns were one-hot-encoded.
    # The 9 columns are: cap-shape,cap-surface,bruises,odor,gill-size,stalk-root,stalk-surface-above-ring,stalk-surface-below-ring,stalk-color-above-ring,spore-print-color
    # The OHE coded columns are:
    df = pd.read_csv("simple.csv")
    X = df[['cap-shape','cap-surface','bruises','odor','gill-size','stalk-root','stalk-surface-above-ring','stalk-surface-below-ring','stalk-color-above-ring','spore-print-color']]
    ohe = OneHotEncoder(handle_unknown='ignore', sparse_output=False, dtype='int')
    X_encoded = pd.DataFrame(data=ohe.fit_transform(X), columns=ohe.get_feature_names_out())
    return X_encoded

X = build_x()
X.head()


def fix_cap_shape(cap_shape):
    # The cap-shape column has 6 unique values: b, c, f, k, s, x
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be cap-shape_b, cap-shape_c, cap-shape_f, cap-shape_k, cap-shape_s, cap-shape_x
    match cap_shape:
        case "b":
            return "cap-shape_b"
        case "c":
            return "cap-shape_c"
        case "f":
            return "cap-shape_f"
        case "k":
            return "cap-shape_k"
        case "s":
            return "cap-shape_s"
        case "x":
            return "cap-shape_x"
        case _:
            return cap_shape

def fix_cap_surface(cap_surface):
    # The cap-surface column has 4 unique values: f, g, s, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be cap-surface_f, cap-surface_g, cap-surface_s, cap-surface_y
    match cap_surface:
        case "f":
            return "cap-surface_f"
        case "g":
            return "cap-surface_g"
        case "s":
            return "cap-surface_s"
        case "y":
            return "cap-surface_y"
        case _:
            return cap_surface
        
def fix_bruises(bruises):
    # The bruises column has 2 unique values: f, t
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be bruises_f, bruises_t
    match bruises:
        case "No":
            return "bruises_f"
        case "Yes":
            return "bruises_t"
        case _:
            return bruises
        
def fix_odor(odor):
    # The odor column has 9 unique values: a, c, f, l, m, n, p, s, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be odor_a, odor_c, odor_f, odor_l, odor_m, odor_n, odor_p, odor_s, odor_y
    match odor:
        case "a":
            return "odor_a"
        case "c":
            return "odor_c"
        case "f":
            return "odor_f"
        case "l":
            return "odor_l"
        case "m":
            return "odor_m"
        case "n":
            return "odor_n"
        case "p":
            return "odor_p"
        case "s":
            return "odor_s"
        case "y":
            return "odor_y"
        case _:
            return odor
        
def fix_gill_size(gill_size):
    # The gill-size column has 2 unique values: b, n
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be gill-size_b, gill-size_n
    match gill_size:
        case "b":
            return "gill-size_b"
        case "n":
            return "gill-size_n"
        case _:
            return gill_size
        
def fix_stalk_root(stalk_root):
    # The stalk-root column has 5 unique values: ?, b, c, e, r
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be stalk-root_?, stalk-root_b, stalk-root_c, stalk-root_e, stalk-root_r
    match stalk_root:
        case "?":
            return "stalk-root_?"
        case "b":
            return "stalk-root_b"
        case "c":
            return "stalk-root_c"
        case "e":
            return "stalk-root_e"
        case "r":
            return "stalk-root_r"
        case _:
            return stalk_root
        
def fix_stalk_surface_above_ring(stalk_surface_above_ring):
    # The stalk-surface-above-ring column has 4 unique values: f, k, s, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be stalk-surface-above-ring_f, stalk-surface-above-ring_k, stalk-surface-above-ring_s, stalk-surface-above-ring_y
    match stalk_surface_above_ring:
        case "f":
            return "stalk-surface-above-ring_f"
        case "k":
            return "stalk-surface-above-ring_k"
        case "s":
            return "stalk-surface-above-ring_s"
        case "y":
            return "stalk-surface-above-ring_y"
        case _:
            return stalk_surface_above_ring
        
def fix_stalk_surface_below_ring(stalk_surface_below_ring):
    # The stalk-surface-below-ring column has 4 unique values: f, k, s, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be stalk-surface-below-ring_f, stalk-surface-below-ring_k, stalk-surface-below-ring_s, stalk-surface-below-ring_y
    match stalk_surface_below_ring:
        case "f":
            return "stalk-surface-below-ring_f"
        case "k":
            return "stalk-surface-below-ring_k"
        case "s":
            return "stalk-surface-below-ring_s"
        case "y":
            return "stalk-surface-below-ring_y"
        case _:
            return stalk_surface_below_ring
        
def fix_stalk_color_above_ring(stalk_color_above_ring):
    # The stalk-color-above-ring column has 9 unique values: b, c, e, g, n, o, p, w, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be stalk-color-above-ring_b, stalk-color-above-ring_c, stalk-color-above-ring_e, stalk-color-above-ring_g, stalk-color-above-ring_n, stalk-color-above-ring_o, stalk-color-above-ring_p, stalk-color-above-ring_w, stalk-color-above-ring_y
    match stalk_color_above_ring:
        case "b":
            return "stalk-color-above-ring_b"
        case "c":
            return "stalk-color-above-ring_c"
        case "e":
            return "stalk-color-above-ring_e"
        case "g":
            return "stalk-color-above-ring_g"
        case "n":
            return "stalk-color-above-ring_n"
        case "o":
            return "stalk-color-above-ring_o"
        case "p":
            return "stalk-color-above-ring_p"
        case "w":
            return "stalk-color-above-ring_w"
        case "y":
            return "stalk-color-above-ring_y"
        case _:
            return stalk_color_above_ring
        
def fix_spore_print_color(spore_print_color):
    # The spore-print-color column has 9 unique values: b, h, k, n, o, r, u, w, y
    # The column is one-hot-encoded, so we need to fix the column names
    # The column names should be spore-print-color_b, spore-print-color_h, spore-print-color_k, spore-print-color_n, spore-print-color_o, spore-print-color_r, spore-print-color_u, spore-print-color_w, spore-print-color_y
    match spore_print_color:
        case "b":
            return "spore-print-color_b"
        case "h":
            return "spore-print-color_h"
        case "k":
            return "spore-print-color_k"
        case "n":
            return "spore-print-color_n"
        case "o":
            return "spore-print-color_o"
        case "r":
            return "spore-print-color_r"
        case "u":
            return "spore-print-color_u"
        case "w":
            return "spore-print-color_w"
        case "y":
            return "spore-print-color_y"
        case _:
            return spore_print_color

slot_bruises = 'Yes'

def fix_columns(X):
    df = pd.read_csv("simple1.csv")
    df[fix_bruises(slot_bruises)] = 1
    return df

df = fix_columns(X)    
df[["bruises_f", "bruises_t"]].head()

Unnamed: 0,bruises_f,bruises_t
0,0,1
