<a href="https://colab.research.google.com/github/samlucas28/Thesis-python-scripts/blob/main/Decision_Tree_BA_OWSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -----------------------------------------------------------
# DECISION TREE CODE FOR BASE-CASE/ONE-WAY SENSITIVITY ANALYSES
# -----------------------------------------------------------

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Callable, Union
import numpy as np
from scipy.stats import beta, gamma, dirichlet, lognorm
import pandas as pd


# 1. DEFINING THE COHORT
# ----------------------------------------------
rng = np.random.default_rng(seed=1)

N = 50000  # Cohort size

cohort = pd.DataFrame({
    "start_age": rng.integers(38, 46, N),  # Random ages from 38 to 45
    "imd": rng.integers(1, 11, N) # IMD based on normal distribution
})

# 2. DEFINING THE DISTRIBUTIONS
# -----------------------------------------------
def dist_lognormal(mu_l: float, sigma_l: float): # Lognormal
    return lognorm(s=sigma_l, scale=np.exp(mu_l))

def dist_beta(a: float, b: float): # Beta
    return beta(a, b)

def dist_gamma(k: float, rate: float): # Gamma
    return gamma(k, scale=1 / rate) # Rate was originally used

def dist_dirichlet(alphas: List[float]): # Dirichlet
    return dirichlet(alphas)

# 3. ASSIGNING THE DISTRIBUTION TO PARAMETERS
# -----------------------------------------------
PARAMS: Dict[str, Union[float, Tuple, list, Callable]] = {
    "COST_FHQS_TEXT":        dist_lognormal(3.091, 0.618), # £ of FHQS development/sending the text
    "COST_UPTAKE":           dist_gamma(32.83, 3.29), # £ risk assessment cost
    "COST_SYMPTOM":          dist_gamma(894.8, 0.780), # £ symptomatic presentation cost
    "COST_HIGH_RISK":        351.88, # £ clinical genetics cost (fixed)
    "COST_REFERRAL":         66.93, # £ referral cost (fixed)

    "P_FHQS_UPTAKE":         dist_beta(203, 1277), # FHQS uptake probability (intervention)
    "P_RISK_SPLIT_FHQS":     dist_dirichlet([88.7, 10.3, 1.0]), # BC risk after FHQS (intervention)

    "P_SYMPT_LOWRISK":       dist_beta(35008, 2413112), # Probability of low risk Symptomatic detection (any arm)
    "P_SYMPT_UNCLASSIFIED":  dist_beta(44712, 2715288), # Probability of symptomatic detection (any arm)

    "IMD_REFERRAL": { # Control arm referral probability by IMD
        1:  dist_beta(5, 19224), 2:  dist_beta(17, 56094), 3:  dist_beta(24, 67475),
        4:  dist_beta(27, 67777), 5:  dist_beta(40, 89535), 6:  dist_beta(60, 119543),
        7:  dist_beta(58, 107788), 8:  dist_beta(65, 111408), 9:  dist_beta(71, 112480),
        10: dist_beta(60, 89261),
    },

    "P_RISK_SPLIT_REFERRAL": dist_dirichlet([50.3, 40.4, 9.3]), # BC risk after referral (control)
}

# 4. DEFINING THE DECISION TREE NODES
# ---------------------------------------------------
@dataclass
class Node: # Base class for all nodes in the tree
    def expected_cost(self) -> float: # Average cost if tree is evaluated deterministically
        raise NotImplementedError

    def expected_outcomes(self) -> dict: # Average health outcomes if tree is evaluated deterministically
        raise NotImplementedError

    def simulate(self, deterministic=False) -> Tuple[float, dict]: # Runs simulation through the node with a flag for stochastic or deterministic analyses
        raise NotImplementedError

