In [2]:
from pyspark.ml.classification import LogisticRegression
from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("LogisticRegressionWithElasticNet")\
        .getOrCreate()

    # Load training data
    training = spark.read.format("libsvm").load("sample_libsvm_data.txt")

    lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)

    # Fit the model
    lrModel = lr.fit(training)

    # Print the coefficients and intercept for logistic regression
    print("Coefficients: " + str(lrModel.coefficients))
    print("Intercept: " + str(lrModel.intercept))

    # We can also use the multinomial family for binary classification
    mlr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8, family="multinomial")

    # Fit the model
    mlrModel = mlr.fit(training)

    # Print the coefficients and intercepts for logistic regression with multinomial family
    print("Multinomial coefficients: " + str(mlrModel.coefficientMatrix))
    print("Multinomial intercepts: " + str(mlrModel.interceptVector))

    spark.stop()

Coefficients: (692,[272,300,323,350,351,378,379,405,406,407,428,433,434,435,455,456,461,462,483,484,489,490,496,511,512,517,539,540,568],[-7.52068987138421e-05,-8.115773146847101e-05,3.814692771846369e-05,0.0003776490540424337,0.00034051483661944103,0.0005514455157343105,0.0004085386116096913,0.000419746733274946,0.0008119171358670028,0.0005027708372668751,-2.3929260406601844e-05,0.000574504802090229,0.0009037546426803721,7.818229700244018e-05,-2.1787551952912764e-05,-3.4021658217896256e-05,0.0004966517360637634,0.0008190557828370367,-8.017982139522704e-05,-2.7431694037836214e-05,0.0004810832226238988,0.00048408017626778765,-8.926472920011488e-06,-0.00034148812330427335,-8.950592574121486e-05,0.00048645469116892167,-8.478698005186209e-05,-0.0004234783215831763,-7.29653577763134e-05])
Intercept: -0.5991460286401435
Multinomial coefficients: 2 X 692 CSRMatrix
(0,272) 0.0001
(0,300) 0.0001
(0,350) -0.0002
(0,351) -0.0001
(0,378) -0.0003
(0,379) -0.0002
(0,405) -0.0002
(0,406) -0.0004
(0,4