In [1]:
print("""
class TrueLift:
  def __init__(self, df, id_label, control_label, 
               target_value_label, prediction_value_label):
    (spark_df, str, str, str, str) -> None
    Input dataframe must contain 4 columns:
    label for row_id, control/treatment, target_value, and prediction_value.
    
  def report_grad(self, repartition=True, N=10):
    (bool, int) -> int, spark_df
    Calculate gradient and the minimum subset size.

  def update_pred(self, new_pred_df, new_pred_label):
    (spark_df, str) -> None
    Update self.df[self.pred_label] to the values from the input table.
""")

In [2]:
import pyspark.sql.functions as F
from pyspark.sql.types import *
import numpy as np
from pyspark.ml.feature import QuantileDiscretizer

In [3]:
class TrueLift:
  def __init__(self, df, id_label, control_label, 
               target_value_label, prediction_value_label):
    """
    (spark_df, str, str, str, str)
    Input dataframe must contain 4 columns:
    label for row_id, control/treatment, target_value, and prediction_value.
    """
    self.df = df.select(id_label, control_label, target_value_label, prediction_value_label)
    self.df.persist()
    self.id_label = id_label
    self.control_label = control_label
    self.target_value_label = target_value_label
    self.pred_label = prediction_value_label
    self.global_stats_df = df.groupby(self.control_label
                            ).agg(F.mean(self.target_value_label
                                  ).alias("ave_target_value")
                                 )
    control_value = self.global_stats_df.filter(F.col(self.control_label)==1
                                       ).select("ave_target_value"
                                       ).collect()[0][0]
    treatment_value = self.global_stats_df.filter(F.col(self.control_label)==0
                                         ).select("ave_target_value"
                                         ).collect()[0][0]
    self.global_lift = treatment_value-control_value
    self.data_size = df.count()
    self.treatment_size = df.filter(F.col(self.control_label)==0).count()
    self.control_size = df.filter(F.col(self.control_label)==1).count()
    return None
  
  def get_cuts(self, N=10):
    """
    Get the threshold values which can split the data 
    into N subsets based on the prediction value.
    """
    self.subset_count = N
    self.df = self.df.drop("subset")
    self.qt = QuantileDiscretizer(numBuckets=N, 
                                  inputCol=self.pred_label, 
                                  outputCol="subset").fit(self.df)
    self.subset_cuts = self.qt.getSplits()
    return None
  
  def split_subsets(self):
    """
    Assign subset values to the table.
    """
    self.df = self.df.drop("subset")
    self.df = self.qt.transform(self.df)
    self.df = self.df.withColumn("subset", F.col("subset").cast(IntegerType())+1)
    return None
  
  def get_subset_stats(self):
    """
    Get average values within each subset.
    """
    total_size_df = self.df.groupby("subset"
                              ).agg(F.count(self.pred_label).alias("subset_size"),
                                    F.mean(self.pred_label).alias("ave_pred_value"))
    control_stats_df = self.df.filter(F.col(self.control_label)==1
                             ).groupby("subset"
                             ).agg(F.count(self.pred_label).alias("control_subset_size"),
                                   F.mean(self.target_value_label).alias("ave_control_target_value"))
    treatment_stats_df = self.df.filter(F.col(self.control_label)==0
                               ).groupby("subset"
                               ).agg(F.count(self.pred_label).alias("treatment_subset_size"),
                                     F.mean(self.target_value_label).alias("ave_treatment_target_value"))
    self.subset_stats_df = total_size_df.join(control_stats_df, ["subset"], how="left"
                                       ).join(treatment_stats_df, ["subset"], how="left")
    self.subset_stats_df = self.subset_stats_df.withColumn("ave_lift", 
                                                           F.col("ave_treatment_target_value")
                                                          -F.col("ave_control_target_value"))
    min_control_subset_size = self.subset_stats_df.groupby(
                                 ).agg(F.min("control_subset_size")).collect()[0][0]
    min_treatment_subset_size = self.subset_stats_df.groupby(
                                   ).agg(F.min("treatment_subset_size")).collect()[0][0]
    return min(min_control_subset_size, min_treatment_subset_size)
  
  def loss_function(self):
    """
    () -> float
    Output the value of the loss function.
    Base on current subset assignments.
    """
    self.subset_stats_df = self.subset_stats_df.withColumn("Loss", 
                                                           (F.col("subset_size")/self.data_size)
                                                           *( (F.col("ave_pred_value")-F.col("ave_lift"))**2
                                                             -(F.col("ave_lift")-self.global_lift)**2
                                                            )
                                                          )
    loss = self.subset_stats_df.groupby().sum("Loss").collect()[0][0]
    return loss

  def get_tri_segment_cuts(self):
    """
    Calculate the threshold values that can split each 
    subset into 3 segments.
    """
    self.upper_cuts = list( (1/3)*np.array([c for c in self.subset_cuts[1:-1]])
                           +(2/3)*np.array([c for c in self.subset_cuts[2:]]) )
    self.lower_cuts = list( (2/3)*np.array([c for c in self.subset_cuts[0:-2]])
                           +(1/3)*np.array([c for c in self.subset_cuts[1:-1]]) )
    self.upper_cuts = [self.subset_cuts[1]-(self.lower_cuts[1]-self.subset_cuts[1])] + self.upper_cuts
    self.lower_cuts = self.lower_cuts + [self.subset_cuts[-2]+(self.subset_cuts[-2]-self.upper_cuts[-2])]
    return None
  
  def assign_segments(self):
    """
    Attach the segment values to the table.
    """
    upper_cuts = [float(x) for x in self.upper_cuts]
    lower_cuts = [float(x) for x in self.lower_cuts]
    get_upper_cuts = F.udf(lambda x: upper_cuts[x-1], FloatType())
    get_lower_cuts = F.udf(lambda x: lower_cuts[x-1], FloatType())
    self.df = self.df.withColumn("Upper_Cut", get_upper_cuts("subset"))
    self.df = self.df.withColumn("Lower_Cut", get_lower_cuts("subset"))
    self.df = self.df.withColumn("Segment", F.when(F.col("Pred")>F.col("Upper_Cut"), "Top"
                                            ).when(F.col("Pred")<F.col("Lower_Cut"), "Bot"
                                            ).otherwise("Mid") 
                                )
    self.df = self.df.drop("Upper_Cut", "Lower_Cut")
    return None
  
  def get_grad_lookup(self):
    """
    Obtain all parameters required to calculate the gradient
    whcih are constant within each subset.
    """
    # Attach all cut-value information into the stats table.
    upper_cuts = [float(x) for x in self.upper_cuts]
    lower_cuts = [float(x) for x in self.lower_cuts]
    subset_cuts = [float(x) for x in self.subset_cuts]
    get_upper_cuts = F.udf(lambda x: upper_cuts[x-1], FloatType())
    get_lower_cuts = F.udf(lambda x: lower_cuts[x-1], FloatType()) 
    get_upper_bounds = F.udf(lambda x: subset_cuts[x], FloatType())
    get_lower_bounds = F.udf(lambda x: subset_cuts[x-1], FloatType())
    self.grad_lookup_df = self.subset_stats_df.withColumn("Upper_Cut", get_upper_cuts("subset"))
    self.grad_lookup_df = self.grad_lookup_df.withColumn("Lower_Cut", get_lower_cuts("subset"))
    self.grad_lookup_df = self.grad_lookup_df.withColumn("Upper_Bound", get_upper_bounds("subset"))
    self.grad_lookup_df = self.grad_lookup_df.withColumn("Lower_Bound", get_lower_bounds("subset"))
    # Rename some columns to match the paper's convention.
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("Subset_Size", "S_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("ave_pred_value", "P_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("control_subset_size", "SC_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("treatment_subset_size", "ST_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("ave_lift", "lbar_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("ave_control_target_value", "ybarC_n")
    self.grad_lookup_df = self.grad_lookup_df.withColumnRenamed("ave_treatment_target_value", "ybarT_n")
    # Calculate the parameters in the gradient which are the same within a subset.
    self.grad_lookup_df = self.grad_lookup_df.withColumn("dLdp_bias", 
                                                         2*(F.col("P_n")-F.col("lbar_n"))
                                                         /self.data_size
                                                        )                                                        
    self.grad_lookup_df = self.grad_lookup_df.withColumn("dp_top", 
                                                         (F.col("Upper_Bound")-F.col("Upper_Cut"))/2
                                                        )
    self.grad_lookup_df = self.grad_lookup_df.withColumn("dp_bot", 
                                                         (F.col("Lower_Bound")-F.col("Lower_Cut"))/2
                                                        )
    self.grad_lookup_df = self.grad_lookup_df.withColumn("dLdlbar_n", 
                                                         (-2*(F.col("P_n")-F.col("lbar_n"))
                                                          -2*(F.col("lbar_n")-self.global_lift) )
                                                         *(F.col("S_n")/self.data_size)
                                                        )
    self.grad_lookup_df = self.grad_lookup_df.withColumn("dLdS_n", 
                                                         ( (F.col("P_n")-F.col("lbar_n"))**2
                                                          -(F.col("lbar_n")-self.global_lift)**2 )
                                                         /self.data_size
                                                        )
    self.grad_lookup_df = self.grad_lookup_df.select("subset", "dLdp_bias", "dp_top", "dp_bot",
                                                     "SC_n", "ST_n", "ybarC_n", "ybarT_n",
                                                     "dLdlbar_n", "dLdS_n" )
    shift_columns = ["SC_n", "ST_n", "ybarC_n", "ybarT_n",
                     "dLdlbar_n", "dLdS_n"]
    np1_df =self.grad_lookup_df.select( (F.col("subset")-1).alias("subset"),
                                       *[F.col(c).alias(c[:-1]+"(n+1)") for c in shift_columns])
    nm1_df =self.grad_lookup_df.select( (F.col("subset")+1).alias("subset"),
                                       *[F.col(c).alias(c[:-1]+"(n-1)") for c in shift_columns])
    self.grad_lookup_df = self.grad_lookup_df.join(np1_df, ["subset"], how="left")
    self.grad_lookup_df = self.grad_lookup_df.join(nm1_df, ["subset"], how="left")
    return None
  
  def calc_grad(self):
    columns_to_keep = self.df.columns[:]
    df = self.df.join(self.grad_lookup_df, ["subset"], how="left")
    df = df.withColumn("gradient", 
                        F.when(F.col("Segment")=="Mid", F.col("dLdp_bias")
                        ).when( (F.col("Segment")=="Top") 
                                &(F.col("Control_Flag")==0), 
                               F.col("dLdp_bias")
                               +(1/F.col("dp_top"))
                                *( F.col("dLdlbar_n")
                                   *((-F.col("Target_Value")+F.col("ybarT_n"))
                                     /F.col("ST_n")
                                    )
                                  +F.col("dLdlbar_(n+1)")
                                   *((F.col("Target_Value")-F.col("ybarT_(n+1)"))
                                     /F.col("ST_(n+1)")
                                    )
                                  -F.col("dLdS_n")+F.col("dLdS_(n+1)")
                                 )
                        ).when( (F.col("Segment")=="Top") 
                                &(F.col("Control_Flag")==1), 
                               F.col("dLdp_bias")
                               +(1/F.col("dp_top"))
                                *( F.col("dLdlbar_n")
                                   *((F.col("Target_Value")-F.col("ybarC_n"))
                                     /F.col("SC_n")
                                    )
                                  +F.col("dLdlbar_(n+1)")
                                   *((-F.col("Target_Value")+F.col("ybarC_(n+1)"))
                                     /F.col("SC_(n+1)")
                                    )
                                  -F.col("dLdS_n")+F.col("dLdS_(n+1)")
                                 )
                        ).when( (F.col("Segment")=="Bot") 
                                &(F.col("Control_Flag")==0), 
                               F.col("dLdp_bias")
                               +(1/F.col("dp_bot"))
                                *( F.col("dLdlbar_n")
                                   *((-F.col("Target_Value")+F.col("ybarT_n"))
                                     /F.col("ST_n")
                                    )
                                  +F.col("dLdlbar_(n-1)")
                                   *((F.col("Target_Value")-F.col("ybarT_(n-1)"))
                                     /F.col("ST_(n-1)")
                                    )
                                  -F.col("dLdS_n")+F.col("dLdS_(n-1)")
                                 )
                        ).otherwise(
                               F.col("dLdp_bias")
                               +(1/F.col("dp_bot"))
                                *( F.col("dLdlbar_n")
                                   *((F.col("Target_Value")-F.col("ybarC_n"))
                                     /F.col("SC_n")
                                    )
                                  +F.col("dLdlbar_(n-1)")
                                   *((-F.col("Target_Value")+F.col("ybarC_(n-1)"))
                                     /F.col("SC_(n-1)")
                                    )
                                  -F.col("dLdS_n")+F.col("dLdS_(n-1)")
                                 )
                        )
                       )
    return df.select(self.id_label, "gradient")
  
  def report_grad(self, repartition=True, N=10):
    """
    (bool, int) -> int, spark_df
    Calculate gradient and the minimum subset size.
    """
    if repartition:
      self.get_cuts(N)
      self.split_subsets()
      min_size = self.get_subset_stats()
      self.get_tri_segment_cuts()
      self.assign_segments()
      self.get_grad_lookup()
      grad = self.calc_grad()
    else:
      self.split_subsets()
      min_size = self.get_subset_stats()
      self.assign_segments()
      self.get_grad_lookup()
      grad = self.calc_grad()
    return min_size, grad
  
  def update_pred(self, new_pred_df, new_pred_label):
    """
    (spark_df, str) -> None
    Update self.df[self.pred_label] to the values from the input table.
    """
    self.df = self.df.join(new_pred_df, [self.id_label], how="left")
    self.df = self.df.withColumn(self.pred_label, F.col(new_pred_label))
    self.df = self.df.drop(new_pred_label)
    return None