@dataclass
class Terminal(Node): # End point of the tree
    cost: Union[float, Callable[[], float]]
    label: str = "TERMINAL"

    def get_cost(self, deterministic=False) -> float: # Gets cost for deterministic or stochastic simulation.
        if deterministic:
            if callable(self.cost) and hasattr(self.cost, "mean"):
                return self.cost.mean() # When deterministic, if cost has a mean return mean, if cost is fixed return fixed number
            elif callable(self.cost):
                return self.cost()
            else:
                return self.cost
        else:
            return self.cost() if callable(self.cost) else self.cost # If not deterministic, sample cost from the distribution or return fixed cost if fixed

    def expected_cost(self) -> float:
        return self.get_cost(deterministic=True)

    def expected_outcomes(self) -> dict:
        return {
            'symptomatic': 1 if self.label == "SYMPT_PRESENT" else 0,
            'moderate': 1 if self.label == "MODERATE_RISK" else 0,
            'high': 1 if self.label == "HIGH_RISK" else 0,
        } # Maps the labels into a dictionary

    def simulate(self, deterministic=False) -> Tuple[float, dict, str]:
        return self.get_cost(deterministic=deterministic), self.expected_outcomes(), self.label

@dataclass
class Chance(Node): # Probabilistic node
    branches: List[Tuple[Callable[[], float], Node]] # list of probabilites and the next node in the tree

    def expected_cost(self) -> float:
        return sum(p_fn() * node.expected_cost() for p_fn, node in self.branches) # Return weighted sum of expected costs

    def expected_outcomes(self) -> dict:
        outcomes = {'symptomatic': 0, 'moderate': 0, 'high': 0}
        for p_fn, node in self.branches:
            p = p_fn()
            child_outcomes = node.expected_outcomes()
            for k in outcomes:
                outcomes[k] += p * child_outcomes[k]
        return outcomes # Returns weighted sum of outcomes

    def simulate(self, deterministic=False) -> Tuple[float, dict, str]:
        probs = np.array([p_fn() for p_fn, _ in self.branches]) # iterates over all branches
        probs /= probs.sum() # Nomralizes probabilites so they sum to 1
        choice = np.random.choice(len(self.branches), p=probs) # Randomly selects one branch according to these probabilities
        return self.branches[choice][1].simulate(deterministic=deterministic) # selects the tuple (probability, node) corresponding to the chosen branch.

@dataclass
class Decision(Node): # Represents the starting choice: send text or not
    options: Dict[str, Node]

    def expected_cost(self, option: str) -> float:
        return self.options[option].expected_cost()

    def expected_outcomes(self, option: str) -> dict:
        return self.options[option].expected_outcomes()

    def simulate(self, option: str, deterministic=False) -> Tuple[float, dict, str]:
        return self.options[option].simulate(deterministic=deterministic)

@dataclass
class AddCost(Node): # Adds cost to the node before moving to next node
    cost: Union[float, Callable[[], float]]
    next_node: Node

    def get_cost(self, deterministic=False) -> float:
        if deterministic:
            if callable(self.cost) and hasattr(self.cost, "mean"):
                return self.cost.mean()
            elif callable(self.cost):
                return self.cost()
            else:
                return self.cost
        else:
            return self.cost() if callable(self.cost) else self.cost

    def expected_cost(self) -> float:
        return self.get_cost(deterministic=True) + self.next_node.expected_cost()

    def expected_outcomes(self) -> dict:
        return self.next_node.expected_outcomes()

    def simulate(self, deterministic=False) -> Tuple[float, dict, str]:
        cost_here = self.get_cost(deterministic=deterministic)
        cost_next, outcomes_next, label = self.next_node.simulate(deterministic=deterministic)
        return cost_here + cost_next, outcomes_next, label


