In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%cd ..

In [None]:
import sys
sys.path.append("src/")

In [None]:
import numpy as np
import scipy.stats as stats
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from dataloader import load_data
from evaluate import *

In [None]:
sns.set_style("ticks")
dark_palette = np.array(sns.color_palette("dark", 8))[np.array([0,2,1,3,4,5,6,7]),:]
deep_palette = np.array(sns.color_palette("deep", 8))[np.array([0,2,1,3,4,5,6,7]),:]
sns.set_palette(dark_palette)
plt_colors = sns.color_palette()
mpl.rcParams["xtick.direction"] = "in"
mpl.rcParams["ytick.direction"] = "in"
mpl.rcParams.update({'font.size': 12})

In [None]:
X, y, w, t, cens = {}, {}, {}, {}, {}

In [None]:
buckets_ml, buckets_conv, arrs_ml, arrs_conv = {}, {}, {}, {}
pvals_ml, pvals_conv = {}, {}
framingham = {}
ascvd = {}
pred_rr_ml = {}
pred_rr_conv = {}

In [None]:
for dataset in ("combined", "sprint", "accord"):
  
  X[dataset] = np.load("results/xlearner/%s/X.npy" % dataset)
  w[dataset] = np.load("results/xlearner/%s/w.npy" % dataset)
  y[dataset] = np.load("results/xlearner/%s/y.npy" % dataset)
  t[dataset] = np.load("results/xlearner/%s/t.npy" % dataset)
  cens[dataset] = np.load("results/xlearner/%s/cens.npy" % dataset)
  
  framingham[dataset] = np.load("results/baselines/%s/framingham.npy" % dataset)
  ascvd[dataset] = np.load("results/baselines/%s/ascvd.npy" % dataset)
  
  pred_rr_ml[dataset] = np.load("results/xlearner/%s/pred_rr.npy" % dataset)
  pred_rr_conv[dataset] = np.load("results/logreg/%s/pred_rr.npy" % dataset)
  
  buckets_ml[dataset] = np.load("results/xlearner/%s/buckets.npy" % dataset)
  arrs_ml[dataset] = np.load("results/xlearner/%s/arrs.npy" % dataset).item()
  pvals_ml[dataset] = np.load("results/xlearner/%s/pvals.npy" % dataset)
  buckets_conv[dataset] = np.load("results/logreg/%s/buckets.npy" % dataset)
  arrs_conv[dataset] = np.load("results/logreg/%s/arrs.npy" % dataset).item()
  pvals_conv[dataset] = np.load("results/logreg/%s/pvals.npy" % dataset)

In [None]:
def calc_corr_in_pred_rr(dataset):
  print(stats.pearsonr(pred_rr_ml[dataset], pred_rr_conv[dataset]))

In [None]:
calc_corr_in_pred_rr("combined")

