In [0]:
from pyspark import SparkConf
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StructField,StructType,IntegerType
from pyspark.sql.functions import col,when

spark = SparkSession.builder.appName("app").master("local[3]").getOrCreate()

In [0]:
schema = StructType([
    StructField("id",IntegerType(),False),
    StructField("p_id",IntegerType(),True)
])
data = [
( 1  , None) ,
( 2  , 1   ) ,
( 3  , 1   ) ,
( 4  , 2   ) ,
( 5  , 2   ) 
]
tree = spark.createDataFrame(data,schema)
tree.show()

+---+----+
| id|p_id|
+---+----+
|  1|null|
|  2|   1|
|  3|   1|
|  4|   2|
|  5|   2|
+---+----+



In [0]:
p_ids = tree.select("p_id").distinct().dropna().rdd.flatMap(lambda x:x).collect()
tree.select("id",when(col("p_id").isNull(),"Root").when(col("id").isin(p_ids),"Inner").otherwise("Leaf").alias("type")).show()

+---+-----+
| id| type|
+---+-----+
|  1| Root|
|  2|Inner|
|  3| Leaf|
|  4| Leaf|
|  5| Leaf|
+---+-----+



In [0]:
tree.createOrReplaceTempView("tree")
spark.sql("select id, case when p_id is null then 'Root' when `id` in (select p_id from tree where p_id is not null) then 'Inner' else 'Leaf' end as type from tree").show()

+---+-----+
| id| type|
+---+-----+
|  1| Root|
|  2|Inner|
|  3| Leaf|
|  4| Leaf|
|  5| Leaf|
+---+-----+



In [0]:
spark.stop()