# 5. TREE CONSTRUCTION
# --------------------------------------------------------
def make_tree(imd_decile: int, deterministic: bool = False,
              override_p_fhqs_uptake: float = None, # Allows for sensitivity analyses of FHQS uptake keep other parameters fixed
              override_cost_fhqs_text: float = None, # Allows for sensitivity analyses of FHQS implementation cost keep other parameters fixed
              override_cost_symptom: float = None) -> Decision: # Allows for sensitivity analyses of symtomatic detection cost keep other parameters fixed

    def get_val(dist): # helper function for deterministic or stochastic simulation
        if isinstance(dist, (float, int)):
            return dist
        return dist.mean() if deterministic else dist.rvs() # If deterministic draw the mean, if stochastic draw from sample

    # Terminal nodes
    no_sympt_term = Terminal(0.0, label="NO_SYMPT_PRESENT") # # no symptomatic detection for patients across all arms, add no cost
    sympt_term = Terminal(override_cost_symptom if override_cost_symptom is not None else get_val(PARAMS["COST_SYMPTOM"]), label="SYMPT_PRESENT") # symptomatic detection for patients across all arms, add cost
    mod_risk_term = Terminal(0.0, label="MODERATE_RISK") # Identified as moderate risk through FHQS or referral, add no cost, feeds Markov
    high_risk_term = Terminal(get_val(PARAMS["COST_HIGH_RISK"]), label="HIGH_RISK") # Identified as high risk through FHQS or referral, add clinical genetics cost, feeds Markov

    # Symptomatic detection (any arm)
    def make_sympt_node(beta_dist):
        p = get_val(beta_dist)
        return Chance([ # Initiates symptom chance node
            (lambda p=p: p, sympt_term), # If detected symptomatically, move to symptom terminal node
            (lambda p=p: 1 - p, no_sympt_term) # If no symptomatic detection, move to no symptom terminal node
        ])

    low_risk_sympt_node = make_sympt_node(PARAMS["P_SYMPT_LOWRISK"]) # probability of low risk patients being detected symptomatically
    unclass_sympt_node = make_sympt_node(PARAMS["P_SYMPT_UNCLASSIFIED"]) # probability of unclassified patients being detected symptomatically

    def dirichlet_probs(dist): # generates one random sample from the Dirichlet—a 3‑element NumPy array
        if deterministic:
            alpha = dist.alpha
            total = sum(alpha)
            return tuple(a / total for a in alpha)
        else:
            draw = dist.rvs()[0]
            return draw[0], draw[1], draw[2]

    # BC risk split after FHQS (Intervention)
    def fhqs_risk_node():
        p_low, p_mod, p_high = dirichlet_probs(PARAMS["P_RISK_SPLIT_FHQS"]) # Assigning the Dirichlet distribution to BC risk node
        return Chance([ # Initiates BC risk node
            (lambda p=p_low: p, low_risk_sympt_node), # If low risk of BC, move to the low risk symptomatic node
            (lambda p=p_mod: p, mod_risk_term), # If moderate risk of BC, move to moderate risk terminal node
            (lambda p=p_high: p, high_risk_term), # If high risk of BC, move to high risk terminal node
        ])

    # BC risk split after referral (Control)
    def referral_risk_node():
        p_low, p_mod, p_high = dirichlet_probs(PARAMS["P_RISK_SPLIT_REFERRAL"]) # Assingning the Dirichlet distribution to BC risk node
        return Chance([ # Intiaties the BC risk node
            (lambda p=p_low: p, low_risk_sympt_node),  # If low risk of BC, move to the low risk symptomatic node
            (lambda p=p_mod: p, mod_risk_term), # If moderate risk of BC, move to moderate risk terminal node
            (lambda p=p_high: p, high_risk_term), # If high risk of BC, move to high risk terminal node
        ])

    # Override/sampled values
    p_uptake_val = override_p_fhqs_uptake if override_p_fhqs_uptake is not None else get_val(PARAMS["P_FHQS_UPTAKE"])
    cost_uptake_val = get_val(PARAMS["COST_UPTAKE"])
    cost_text_val = override_cost_fhqs_text if override_cost_fhqs_text is not None else get_val(PARAMS["COST_FHQS_TEXT"])

    # FHQS uptake (Intervention)
    uptake_node = Chance([ # # Intiates uptake node
        (lambda: p_uptake_val, AddCost(lambda: cost_uptake_val, fhqs_risk_node())), # If FHQS is completed, add cost and move to BC risk node
        (lambda: 1 - p_uptake_val, unclass_sympt_node), # If FHQS is not completed, add no cost and move to unclassfied symptomatic detection node
    ])

    send_text_branch = AddCost(lambda: cost_text_val, uptake_node) # Add text‑message cost then go to uptake chance node

    p_referral_val = get_val(PARAMS["IMD_REFERRAL"][imd_decile])
    cost_referral_val = get_val(PARAMS["COST_REFERRAL"])

    # Referral to genetics (control)
    referral_node = Chance([ # Initiates referral node
        (lambda: p_referral_val, AddCost(lambda: cost_referral_val, referral_risk_node())), # If patient is referred, add referral cost and move to BC risk node
        (lambda: 1 - p_referral_val, unclass_sympt_node), # If patient is not referred, add no cost and move to the unclassified symptomatic detection node
    ])

    return Decision({
        "SEND_TEXT": send_text_branch,
        "NO_TEXT": referral_node,
    })

