### Survival regression

In spark.ml, we implement the Accelerated failure time (AFT) model which is a parametric survival regression model for censored data. It describes a model for the log of survival time, so it’s often called a log-linear model for survival analysis. Different from a Proportional hazards model designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently.

In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors

In [7]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("surv").getOrCreate()

In [8]:
training = spark.createDataFrame([
    (1.218, 1.0, Vectors.dense(1.560, -0.605)),
    (2.949, 0.0, Vectors.dense(0.346, 2.158)),
    (3.627, 0.0, Vectors.dense(1.380, 0.231)),
    (0.273, 1.0, Vectors.dense(0.520, 1.151)),
    (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])

In [9]:
quantileProbabilities = [0.3, 0.6]
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,quantilesCol="quantiles")

model = aft.fit(training)

In [10]:
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
model.transform(training).show(truncate=False)

Coefficients: [-0.4963111466650707,0.19844437699933098]
Intercept: 2.63809461510401
Scale: 1.5472345574364692
+-----+------+--------------+------------------+---------------------------------------+
|label|censor|features      |prediction        |quantiles                              |
+-----+------+--------------+------------------+---------------------------------------+
|1.218|1.0   |[1.56,-0.605] |5.718979487635007 |[1.1603238947151657,4.99545601027477]  |
|2.949|0.0   |[0.346,2.158] |18.07652118149533 |[3.6675458454717362,15.789611866277625]|
|3.627|0.0   |[1.38,0.231]  |7.381861804239096 |[1.497706130519082,6.44796261233896]   |
|0.273|1.0   |[0.52,1.151]  |13.577612501425284|[2.7547621481506837,11.8598722240697]  |
|4.199|0.0   |[0.795,-0.226]|9.013097744073898 |[1.8286676321297812,7.87282650587843]  |
+-----+------+--------------+------------------+---------------------------------------+

