### ðŸ”¹ What is corr() in PySpark?

The corr() function in PySpark is used to compute the Pearson correlation coefficient between two numerical columns.
It measures the linear relationship between the columns, giving a value between -1 and 1.

+1 â†’ Perfect positive correlation

-1 â†’ Perfect negative correlation

0 â†’ No linear correlation

In [0]:
###Syntax:
DataFrame.stat.corr(col1, col2, method=None)

Parameters:

col1, col2 â†’ Column names (must be numeric)

method â†’ Optional. 'pearson' (default) or 'spearman'


[0;36m  File [0;32m<command-4517717111200233>:3[0;36m[0m
[0;31m    Parameters:[0m
[0m               ^[0m
[0;31mSyntaxError[0m[0;31m:[0m invalid syntax


In [0]:
# Create sample DataFrame
data = [
    (1, 10, 100),
    (2, 20, 200),
    (3, 30, 300),
    (4, 40, 400),
    (5, 50, 500)
]

columns = ["id", "sales", "profit"]

df = spark.createDataFrame(data, columns)

df.display()

id,sales,profit
1,10,100
2,20,200
3,30,300
4,40,400
5,50,500


In [0]:
# Find correlation between sales and profit
corr_value = df.stat.corr("sales", "profit")

print("Correlation between Sales and Profit:", corr_value)


Correlation between Sales and Profit: 1.0


In [0]:
data2 = [
    (1, 10, 30),
    (2, 15, 10),
    (3, 20, 25),
    (4, 25, 40),
    (5, 30, 35)
]

df2 = spark.createDataFrame(data2, ["id", "x", "y"])

df2.display()

# Compute correlation
print("Correlation between x and y:", df2.stat.corr("x", "y"))

id,x,y
1,10,30
2,15,10
3,20,25
4,25,40
5,30,35


Correlation between x and y: 0.5494422557947561


In [0]:
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler

# Convert columns to a single vector column
assembler = VectorAssembler(inputCols=["sales", "profit"], outputCol="features")
df_vector = assembler.transform(df)

# Compute correlation matrix
corr_matrix = Correlation.corr(df_vector, "features").head()[0]
print("Correlation matrix:\n", corr_matrix.toArray())


Correlation matrix:
 [[1. 1.]
 [1. 1.]]


### ðŸ”¹ Key Takeaways
| Use Case                             | Method                                      |
| ------------------------------------ | ------------------------------------------- |
| Find correlation between two columns | `df.stat.corr("col1", "col2")`              |
| Get full correlation matrix          | `Correlation.corr()` from `pyspark.ml.stat` |
| Default correlation type             | Pearson                                     |
| Other method                         | Spearman (set `method="spearman"`)          |