# 6. EVALUATE TREE
# -----------------------------
def evaluate_tree(imd: int, p_fhqs_uptake: float = None, cost_fhqs_text: float = None, cost_symptom: float = None): # Evaluates the tree for deterministic scenario
    tree = make_tree(imd, deterministic=True,
                     override_p_fhqs_uptake=p_fhqs_uptake,
                     override_cost_fhqs_text=cost_fhqs_text,
                     override_cost_symptom=cost_symptom)

    cost_send = tree.expected_cost("SEND_TEXT")
    outcomes_send = tree.expected_outcomes("SEND_TEXT")

    cost_no_text = tree.expected_cost("NO_TEXT")
    outcomes_no_text = tree.expected_outcomes("NO_TEXT")

    return {
        "SEND_TEXT": {"cost": cost_send, **outcomes_send},
        "NO_TEXT": {"cost": cost_no_text, **outcomes_no_text},
    }

# 7. EVALUATE FULL COHORT DETERMINISTICALLY (BASE CASE ANALYSIS)
# ---------------------------------------------------------------

# Initialize results for whole cohort
totals = {
    "SEND_TEXT": {"cost": 0.0, "symptomatic": 0.0, "moderate": 0.0, "high": 0.0},
    "NO_TEXT":   {"cost": 0.0, "symptomatic": 0.0, "moderate": 0.0, "high": 0.0},
}

# Initialize results by IMD group
totals_by_imd = {
    imd: {
        "SEND_TEXT": {"cost": 0.0, "symptomatic": 0.0, "moderate": 0.0, "high": 0.0},
        "NO_TEXT":   {"cost": 0.0, "symptomatic": 0.0, "moderate": 0.0, "high": 0.0},
    }
    for imd in cohort["imd"].unique()
}

# Looping through each individual in the cohort
for _, row in cohort.iterrows():
    imd = row["imd"]
    result = evaluate_tree(imd)

    for arm in ["SEND_TEXT", "NO_TEXT"]:
        # Update overall totals
        for k in totals[arm]:
            totals[arm][k] += result[arm][k]

        # Update IMD-specific totals
        for k in totals_by_imd[imd][arm]:
            totals_by_imd[imd][arm][k] += result[arm][k]

# Print full cohort totals
N = len(cohort)
print(f"\n=== TOTAL EXPECTED FOR {N} WOMEN ===")
for arm in ["SEND_TEXT", "NO_TEXT"]:
    print(f"\n{arm}:")
    print(f"  Total cost: £{totals[arm]['cost']:.2f}")
    print(f"  Total symptomatic cases: {totals[arm]['symptomatic']:.2f}")
    print(f"  Total moderate risk cases: {totals[arm]['moderate']:.2f}")
    print(f"  Total high risk cases: {totals[arm]['high']:.2f}")

# Print totals by IMD group
print("\n=== TOTALS BY IMD GROUP ===")
for imd, results in totals_by_imd.items():
    print(f"\nIMD {imd}:")
    for arm in ["SEND_TEXT", "NO_TEXT"]:
        print(f"  {arm}:")
        print(f"    Total cost: £{results[arm]['cost']:.2f}")
        print(f"    Total symptomatic cases: {results[arm]['symptomatic']:.2f}")
        print(f"    Total moderate risk cases: {results[arm]['moderate']:.2f}")
        print(f"    Total high risk cases: {results[arm]['high']:.2f}")