In [None]:
def forest_plot(dataset):
  n_trials = len(arrs_ml[dataset][BENEFIT_ASSIGNMENT])
  buckets, risk_reductions, lengths = [], [], []
  buckets += [1] * n_trials
  risk_reductions += arrs_ml[dataset][BENEFIT_ASSIGNMENT]
  lengths.append(np.sum(buckets_ml[dataset] == BENEFIT_ASSIGNMENT))
  buckets += [2] * n_trials
  risk_reductions += arrs_conv[dataset][BENEFIT_ASSIGNMENT]
  lengths.append(np.sum(buckets_conv[dataset] == BENEFIT_ASSIGNMENT))
  buckets += [4] * n_trials
  risk_reductions += arrs_ml[dataset][NO_BENEFIT_ASSIGNMENT]
  lengths.append(np.sum(buckets_ml[dataset] == NO_BENEFIT_ASSIGNMENT))
  buckets += [5] * n_trials
  risk_reductions += arrs_conv[dataset][NO_BENEFIT_ASSIGNMENT]
  lengths.append(np.sum(buckets_conv[dataset] == NO_BENEFIT_ASSIGNMENT))
  buckets = np.array(buckets)
  risk_reductions = np.array(risk_reductions)
  ytick_labels = [
    "$\mathbf{Benefit}$",
    "    Machine Learning",
    "    Conventional",
    "$\mathbf{No\ benefit}$   ",
    "    Machine Learning", 
    "    Conventional"]
  ytick_locs = [-0, -1,-1.5,-3,-4,-4.5]
  fig = plt.figure(figsize=(12, 4), tight_layout=True)
  pvals = {
    1: pvals_ml[dataset][0],
    2: pvals_conv[dataset][0],
    4: pvals_ml[dataset][1],
    5: pvals_conv[dataset][1]
  }
  minsofar, maxsofar = 0, 0
  for i, b in enumerate([1,2,4,5]):
    rng = get_range(risk_reductions[buckets == b])
    minsofar = min(minsofar, np.min(rng))
    maxsofar = max(maxsofar, np.max(rng))
    plt.plot(rng, [ytick_locs[b]] * 3, "-|", 
             color=plt_colors[0] if b == 1 or b == 4 else plt_colors[1])
    plt.plot([rng[1]], ytick_locs[b], "D", 
             color=plt_colors[0] if b == 1 or b == 4 else plt_colors[1])
  for i, b in enumerate([1,2,4,5]):
    rng = get_range(risk_reductions[buckets == b])
    if pvals[b] < 0.01:
        plt.text(maxsofar + 0.01, ytick_locs[b], 
                 "{:.4f} [{:.4f} {:.4f}], $P < 0.01$, $N$ = {}".format(
                   rng[1], rng[0], rng[2], lengths[i]))
    else:
        plt.text(maxsofar + 0.01, ytick_locs[b], 
                 "{:.4f} [{:.4f} {:.4f}], $P$ = {:.2f}, $N$ = {}".format(
                   rng[1], rng[0], rng[2], pvals[b], lengths[i]))
  ax = fig.get_axes()[0]
  ax.set_yticks([-0, -1,-1.5,-3,-4,-4.5])
  r = ax.set_yticklabels(ytick_labels, ha = 'left')
  plt.draw()
  ax.tick_params(axis=u'y', which=u'both',length=0)
  yax = ax.get_yaxis()
  pad = max(T.label.get_window_extent().width + 5 for T in yax.majorTicks)
  yax.set_tick_params(pad=pad)
  plt.xlabel("Average Risk Reduction")
  plt.xlim([-0.08, 0.16])
  plt.ylim([-5.9,0.5])
  plt.axvline(x=0.0, linestyle="--", color="grey")
  plt.savefig("./paper/img/{}_forest_plot.pdf".format(dataset))

In [None]:
forest_plot("sprint")

In [None]:
def plot_expected_vs_obs_rr(dataset, n_bins=5, cens_time=365.25 * 3, bin_strategy="rr"):
  plt.figure(figsize=(12, 4))
  plt.subplot(1,2,1)
  rss, slope, intercept, pred_rr, obs_rr, = calibration(pred_rr_ml[dataset], 
                                                        y[dataset],
                                                        w[dataset], 
                                                        t[dataset],
                                                        cens_time, n_bins=n_bins)
  plt.scatter(pred_rr, obs_rr, alpha=0.5, color=plt_colors[0])
  abline_values = [slope * i + intercept for i in [-0.15, 0.25]]
  plt.plot([-0.15, 0.25], abline_values, '--', color=plt_colors[0])
  plt.title("Machine Learning", fontsize=12, fontweight="bold")
  plt.xlim([-0.15,0.20])
  plt.ylim([-0.15,0.20])
  plt.text(-0.12, 0.15, "Slope: {:.2f}, Intercept: {:.2f}".format(slope, intercept))
  plt.xlabel("Predicted ARR")
  plt.ylabel("Observed ARR")
  plt.plot((-0.3,0.3), (-0.3, 0.3), "--", color="grey")
  plt.subplot(1,2,2)
  rss, slope, intercept, pred_rr, obs_rr, = calibration(pred_rr_conv[dataset], 
                                                        y[dataset],
                                                        w[dataset], 
                                                        t[dataset],
                                                        cens_time, n_bins=n_bins)
  rss = np.sum((np.array(obs_rr) - np.array(pred_rr)) ** 2)
  plt.scatter(pred_rr, obs_rr, alpha=0.5, color=plt_colors[1])
  abline_values = [slope * i + intercept for i in [-0.15, 0.25]]
  plt.plot([-0.15, 0.25], abline_values, '--', color=plt_colors[1])
  plt.title("Conventional", fontsize=12, fontweight="bold")
  plt.xlim([-0.15,0.20])
  plt.ylim([-0.15,0.20])
  plt.xlabel("Predicted ARR")
  plt.ylabel("Observed ARR")
  plt.text(-0.12, 0.15, "Slope: {:.2f}, Intercept: {:.2f}".format(slope, intercept))
  plt.plot((-0.3,0.3), (-0.3, 0.3), "--", color="grey")
  plt.savefig("./paper/img/{}_calibration_curve_by_pred_risk.pdf".format(dataset))

