In PySpark, the method randomSplit() is used to split a DataFrame (or RDD) into multiple random parts, which is super handy for train/test dataset creation in ML workflows.

In [0]:
# Sample dataset: 15 rows
data = [
    (1, "Alice", 25, "HR"),
    (2, "Bob", 30, "IT"),
    (3, "Cathy", 28, "Finance"),
    (4, "David", 35, "IT"),
    (5, "Eva", 40, "HR"),
    (6, "Frank", 29, "Finance"),
    (7, "Grace", 31, "IT"),
    (8, "Helen", 27, "Finance"),
    (9, "Ian", 26, "HR"),
    (10, "Jack", 33, "IT"),
    (11, "Kelly", 34, "Finance"),
    (12, "Leo", 32, "IT"),
    (13, "Mona", 29, "HR"),
    (14, "Nina", 36, "Finance"),
    (15, "Oscar", 38, "IT")
]

columns = ["id", "name", "age", "department"]

df = spark.createDataFrame(data, columns)

df.display()

id,name,age,department
1,Alice,25,HR
2,Bob,30,IT
3,Cathy,28,Finance
4,David,35,IT
5,Eva,40,HR
6,Frank,29,Finance
7,Grace,31,IT
8,Helen,27,Finance
9,Ian,26,HR
10,Jack,33,IT


In [0]:
trainDF, testDF = df.randomSplit([0.8, 0.2], seed=42)

print("Training set count:", trainDF.count())
print("Test set count:", testDF.count())


Training set count: 13
Test set count: 2


In [0]:
print("Training Data:")
trainDF.display()

print("Test Data:")
testDF.display()

Training Data:


id,name,age,department
1,Alice,25,HR
3,Cathy,28,Finance
4,David,35,IT
6,Frank,29,Finance
7,Grace,31,IT
8,Helen,27,Finance
9,Ian,26,HR
10,Jack,33,IT
11,Kelly,34,Finance
12,Leo,32,IT


Test Data:


id,name,age,department
2,Bob,30,IT
5,Eva,40,HR


In [0]:
trainDF, valDF, testDF = df.randomSplit([0.7, 0.2, 0.1], seed=123)

print("Train:", trainDF.count())
print("Validation:", valDF.count())
print("Test:", testDF.count())


Train: 14
Validation: 0
Test: 1


### ✅ Key Notes

Proportions in weights don’t need to sum to 1; PySpark normalizes them internally.

Always set a seed for reproducible splits.

randomSplit() returns DataFrames in a list, so unpack accordingly.