In [None]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import AFTSurvivalRegression

In [None]:
spark = SparkSession \
    .builder \
    .appName("Survival Regression") \
    .getOrCreate()

### Logged Time To Failure Data

label = unit of time, say months, equipment fails
censor = 1 means occured, say time (by the label) to failure, uncensored
censor = 0 means censored, failure not occured, say time (by the label) to maintenance
features contains feature columns, such as machine age and temperature, more example such as 

Haeat Attack Study:
https://web.archive.org/web/20170517071528/http://www.umass.edu/statdata/statdata/data/whas500.txt


In [None]:
training = spark.createDataFrame((
(1.218, 1.0, Vectors.dense(1.560, -0.605)), 
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)), 
(0.273, 1.0, Vectors.dense(0.520, 1.151)), 
(4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")

In [None]:
training.show()

### Predict 2 quantile time to failure at 30% chance and at 60% chance 

In [None]:
quantileProbabilities = (0.3, 0.9)

### train the model with training data above with AFTSurvivalRegression

In [None]:
aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities,quantilesCol="quantiles",\
                            censorCol="censor",featuresCol="features",labelCol="label") 
model = aft.fit(training)

### # Print the coefficients, intercept and scale parameter for AFT survival regression

In [None]:
print("Coefficients: {}".format(model.coefficients))
print("Intercept: {}".format(model.intercept))
print("Scale: {}".format(model.scale)) 


### transform the data based on model

prediction = time unit to fail when censor = 1 uncensored
prediction = time unit to other event such as maintenance when censor = 0 (Censored)
1st element of quantiles = time unit at 30% chance
2nd element of quantiles = time unit at 60% chance

In [None]:
model.transform(training).show(truncate=False)