In [None]:
plot_expected_vs_obs_rr("combined", bin_strategy="rr")

In [None]:
def plot_pred_rr_against_baseline_decile(dataset, baseline_risk=cox, n_bins=10):
  bins = np.percentile(baseline_risk[dataset], q=np.linspace(0, 100, n_bins + 1))
  baseline_decile = np.linspace(0, 100, n_bins + 1)[np.digitize(baseline_risk[dataset], bins[:-1]) - 1]
  plt.figure(figsize=(10, 4))
  plt.axhline(y=0.0, linestyle="--", color="grey")
  sns.boxplot(x=np.r_[baseline_decile, baseline_decile] / 10 + 1, 
              y=np.r_[pred_rr_ml[dataset], pred_rr_conv[dataset]], 
              hue=np.r_[["Machine Learning"] * len(pred_rr_ml[dataset]), 
                        ["Conventional"] * len(pred_rr_conv[dataset])], 
              palette=deep_palette,
              showfliers=False)
  plt.ylabel("Predicted ARR")
  plt.xlabel("Baseline Risk Decile")
  plt.ylim((-0.15, 0.20))
  plt.savefig("./paper/img/{}_pred_rr_baseline_decile.pdf".format(dataset))

In [None]:
plot_pred_rr_against_baseline_decile("combined", baseline_risk=framingham)

In [None]:
def calc_baseline_decile_table(dataset, baseline_risk=cox, n_bins=10):
  bins = np.percentile(baseline_risk[dataset], q=np.linspace(0, 100, n_bins + 1))
  baseline_decile = np.linspace(0, 100, n_bins + 1)[np.digitize(baseline_risk[dataset], bins[:-1]) - 1]
  for i in sorted(list(set(baseline_decile))):
    u = pred_rr_ml["combined"][baseline_decile == i]
    v = pred_rr_conv["combined"][baseline_decile == i]
    print(f"{int(i / 10 + 1)},{np.percentile(u, 50):.4f},{np.percentile(u, 25):.4f},{np.percentile(u, 75):.4f},{np.percentile(v, 50):.4f},{np.percentile(v, 25):.4f},{np.percentile(v, 75):.4f}")

In [None]:
calc_baseline_decile_table("combined", baseline_risk=ascvd)

In [None]:
def plot_predicted_rr(dataset):
  plt.figure(figsize=(8, 4))
  pred_rr = pred_rr_ml[dataset]
  sns.kdeplot(pred_rr, label="Machine Learning", shade=True)
  pred_rr = pred_rr_conv[dataset]
  sns.kdeplot(pred_rr, label="Conventional", shade=True)
  plt.ylabel("Density")
  plt.xlabel("Predicted absolute risk reduction")
  plt.xlim([-0.15, 0.15])
  plt.legend()
  plt.savefig("./paper/img/{}_pred_rr_distributions.pdf".format(dataset))

In [None]:
plot_predicted_rr("combined")

