In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master('local[1]').appName('GroundStateEnergies').getOrCreate()

In [31]:
spark

In [17]:
base_cols = ['pubchem_id', 'Eat']
num_cols = 5
selected_cols = [str(i) for i in range(num_cols)]
selected_cols.extend(base_cols)
data = spark.read.csv('ground_state_energies/roboBohr.csv', header=True, inferSchema=True).select(selected_cols)
data.show(3)

+-----------------+------------------+------------------+------------------+------------------+----------+-------------------+
|                0|                 1|                 2|                 3|                 4|pubchem_id|                Eat|
+-----------------+------------------+------------------+------------------+------------------+----------+-------------------+
|73.51669471981023| 17.81776508458945| 12.46955101471414|12.458130159135047|12.454607457208622|     25004|-19.013762529999866|
|73.51669471981023|20.649126049739863|18.527788932970484|17.891534524101264|17.887995075841932|     25005|-10.161019429999953|
|73.51669471981023|17.830377224933926| 12.51226341499877| 12.40477532898951|12.394492613618938|     25006| -9.376619229999946|
+-----------------+------------------+------------------+------------------+------------------+----------+-------------------+
only showing top 3 rows



In [18]:
data.printSchema()

root
 |-- 0: double (nullable = true)
 |-- 1: double (nullable = true)
 |-- 2: double (nullable = true)
 |-- 3: double (nullable = true)
 |-- 4: double (nullable = true)
 |-- pubchem_id: integer (nullable = true)
 |-- Eat: double (nullable = true)



In [19]:
from pyspark.ml.feature import VectorAssembler
featureassembler = VectorAssembler(inputCols=[str(i) for i in range(num_cols)], outputCol="IndependentFeature")

In [20]:
train_data = featureassembler.transform(data)

In [23]:
train_data.select('IndependentFeature').show(4)

+--------------------+
|  IndependentFeature|
+--------------------+
|[73.5166947198102...|
|[73.5166947198102...|
|[73.5166947198102...|
|[73.5166947198102...|
+--------------------+
only showing top 4 rows



In [24]:
finalized_data = train_data.select('IndependentFeature', 'Eat')
finalized_data.show(3)

+--------------------+-------------------+
|  IndependentFeature|                Eat|
+--------------------+-------------------+
|[73.5166947198102...|-19.013762529999866|
|[73.5166947198102...|-10.161019429999953|
|[73.5166947198102...| -9.376619229999946|
+--------------------+-------------------+
only showing top 3 rows



In [25]:
from pyspark.ml.regression import LinearRegression
train_set, valid_set = finalized_data.randomSplit([0.8, 0.2])
regressor = LinearRegression(featuresCol='IndependentFeature', labelCol='Eat')
regressor=regressor.fit(train_set)

In [26]:
### Coefficient
regressor.coefficients

DenseVector([0.0016, 0.106, 0.0586, -0.0308, -0.2108])

In [27]:
### Intercepts
regressor.intercept

-10.748384396158793

In [28]:
### Predictions
pred_results = regressor.evaluate(train_set)

In [29]:
pred_results.predictions.show()

+--------------------+-------------------+-------------------+
|  IndependentFeature|                Eat|         prediction|
+--------------------+-------------------+-------------------+
|(5,[0,1],[53.3587...| -1.207594010000001| -8.194471994350016|
|(5,[0,1],[73.5166...|-0.9827892500000033| -8.087182287021717|
|[36.8581051994259...|-1.7550030800000034|-10.914281554584761|
|[36.8581051994259...|-1.7550030800000034|-10.914281554584761|
|[36.8581051994259...| -10.17353853999991|-11.638754093133144|
|[36.8581051994259...| -4.736708110000016|-11.641543549040769|
|[36.8581051994259...| -17.34876028999986|-11.633380546101439|
|[36.8581051994259...| -10.95664080999994|-11.644442393117687|
|[36.8581051994259...|-12.598071529999913| -11.64495306803491|
|[36.8581051994259...|  -8.97782656999992|-11.631240233878282|
|[36.8581051994259...|-6.5794563800000105|-11.646494478660287|
|[36.8581051994259...| -9.066991039999934|-11.641834049884503|
|[36.8581051994259...|  -11.3647322099999|-11.638318917

In [30]:
pred_results.meanAbsoluteError, pred_results.meanSquaredError

(2.8628827290379704, 12.77370833025871)