# Pseudo Data Generation

In [13]:
import pandas as pd
import numpy as np

# Set seed for reproducibility
np.random.seed(42)

# Regimen allocation
regimen_counts = {
    "Split I": 407,
    "Split II": 415,
    "Control": 418
}

# Outcome distributions from the paper
favourable_response_rates = {
    "Split I": 0.912,
    "Split II": 0.915,
    "Control": 0.867
}

# Gastro symptoms (higher in Control)
gastro_symptom_rates = {
    "Split I": 0.10,
    "Split II": 0.09,
    "Control": 0.17
}

# Drug resistance distribution
drug_resistance_options = ['None', 'H', 'R', 'H+R']
drug_resistance_probs = [0.85, 0.08, 0.05, 0.02]

# Helper function to generate patients for each group
def generate_patients(regimen, n):
    age = np.random.normal(32.5, 11.5, n).astype(int)
    weight = np.round(np.random.normal(40.5, 6.5, n), 1)
    sex = np.random.choice(['Male', 'Female'], size=n)
    smear_grade = np.random.choice(['0/1+', '2+/3+'], size=n, p=[0.45, 0.55])
    prev_treatment = np.random.choice(['<15 days', '>=15 days'], size=n, p=[0.8, 0.2])
    hiv_positive = [False] * n
    vision_defect = [False] * n
    favourable = np.random.rand(n) < favourable_response_rates[regimen]
    unfavourable = ~favourable
    relapse = favourable & (np.random.rand(n) < 0.04)
    quiescent = favourable & ~relapse
    gastro_symptoms = np.random.rand(n) < gastro_symptom_rates[regimen]
    drug_resistance = np.random.choice(drug_resistance_options, size=n, p=drug_resistance_probs)
    compliance = np.random.choice(['>75%', '<=75%'], size=n, p=[0.9, 0.1])
    follow_up = np.where(unfavourable, np.random.randint(1, 60, size=n), 60)

    return pd.DataFrame({
        "PatientID": [f"{regimen[:2]}-{i+1}" for i in range(n)],
        "Regimen": [regimen] * n,
        "Age": age,
        "Weight": weight,
        "Sex": sex,
        "SmearGrade": smear_grade,
        "PreviousTreatment": prev_treatment,
        "HIV_Positive": hiv_positive,
        "VisionDefect": vision_defect,
        "FavourableResponse": favourable,
        "UnfavourableResponse": unfavourable,
        "QuiescentAt5Yrs": quiescent,
        "Relapse": relapse,
        "GastroSymptoms": gastro_symptoms,
        "DrugResistance": drug_resistance,
        "Compliance": compliance,
        "FollowUpMonths": follow_up
    })

# Combine all patients
df_all = pd.concat([
    generate_patients(regimen, count)
    for regimen, count in regimen_counts.items()
], ignore_index=True)

# print(df_all.head())

# Save to CSV
csv_path = "data/synthetic_tb_clinical_trial_data.csv"
df_all.to_csv(csv_path, index=False)

# csv_path