In [None]:
def calculate_summary_stats(dataset, bucket=True):
  cols = load_data("accord")["cols"]
  if bucket:
    print("== ml [BEN | NOEFF]")
    print(sum(pred_rr_ml[dataset] > 0))
    print(sum(pred_rr_ml[dataset] <= 0))
    for i, col in enumerate(cols):
      ben = X[dataset][:,i][pred_rr_ml[dataset] > 0]
      noeff = X[dataset][:,i][pred_rr_ml[dataset] <= 0]
      print("{}:,{:.2f} ({:.2f}),{:.2f} ({:.2f})".format(col, ben.mean(), ben.std(),
                                                           noeff.mean(), noeff.std()))
    print("== conv [BEN | NOEFF]")
    print(sum(pred_rr_conv[dataset] > 0))
    print(sum(pred_rr_conv[dataset] <= 0))
    for i, col in enumerate(cols):
      ben = X[dataset][:,i][pred_rr_conv[dataset] > 0]
      noeff = X[dataset][:,i][pred_rr_conv[dataset] <= 0]
      print("{}:,{:.2f} ({:.2f}),{:.2f} ({:.2f})".format(col, ben.mean(), ben.std(),
                                                              noeff.mean(), noeff.std()))
  else:
    for i, col in enumerate(cols):
      print("{}:,{:.2f} ({:.2f})".format(col, X[dataset][:,i].mean(),
                                              X[dataset][:,i].std()))

In [None]:
calculate_summary_stats("combined", bucket=True)

In [None]:
calculate_summary_stats("combined", bucket=False)

In [None]:
def plot_matching_patient_pairs(dataset):
  random.seed(1)
  plt.figure(figsize=(11, 4))
  plt.subplot(1,2,1)
  tuples = list(zip(pred_rr_ml[dataset][cens[dataset] == 0], 
                    y[dataset][cens[dataset] == 0], 
                    w[dataset][cens[dataset] == 0]))
  untreated = list(filter(lambda t: t[2] == 0, tuples))
  treated = list(filter(lambda t: t[2] == 1, tuples))
  if len(treated) < len(untreated):
    untreated = random.sample(untreated, len(treated))
  if len(untreated) < len(treated):
    treated = random.sample(treated, len(untreated))
  assert len(untreated) == len(treated)
  untreated = sorted(untreated, key=lambda t: t[0])
  treated = sorted(treated, key=lambda t: t[0])
  plt.scatter(np.array(treated)[:,0], np.array(untreated)[:,0], marker=".", alpha=1e-2)
  plt.plot((-0.3, 0.3), (-0.3, 0.3), "--", color="grey")
  plt.xlabel("Predicted ARR, intensive arm")
  plt.ylabel("Predicted ARR, standard arm")
  plt.ylim(-0.3, 0.3)
  plt.xlim(-0.3, 0.3)
  plt.title("Machine Learning", fontsize=12, fontweight="bold")
  plt.subplot(1,2,2)
  tuples = list(zip(pred_rr_conv[dataset], y[dataset], w[dataset]))
  untreated = list(filter(lambda t: t[2] == 0, tuples))
  treated = list(filter(lambda t: t[2] == 1, tuples))
  if len(treated) < len(untreated):
    untreated = random.sample(untreated, len(treated))
  if len(untreated) < len(treated):
    treated = random.sample(treated, len(untreated))
  assert len(untreated) == len(treated)
  untreated = sorted(untreated, key=lambda t: t[0])
  treated = sorted(treated, key=lambda t: t[0])
  plt.scatter(np.array(treated)[:,0], np.array(untreated)[:,0], marker=".", alpha=2e-3, color=plt_colors[1])
  plt.plot((-0.3, 0.3), (-0.3, 0.3), "--", color="grey")
  plt.xlabel("Predicted ARR, intensive arm")
  plt.ylabel("Predicted ARR, standard arm")
  plt.title("Conventional", fontsize=12, fontweight="bold")
  plt.ylim(-0.3, 0.3)
  plt.xlim(-0.3, 0.3)
  plt.savefig("./paper/img/{}_matching_patient_pairs.pdf".format(dataset))

In [None]:
plot_matching_patient_pairs("combined")