# 8. ONE-WAY SENSITIVITY ANALYSIS
# -----------------------------
def one_way_sensitivity_analysis(param_name: str, values: List[float]) -> pd.DataFrame: # Function to perform OWSA
    results = []

    for val in values:
        # Initialize totals for this parameter value
        total_cost_send = 0.0
        total_cost_no_text = 0.0
        total_sympt_send = 0.0
        total_sympt_no_text = 0.0
        total_mod_send = 0.0
        total_mod_no_text = 0.0
        total_high_send = 0.0
        total_high_no_text = 0.0

        # Loop over cohort
        for _, row in cohort.iterrows():
            imd = row["imd"]

            if param_name == "P_FHQS_UPTAKE":
                res = evaluate_tree(imd, p_fhqs_uptake=val)
            elif param_name == "COST_FHQS_TEXT":
                res = evaluate_tree(imd, cost_fhqs_text=val)
            elif param_name == "COST_SYMPTOM":
                res = evaluate_tree(imd, cost_symptom=val)
            else:
                raise ValueError("Unsupported parameter")

            # Aggregate results for SEND_TEXT (intervention) arm
            total_cost_send += res["SEND_TEXT"]["cost"]
            total_sympt_send += res["SEND_TEXT"]["symptomatic"]
            total_mod_send += res["SEND_TEXT"]["moderate"]
            total_high_send += res["SEND_TEXT"]["high"]

            # Aggregate results for NO_TEXT (control) arm
            total_cost_no_text += res["NO_TEXT"]["cost"]
            total_sympt_no_text += res["NO_TEXT"]["symptomatic"]
            total_mod_no_text += res["NO_TEXT"]["moderate"]
            total_high_no_text += res["NO_TEXT"]["high"]

        # Appends the results for both arms
        results.append({
            "value": val,
            "total_cost_send": total_cost_send,
            "total_symptomatic_send": total_sympt_send,
            "total_moderate_send": total_mod_send,
            "total_high_send": total_high_send,
            "total_cost_no_text": total_cost_no_text,
            "total_symptomatic_no_text": total_sympt_no_text,
            "total_moderate_no_text": total_mod_no_text,
            "total_high_no_text": total_high_no_text
        })

    return pd.DataFrame(results)


# 9. RESULTS FOR ONE-WAY SENSITIVITY ANALYSES
# ---------------------------------------------------
uptake_FHQS_range = np.linspace(0.05, 0.25, 5) # Defining the FHQS uptake probability range
owsa_total_uptake = one_way_sensitivity_analysis("P_FHQS_UPTAKE", uptake_FHQS_range)

print("\nTOTAL RESULTS FOR FHQS UPTAKE PROBABILITY OWSA")
print(owsa_total_uptake)

cost_FHQS_range = np.linspace(10, 50, 5) # Defining FHQS set up cost range
owsa_total_cost = one_way_sensitivity_analysis("COST_FHQS_TEXT", cost_FHQS_range)

print("\nTOTAL RESULTS FOR FHQS SETUP COST OWSA")
print(owsa_total_cost)

cost_symptom_range = np.linspace(800, 1500, 5) # Defining the symptomatic detection cost range
owsa_total_symptom = one_way_sensitivity_analysis("COST_SYMPTOM", cost_symptom_range)

print("\nTOTAL RESULTS FOR SYMPTOM COST OWSA")
print(owsa_total_symptom)


=== TOTAL EXPECTED FOR 50000 WOMEN ===

SEND_TEXT:
  Total cost: £2325519.53
  Total symptomatic cases: 785.89
  Total moderate risk cases: 706.39
  Total high risk cases: 68.58

NO_TEXT:
  Total cost: £931312.69
  Total symptomatic cases: 809.79
  Total moderate risk cases: 9.49
  Total high risk cases: 2.18

=== TOTALS BY IMD GROUP ===

IMD 8:
  SEND_TEXT:
    Total cost: £232784.50
    Total symptomatic cases: 78.67
    Total moderate risk cases: 70.71
    Total high risk cases: 6.86
  NO_TEXT:
    Total cost: £93275.14
    Total symptomatic cases: 81.05
    Total moderate risk cases: 1.18
    Total high risk cases: 0.27

IMD 5:
  SEND_TEXT:
    Total cost: £229714.82
    Total symptomatic cases: 77.63
    Total moderate risk cases: 69.78
    Total high risk cases: 6.77
  NO_TEXT:
    Total cost: £91984.90
    Total symptomatic cases: 79.99
    Total moderate risk cases: 0.89
    Total high risk cases: 0.21

IMD 4:
  SEND_TEXT:
    Total cost: £230552.01
    Total symptomatic cases