In [1]:
import findspark

In [2]:
findspark.init('/home/lakshmi/spark-2.4.5-bin-hadoop2.7')

In [3]:
import pyspark

create a spark session

In [4]:
from pyspark.sql import SparkSession

In [5]:
spark = SparkSession.builder.appName("My_Spark").getOrCreate()

define a schema for reading input file

In [6]:
from pyspark.sql.types import StringType, DecimalType, FloatType, IntegerType, StructField, StructType

In [7]:
data_schema = StructType([StructField("country", StringType(), False), StructField("area", DecimalType(10,2), False),
                     StructField("female", IntegerType(), True), StructField("male", IntegerType(), True), 
                     StructField("population", IntegerType(), True), StructField("density", FloatType(), True)])

In [8]:
data_schema.fieldNames()

['country', 'area', 'female', 'male', 'population', 'density']

Read the data from a CSV into a DataFrame

In [9]:
df = spark.read.csv('/media/sf_Ubuntu_Shared/popdensity.csv', schema=data_schema, header=True)

In [10]:
type(df)

pyspark.sql.dataframe.DataFrame

In [11]:
df.printSchema()

root
 |-- country: string (nullable = true)
 |-- area: decimal(10,2) (nullable = true)
 |-- female: integer (nullable = true)
 |-- male: integer (nullable = true)
 |-- population: integer (nullable = true)
 |-- density: float (nullable = true)



In [12]:
df.columns

['country', 'area', 'female', 'male', 'population', 'density']

In [13]:
df.show(5)

+---------------+--------+------+----+----------+----------+
|        country|    area|female|male|population|   density|
+---------------+--------+------+----+----------+----------+
|BadenWrttemberg|35751.65|  5465|5271|     10736|0.30029383|
|         Bayern|70551.57|  6366|6103|     12469|0.17673597|
|         Berlin|  891.85|  1736|1660|      3396| 3.8078153|
|    Brandenburg|29478.61|  1293|1267|      2560|0.08684263|
|         Bremen|  404.28|   342| 321|       663| 1.6399525|
+---------------+--------+------+----+----------+----------+
only showing top 5 rows



## Accessing Data from a DataFrame

In [14]:
df['country'] # this returns a column

Column<b'country'>

In [15]:
type(df['country'])

pyspark.sql.column.Column

In [16]:
df.country

Column<b'country'>

In [17]:
df.select('country').show()

+--------------------+
|             country|
+--------------------+
|     BadenWrttemberg|
|              Bayern|
|              Berlin|
|         Brandenburg|
|              Bremen|
|             Hamburg|
|              Hessen|
|Mecklenburg-Vorpo...|
|       Niedersachsen|
| Nordrhein-Westfalen|
|     Rheinland-Pfalz|
|            Saarland|
|             Sachsen|
|      Sachsen-Anhalt|
|  Schleswig-Holstein|
|           Thuringen|
+--------------------+



In [18]:
type(df.select('country')) # select() method returns a dataframe

pyspark.sql.dataframe.DataFrame

In [19]:
df.select('country', 'population').show() # selecting multiple columns

+--------------------+----------+
|             country|population|
+--------------------+----------+
|     BadenWrttemberg|     10736|
|              Bayern|     12469|
|              Berlin|      3396|
|         Brandenburg|      2560|
|              Bremen|       663|
|             Hamburg|      1743|
|              Hessen|      6092|
|Mecklenburg-Vorpo...|      1707|
|       Niedersachsen|      7994|
| Nordrhein-Westfalen|     18058|
|     Rheinland-Pfalz|      4059|
|            Saarland|      1050|
|             Sachsen|      4274|
|      Sachsen-Anhalt|      2470|
|  Schleswig-Holstein|      2833|
|           Thuringen|      2335|
+--------------------+----------+



In [20]:
df.select(['country', 'population'])

DataFrame[country: string, population: int]

In [21]:
df.select(df.country, (df.population/1000).alias('Population in Thousands')).show()

+--------------------+-----------------------+
|             country|Population in Thousands|
+--------------------+-----------------------+
|     BadenWrttemberg|                 10.736|
|              Bayern|                 12.469|
|              Berlin|                  3.396|
|         Brandenburg|                   2.56|
|              Bremen|                  0.663|
|             Hamburg|                  1.743|
|              Hessen|                  6.092|
|Mecklenburg-Vorpo...|                  1.707|
|       Niedersachsen|                  7.994|
| Nordrhein-Westfalen|                 18.058|
|     Rheinland-Pfalz|                  4.059|
|            Saarland|                   1.05|
|             Sachsen|                  4.274|
|      Sachsen-Anhalt|                   2.47|
|  Schleswig-Holstein|                  2.833|
|           Thuringen|                  2.335|
+--------------------+-----------------------+



In [22]:
df.head(5) # returns the first n rows; this method should be used only if the df is small enough to be loaded in driver's memory

[Row(country='BadenWrttemberg', area=Decimal('35751.65'), female=5465, male=5271, population=10736, density=0.30029383301734924),
 Row(country='Bayern', area=Decimal('70551.57'), female=6366, male=6103, population=12469, density=0.17673596739768982),
 Row(country='Berlin', area=Decimal('891.85'), female=1736, male=1660, population=3396, density=3.8078153133392334),
 Row(country='Brandenburg', area=Decimal('29478.61'), female=1293, male=1267, population=2560, density=0.0868426263332367),
 Row(country='Bremen', area=Decimal('404.28'), female=342, male=321, population=663, density=1.639952540397644)]

In [24]:
single_row = df.head(1) #accessing a single row of dataframe

In [27]:
single_row

[Row(country='BadenWrttemberg', area=Decimal('35751.65'), female=5465, male=5271, population=10736, density=0.30029383301734924)]

In [30]:
single_row[0]['population'] # accessing a column from the row

10736

In [31]:
single_row[0]['area'] # accessing a column from the row

Decimal('35751.65')

## Adding / renaming columns

In [32]:
df1 = df.withColumn('diff', df.female-df.male) #adding a new column

In [33]:
df1.show()

+--------------------+--------+------+----+----------+----------+----+
|             country|    area|female|male|population|   density|diff|
+--------------------+--------+------+----+----------+----------+----+
|     BadenWrttemberg|35751.65|  5465|5271|     10736|0.30029383| 194|
|              Bayern|70551.57|  6366|6103|     12469|0.17673597| 263|
|              Berlin|  891.85|  1736|1660|      3396| 3.8078153|  76|
|         Brandenburg|29478.61|  1293|1267|      2560|0.08684263|  26|
|              Bremen|  404.28|   342| 321|       663| 1.6399525|  21|
|             Hamburg|  755.16|   894| 849|      1743| 2.3081203|  45|
|              Hessen|21114.79|  3109|2983|      6092|0.28851813| 126|
|Mecklenburg-Vorpo...|23180.14|   861| 846|      1707|0.07364062|  15|
|       Niedersachsen|47624.20|  4076|3918|      7994|0.16785584| 158|
| Nordrhein-Westfalen|34085.29|  9261|8797|     18058| 0.5297887| 464|
|     Rheinland-Pfalz|19853.36|  2069|1990|      4059|0.20444901|  79|
|     

In [34]:
df1.withColumnRenamed('diff', 'diff_in_population').show()

+--------------------+--------+------+----+----------+----------+------------------+
|             country|    area|female|male|population|   density|diff_in_population|
+--------------------+--------+------+----+----------+----------+------------------+
|     BadenWrttemberg|35751.65|  5465|5271|     10736|0.30029383|               194|
|              Bayern|70551.57|  6366|6103|     12469|0.17673597|               263|
|              Berlin|  891.85|  1736|1660|      3396| 3.8078153|                76|
|         Brandenburg|29478.61|  1293|1267|      2560|0.08684263|                26|
|              Bremen|  404.28|   342| 321|       663| 1.6399525|                21|
|             Hamburg|  755.16|   894| 849|      1743| 2.3081203|                45|
|              Hessen|21114.79|  3109|2983|      6092|0.28851813|               126|
|Mecklenburg-Vorpo...|23180.14|   861| 846|      1707|0.07364062|                15|
|       Niedersachsen|47624.20|  4076|3918|      7994|0.16785584|

## SQL commands using dataframe

In [35]:
df.createOrReplaceTempView("Population")

In [36]:
spark.sql("SELECT * FROM Population WHERE population>10000").show()

+-------------------+--------+------+----+----------+----------+
|            country|    area|female|male|population|   density|
+-------------------+--------+------+----+----------+----------+
|    BadenWrttemberg|35751.65|  5465|5271|     10736|0.30029383|
|             Bayern|70551.57|  6366|6103|     12469|0.17673597|
|Nordrhein-Westfalen|34085.29|  9261|8797|     18058| 0.5297887|
+-------------------+--------+------+----+----------+----------+



In [37]:
df.count()

16

In [38]:
L1 = df.collect() # collect() returns the rows of dataframe as a list

In [39]:
len(L1) # note that this is same as df.count()

16

In [40]:
L1[0] # first row

Row(country='BadenWrttemberg', area=Decimal('35751.65'), female=5465, male=5271, population=10736, density=0.30029383301734924)

# Filtering data based on conditions

In [41]:
from pyspark.sql.types import DateType

In [42]:
schema2 = StructType([StructField("Date", DateType(), True), StructField("Open", FloatType(), True),
                      StructField("High", FloatType(), True), StructField("Low", FloatType(), True),
                      StructField("Close", FloatType(), True), StructField("Volume", IntegerType(), True),
                      StructField("Adj_Close", FloatType(), True) ])

In [50]:
df1 = spark.read.csv(path="/media/sf_Ubuntu_Shared/appl_stock.csv", schema=schema2, header=True) #DateType format is not working

In [47]:
df1 = spark.read.csv(path="/media/sf_Ubuntu_Shared/appl_stock.csv", header=True, inferSchema=True)

In [51]:
df1.printSchema()

root
 |-- Date: date (nullable = true)
 |-- Open: float (nullable = true)
 |-- High: float (nullable = true)
 |-- Low: float (nullable = true)
 |-- Close: float (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- Adj_Close: float (nullable = true)



In [52]:
df1.show(5)

+----------+---------+------+------+---------+---------+---------+
|      Date|     Open|  High|   Low|    Close|   Volume|Adj_Close|
+----------+---------+------+------+---------+---------+---------+
|2010-01-04|   213.43| 214.5|212.38|   214.01|123432400| 27.72704|
|2010-01-05|214.59999|215.59|213.25|214.37999|150476200|27.774977|
|2010-01-06|214.37999|215.23|210.75|   210.97|138040000|27.333178|
|2010-01-07|   211.75| 212.0|209.05|   210.58|119282800| 27.28265|
|2010-01-08|210.29999| 212.0|209.06|211.98001|111902700|27.464033|
+----------+---------+------+------+---------+---------+---------+
only showing top 5 rows



In [53]:
df1.count()

1762

In [54]:
df1.describe().show()

+-------+------------------+-----------------+------------------+------------------+-------------------+------------------+
|summary|              Open|             High|               Low|             Close|             Volume|         Adj_Close|
+-------+------------------+-----------------+------------------+------------------+-------------------+------------------+
|  count|              1762|             1762|              1762|              1762|               1762|              1762|
|   mean|313.07631053340015|315.9112879420788| 309.8282404974289| 312.9270658330668|9.422577587968218E7| 75.00174113976158|
| stddev|185.29946731081264|186.8981766989906|183.38391663940038|185.14710364848838|6.020518776592709E7|28.574929738834484|
|    min|              90.0|             90.7|             89.47|             90.28|           11475900|         24.881912|
|    max|            702.41|           705.07|            699.57|         702.10004|          470249500|         127.96609|
+-------

Filtering using SQL statement

In [56]:
df1.filter("Open > 500").show(5)

+----------+---------+---------+---------+---------+---------+---------+
|      Date|     Open|     High|      Low|    Close|   Volume|Adj_Close|
+----------+---------+---------+---------+---------+---------+---------+
|2012-02-14|504.65997|509.56003|    502.0|   509.46|115099600| 66.00541|
|2012-02-15|   514.26|526.29004|496.88998|497.66998|376530000|  64.4779|
|2012-02-17|   503.11|507.77002|    500.3|   502.12|133951300| 65.05444|
|2012-02-21|   506.88|514.85004|   504.12|514.85004|151398800|66.703735|
|2012-02-22|   513.08|   515.49|509.07004|   513.04|120825600| 66.46923|
+----------+---------+---------+---------+---------+---------+---------+
only showing top 5 rows



In [57]:
df1.filter("Open > 500").count() # using SQL format expression

401

In [58]:
df1.filter(df1['Open'] > 500).count() # using python expression

401

In [59]:
df1.filter( (df1['Open'] >= 500) & (df1['Open'] < 510) ).count()

25

In [62]:
df1.filter( ("Open >= 500") and ("Open < 510") ).count() #Incorrect results

1386

In [65]:
df1.select('Date', 'Open', 'Close').filter("Open >= 500").count()

401

In [67]:
 #Unlike SQL queries the order of select and where clauses can be interchanged
df1.filter("Open >= 500").select('Date', 'Open', 'Close').count()

401

In [69]:
df1.createOrReplaceTempView("appleStock")

In [70]:
spark.sql("select * from appleStock where (Open >= 500) and (Open < 510)").count() # SQL query works

25

In [71]:
results = df1.filter( (df1['Open'] >= 500) & (df1['Open'] < 510) ).collect() # return a list of rows

In [72]:
type(results)

list

In [73]:
len(results)

25

In [74]:
results

[Row(Date=datetime.date(2012, 2, 14), Open=504.65997314453125, High=509.5600280761719, Low=502.0, Close=509.4599914550781, Volume=115099600, Adj_Close=66.00540924072266),
 Row(Date=datetime.date(2012, 2, 17), Open=503.1099853515625, High=507.77001953125, Low=500.29998779296875, Close=502.1199951171875, Volume=133951300, Adj_Close=65.054443359375),
 Row(Date=datetime.date(2012, 2, 21), Open=506.8800048828125, High=514.8500366210938, Low=504.1199951171875, Close=514.8500366210938, Volume=151398800, Adj_Close=66.7037353515625),
 Row(Date=datetime.date(2012, 12, 17), Open=508.92999267578125, High=520.0, Low=501.2300109863281, Close=518.8299560546875, Volume=189401800, Adj_Close=67.81631469726562),
 Row(Date=datetime.date(2013, 1, 14), Open=502.6800231933594, High=507.5, Low=498.510009765625, Close=501.75, Volume=183551900, Adj_Close=65.58379364013672),
 Row(Date=datetime.date(2013, 1, 22), Open=504.5600280761719, High=507.8799743652344, Low=496.6300048828125, Close=504.7699890136719, Volum

In [75]:
results[0]

Row(Date=datetime.date(2012, 2, 14), Open=504.65997314453125, High=509.5600280761719, Low=502.0, Close=509.4599914550781, Volume=115099600, Adj_Close=66.00540924072266)

In [76]:
type(results[0])

pyspark.sql.types.Row

In [77]:
results[0]['Open'] #accessing an element of the row

504.65997314453125

In [78]:
results[0].High # another method of accessing row element

509.5600280761719

In [53]:
results[0].asDict() # create a dictionary from a row object

{'Date': '14/02/12',
 'Open': 504.659988,
 'High': 509.56002,
 'Low': 502.000008,
 'Close': 509.459991,
 'Volume': 115099600,
 'Adj Close': 66.005408}

## GroupBy and Aggregate functions

In [79]:
df2 = spark.read.csv("/media/sf_Ubuntu_Shared/sales_details.csv", inferSchema=True, header=True)

In [80]:
df2.printSchema()

root
 |-- Company: string (nullable = true)
 |-- Geo: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Sales: integer (nullable = true)
 |-- Profit: integer (nullable = true)



In [81]:
df2.groupBy('Company') #groupBy returns a GroupedData object

<pyspark.sql.group.GroupedData at 0x7fc623a735f8>

In [82]:
df2.groupBy('Company').sum('Sales', 'Profit') #aggregate function returns a dataframe

DataFrame[Company: string, sum(Sales): bigint, sum(Profit): bigint]

In [83]:
df2.groupBy('Company').sum('Sales', 'Profit').show()

+-------+----------+-----------+
|Company|sum(Sales)|sum(Profit)|
+-------+----------+-----------+
|   APPL|      3700|        441|
|     FB|      2388|        535|
| GOOGLE|      1961|        291|
+-------+----------+-----------+



In [84]:
df2.groupBy('Company', 'Geo').mean('Sales', 'Profit').sort('Company', 'Geo').show() # groupBy multiple columns

+-------+----+------------------+------------------+
|Company| Geo|        avg(Sales)|       avg(Profit)|
+-------+----+------------------+------------------+
|   APPL|APAC|             340.0|45.666666666666664|
|   APPL|EMEA| 503.3333333333333|59.666666666666664|
|   APPL|  US|             390.0|41.666666666666664|
|     FB|APAC|             173.0|              75.0|
|     FB|EMEA| 345.3333333333333|              68.0|
|     FB|  US| 277.6666666666667|35.333333333333336|
| GOOGLE|APAC|190.66666666666666|              43.0|
| GOOGLE|EMEA| 278.6666666666667|32.333333333333336|
| GOOGLE|  US|184.33333333333334|21.666666666666668|
+-------+----+------------------+------------------+



In [85]:
df2.groupBy('Company', 'Geo').mean('Sales', 'Profit').sort('avg(Profit)').show() # sorting by a column not grouped by

+-------+----+------------------+------------------+
|Company| Geo|        avg(Sales)|       avg(Profit)|
+-------+----+------------------+------------------+
| GOOGLE|  US|184.33333333333334|21.666666666666668|
| GOOGLE|EMEA| 278.6666666666667|32.333333333333336|
|     FB|  US| 277.6666666666667|35.333333333333336|
|   APPL|  US|             390.0|41.666666666666664|
| GOOGLE|APAC|190.66666666666666|              43.0|
|   APPL|APAC|             340.0|45.666666666666664|
|   APPL|EMEA| 503.3333333333333|59.666666666666664|
|     FB|EMEA| 345.3333333333333|              68.0|
|     FB|APAC|             173.0|              75.0|
+-------+----+------------------+------------------+



In [86]:
df2.groupBy('Company', 'Year').sum('Sales', 'Profit').sort(['Company', 'Year'],ascending=[False, True]).show()

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|sum(Profit)|
+-------+----+----------+-----------+
| GOOGLE|2017|       575|         95|
| GOOGLE|2018|       685|        137|
| GOOGLE|2019|       701|         59|
|     FB|2017|       440|         75|
|     FB|2018|      1200|        245|
|     FB|2019|       748|        215|
|   APPL|2017|       505|        100|
|   APPL|2018|      1350|        140|
|   APPL|2019|      1845|        201|
+-------+----+----------+-----------+



In [87]:
df2.groupBy('Company', 'Year').sum('Sales', 'Profit').sort(['Company', 'Year'],ascending=[False, True]).collect()

[Row(Company='GOOGLE', Year=2017, sum(Sales)=575, sum(Profit)=95),
 Row(Company='GOOGLE', Year=2018, sum(Sales)=685, sum(Profit)=137),
 Row(Company='GOOGLE', Year=2019, sum(Sales)=701, sum(Profit)=59),
 Row(Company='FB', Year=2017, sum(Sales)=440, sum(Profit)=75),
 Row(Company='FB', Year=2018, sum(Sales)=1200, sum(Profit)=245),
 Row(Company='FB', Year=2019, sum(Sales)=748, sum(Profit)=215),
 Row(Company='APPL', Year=2017, sum(Sales)=505, sum(Profit)=100),
 Row(Company='APPL', Year=2018, sum(Sales)=1350, sum(Profit)=140),
 Row(Company='APPL', Year=2019, sum(Sales)=1845, sum(Profit)=201)]

In [88]:
df2.groupBy('Company', 'Year').max('Sales', 'Profit').orderBy(['Company', 'Year'], ascending=[False,False]).show()

+-------+----+----------+-----------+
|Company|Year|max(Sales)|max(Profit)|
+-------+----+----------+-----------+
| GOOGLE|2019|       322|         24|
| GOOGLE|2018|       330|         80|
| GOOGLE|2017|       250|         40|
|     FB|2019|       526|        124|
|     FB|2018|       500|        125|
|     FB|2017|       230|         35|
|   APPL|2019|       760|         89|
|   APPL|2018|       600|         60|
|   APPL|2017|       200|         45|
+-------+----+----------+-----------+



In [89]:
df2.agg({'Sales':'max', 'Profit':'min'}).show() #Aggregate function on the entire dataframe;
# note that a different function is called for each column

+----------+-----------+
|max(Sales)|min(Profit)|
+----------+-----------+
|       760|         10|
+----------+-----------+



In [90]:
df2.groupBy('Company', 'Year').agg({'Sales':'sum', 'Profit':'max'}).sort(['Company', 'Year']).show()

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|max(Profit)|
+-------+----+----------+-----------+
|   APPL|2017|       505|         45|
|   APPL|2018|      1350|         60|
|   APPL|2019|      1845|         89|
|     FB|2017|       440|         35|
|     FB|2018|      1200|        125|
|     FB|2019|       748|        124|
| GOOGLE|2017|       575|         40|
| GOOGLE|2018|       685|         80|
| GOOGLE|2019|       701|         24|
+-------+----+----------+-----------+



In [91]:
df2.groupBy('Company', 'Year').agg({'Sales':'sum', 'Profit':'max'}).sort(['Company', 'Year']).collect()

[Row(Company='APPL', Year=2017, sum(Sales)=505, max(Profit)=45),
 Row(Company='APPL', Year=2018, sum(Sales)=1350, max(Profit)=60),
 Row(Company='APPL', Year=2019, sum(Sales)=1845, max(Profit)=89),
 Row(Company='FB', Year=2017, sum(Sales)=440, max(Profit)=35),
 Row(Company='FB', Year=2018, sum(Sales)=1200, max(Profit)=125),
 Row(Company='FB', Year=2019, sum(Sales)=748, max(Profit)=124),
 Row(Company='GOOGLE', Year=2017, sum(Sales)=575, max(Profit)=40),
 Row(Company='GOOGLE', Year=2018, sum(Sales)=685, max(Profit)=80),
 Row(Company='GOOGLE', Year=2019, sum(Sales)=701, max(Profit)=24)]

In [92]:
df2.createOrReplaceTempView('salesDetails') # create a view for SQL querying

In [93]:
results1 = spark.sql("select sum(Sales), sum(Profit) from salesDetails where Company = 'GOOGLE'")

In [94]:
results1.show()

+----------+-----------+
|sum(Sales)|sum(Profit)|
+----------+-----------+
|      1961|        291|
+----------+-----------+



In [95]:
results2 = spark.sql("select Company, Year, sum(Sales), sum(Profit) from salesDetails group by Company, Year order by Company ASC, Year DESC")

In [96]:
results2.show()

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|sum(Profit)|
+-------+----+----------+-----------+
|   APPL|2019|      1845|        201|
|   APPL|2018|      1350|        140|
|   APPL|2017|       505|        100|
|     FB|2019|       748|        215|
|     FB|2018|      1200|        245|
|     FB|2017|       440|         75|
| GOOGLE|2019|       701|         59|
| GOOGLE|2018|       685|        137|
| GOOGLE|2017|       575|         95|
+-------+----+----------+-----------+



In [97]:
df2.groupBy(['Company', 'Year']).sum('Sales', 'Profit').sort(['Company', 'Year'], ascending=[True, False]).show() #same command using Python functions

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|sum(Profit)|
+-------+----+----------+-----------+
|   APPL|2019|      1845|        201|
|   APPL|2018|      1350|        140|
|   APPL|2017|       505|        100|
|     FB|2019|       748|        215|
|     FB|2018|      1200|        245|
|     FB|2017|       440|         75|
| GOOGLE|2019|       701|         59|
| GOOGLE|2018|       685|        137|
| GOOGLE|2017|       575|         95|
+-------+----+----------+-----------+



In [98]:
results3 = spark.sql("select Company, Year, sum(Sales), sum(Profit) from salesDetails group by Company, Year order by sum(Sales) DESC, sum(Profit) DESC")

In [99]:
results3.show()

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|sum(Profit)|
+-------+----+----------+-----------+
|   APPL|2019|      1845|        201|
|   APPL|2018|      1350|        140|
|     FB|2018|      1200|        245|
|     FB|2019|       748|        215|
| GOOGLE|2019|       701|         59|
| GOOGLE|2018|       685|        137|
| GOOGLE|2017|       575|         95|
|   APPL|2017|       505|        100|
|     FB|2017|       440|         75|
+-------+----+----------+-----------+



In [100]:
spark.sql("select Company, Year, sum(Sales), sum(Profit) from salesDetails group by Company, Year having (sum(Sales)>1000) order by sum(Sales) DESC, sum(Profit) DESC") # Not working

AnalysisException: "cannot resolve '`Sales`' given input columns: [salesdetails.Company, salesdetails.Year, sum(Sales), sum(Profit)]; line 1 pos 125;\n'Sort ['sum('Sales) DESC NULLS LAST, 'sum('Profit) DESC NULLS LAST], true\n+- Project [Company#978, Year#980, sum(Sales)#1398L, sum(Profit)#1399L]\n   +- Filter (sum(cast(Sales#981 as bigint))#1402L > cast(1000 as bigint))\n      +- Aggregate [Company#978, Year#980], [Company#978, Year#980, sum(cast(Sales#981 as bigint)) AS sum(Sales)#1398L, sum(cast(Profit#982 as bigint)) AS sum(Profit)#1399L, sum(cast(Sales#981 as bigint)) AS sum(cast(Sales#981 as bigint))#1402L]\n         +- SubqueryAlias `salesdetails`\n            +- Relation[Company#978,Geo#979,Year#980,Sales#981,Profit#982] csv\n"

In [101]:
results3.filter(results3['sum(Sales)'] > 1000).show() # workaround for the above issue

+-------+----+----------+-----------+
|Company|Year|sum(Sales)|sum(Profit)|
+-------+----+----------+-----------+
|   APPL|2019|      1845|        201|
|   APPL|2018|      1350|        140|
|     FB|2018|      1200|        245|
+-------+----+----------+-----------+



In [102]:
df2.groupBy().sum('Sales', 'Profit').show()

+----------+-----------+
|sum(Sales)|sum(Profit)|
+----------+-----------+
|      8049|       1267|
+----------+-----------+



In [104]:
df2.agg({'Sales':'sum', 'Profit':'sum'}).show()

+----------+-----------+
|sum(Sales)|sum(Profit)|
+----------+-----------+
|      8049|       1267|
+----------+-----------+



## Functions module

In [74]:
from pyspark.sql.functions import avg, stddev, sqrt, countDistinct, log, format_number

In [76]:
df2.select(avg(df2.Sales)).show()

+-----------------+
|       avg(Sales)|
+-----------------+
|298.1111111111111|
+-----------------+



In [77]:
df2.agg({'Sales': 'avg'}).show() # using alternate method

+-----------------+
|       avg(Sales)|
+-----------------+
|298.1111111111111|
+-----------------+



In [43]:
df2.select(stddev('Sales')) 

DataFrame[stddev_samp(Sales): double]

In [66]:
df2.select(stddev('Sales').alias('standard_deviation')).show() # stddev is an aggregate function

+------------------+
|standard_deviation|
+------------------+
|189.54385190470248|
+------------------+



In [86]:
df2.select(countDistinct('Company').alias("No._of_Companies")).show()

+----------------+
|No._of_Companies|
+----------------+
|               3|
+----------------+



In [87]:
sq_sales = df2.select(sqrt('Sales')) # sqrt operates on each value in the column

In [88]:
sq_sales.show() # sqrt operates on each value in the column

+------------------+
|       SQRT(Sales)|
+------------------+
|14.142135623730951|
|15.811388300841896|
|11.180339887498949|
|15.165750888103101|
|10.488088481701515|
|              10.0|
| 11.40175425099138|
|14.142135623730951|
|13.228756555322953|
|15.165750888103101|
| 18.16590212458495|
|11.180339887498949|
|22.360679774997898|
|              20.0|
|17.320508075688775|
| 24.49489742783178|
| 23.45207879911715|
|14.142135623730951|
|11.090536506409418|
|              16.0|
+------------------+
only showing top 20 rows



In [62]:
df2.select(countDistinct(df2['Sales']), countDistinct(df2['Profit'])).show() # aggregate function

+---------------------+----------------------+
|count(DISTINCT Sales)|count(DISTINCT Profit)|
+---------------------+----------------------+
|                   23|                    19|
+---------------------+----------------------+



In [75]:
log_profit = df2.select(log('Profit').alias('log_of_profit'))

In [76]:
log_profit.select(format_number('log_of_profit', 2).alias('log_of_profit')).show() #format_number function

+-------------+
|log_of_profit|
+-------------+
|         3.40|
|         3.69|
|         3.22|
|         3.56|
|         2.30|
|         3.40|
|         3.00|
|         3.81|
|         3.56|
|         2.48|
|         3.81|
|         4.38|
|         3.91|
|         4.25|
|         4.83|
|         4.09|
|         3.81|
|         3.56|
|         3.14|
|         2.48|
+-------------+
only showing top 20 rows



In [78]:
df2.select(format_number('Sales', 2).alias('Sales'), format_number('Profit', 2).alias("Profit")).show()

+------+------+
| Sales|Profit|
+------+------+
|200.00| 30.00|
|250.00| 40.00|
|125.00| 25.00|
|230.00| 35.00|
|110.00| 10.00|
|100.00| 30.00|
|130.00| 20.00|
|200.00| 45.00|
|175.00| 35.00|
|230.00| 12.00|
|330.00| 45.00|
|125.00| 80.00|
|500.00| 50.00|
|400.00| 70.00|
|300.00|125.00|
|600.00| 60.00|
|550.00| 45.00|
|200.00| 35.00|
|123.00| 23.00|
|256.00| 12.00|
+------+------+
only showing top 20 rows



## Join and Union

In [36]:
dfA = spark.read.csv("/media/sf_Ubuntu_Shared/File1.csv", inferSchema=True, header=True)

In [37]:
dfA.show()

+-----+--------+-----+
| Name|Sub_Code|Score|
+-----+--------+-----+
| John|     MAT|   99|
| Mary|     SCI|   78|
|Louis|     MAT|   89|
| Andy|     PSY|   67|
|Queen|     SCI|   45|
+-----+--------+-----+



In [38]:
dfA.collect()

[Row(Name='John', Sub_Code='MAT', Score=99),
 Row(Name='Mary', Sub_Code='SCI', Score=78),
 Row(Name='Louis', Sub_Code='MAT', Score=89),
 Row(Name='Andy', Sub_Code='PSY', Score=67),
 Row(Name='Queen', Sub_Code='SCI', Score=45)]

In [8]:
dfB = spark.read.csv("/media/sf_Ubuntu_Shared/File2.csv", inferSchema=True, header=True)

In [9]:
dfB.show()

+--------+----------+
|Sub_Code|  Sub_Name|
+--------+----------+
|     MAT|     Maths|
|     SCI|   Science|
|     PSY|Psychology|
+--------+----------+



In [10]:
dfC = dfA.join(dfB, on='Sub_Code', how='left')

In [11]:
dfC.select(['Name', 'Sub_Code', 'Sub_Name', 'Score']).show()

+-----+--------+----------+-----+
| Name|Sub_Code|  Sub_Name|Score|
+-----+--------+----------+-----+
| John|     MAT|     Maths|   99|
| Mary|     SCI|   Science|   78|
|Louis|     MAT|     Maths|   89|
| Andy|     PSY|Psychology|   67|
|Queen|     SCI|   Science|   45|
+-----+--------+----------+-----+



In [12]:
dfD = spark.read.csv("/media/sf_Ubuntu_Shared/File3.csv", inferSchema=True, header=True)

In [13]:
dfD.show()

+-----+-----+--------+
| Name|Score|Sub_Code|
+-----+-----+--------+
|Pilip|   72|     MAT|
|Karen|   99|     SCI|
+-----+-----+--------+



In [14]:
dfA.unionByName(dfD).show() # equivalent of append() in Pandas

+-----+--------+-----+
| Name|Sub_Code|Score|
+-----+--------+-----+
| John|     MAT|   99|
| Mary|     SCI|   78|
|Louis|     MAT|   89|
| Andy|     PSY|   67|
|Queen|     SCI|   45|
|Pilip|     MAT|   72|
|Karen|     SCI|   99|
+-----+--------+-----+



In [15]:
dfE = spark.read.csv("/media/sf_Ubuntu_Shared/File_dup.csv", inferSchema=True, header=True)

In [16]:
dfE.count() # before de-duping

7

In [17]:
dfE = dfE.distinct() # remove duplicate rows

In [18]:
dfE.count() # after de-duping

5

## Missing Data handling

In [5]:
df1 = spark.read.csv("/media/sf_Ubuntu_Shared/Missing_Data.csv", inferSchema=True, header=True)

In [6]:
df1.show()

+-----+------+------+
| Name|   Age|Salary|
+-----+------+------+
| John|    NA|  43.5|
| None|    23|  20.5|
|Louis|    45|  null|
| Sean|    22|  67.8|
|  NaN|np.nan|  23.5|
| Bill|   nan|   122|
| null|    12|   NAN|
+-----+------+------+



In [7]:
df1.na.drop().show()

+----+------+------+
|Name|   Age|Salary|
+----+------+------+
|John|    NA|  43.5|
|None|    23|  20.5|
|Sean|    22|  67.8|
| NaN|np.nan|  23.5|
|Bill|   nan|   122|
|null|    12|   NAN|
+----+------+------+



In [8]:
from pyspark.sql.functions import isnull, isnan, mean, format_number

In [10]:
df1.select(isnull('Name'), isnull('Age'), isnull('Salary')).show() # only empty cell in the input file is considered as null

+--------------+-------------+----------------+
|(Name IS NULL)|(Age IS NULL)|(Salary IS NULL)|
+--------------+-------------+----------------+
|         false|        false|           false|
|         false|        false|           false|
|         false|        false|            true|
|         false|        false|           false|
|         false|        false|           false|
|         false|        false|           false|
|         false|        false|           false|
+--------------+-------------+----------------+



In [13]:
df1.select(isnull('Salary').alias("Salary")).groupBy('Salary').count().show()

+------+-----+
|Salary|count|
+------+-----+
|  true|    1|
| false|    6|
+------+-----+



In [17]:
df1.select(isnan('Name'), isnan('Age'), isnan('Salary')).show() # null is different from nan

+-----------+----------+-------------+
|isnan(Name)|isnan(Age)|isnan(Salary)|
+-----------+----------+-------------+
|      false|     false|        false|
|      false|     false|        false|
|      false|     false|        false|
|      false|     false|        false|
|       true|     false|        false|
|      false|     false|        false|
|      false|     false|        false|
+-----------+----------+-------------+



In [82]:
df2 = df1.replace(['NA', 'NaN', 'nan', 'np.nan', 'None', 'NAN', 'null' ], value=None)

In [83]:
df2.show()

+-----+----+------+
| Name| Age|Salary|
+-----+----+------+
| John|null|  43.5|
| null|  23|  20.5|
|Louis|  45|  null|
| Sean|  22|  67.8|
| null|null|  23.5|
| Bill|null|   122|
| null|  12|  null|
+-----+----+------+



In [84]:
df2.select(isnull('Name'), isnull('Age'), isnull('Salary')).show()

+--------------+-------------+----------------+
|(Name IS NULL)|(Age IS NULL)|(Salary IS NULL)|
+--------------+-------------+----------------+
|         false|         true|           false|
|          true|        false|           false|
|         false|        false|            true|
|         false|        false|           false|
|          true|         true|           false|
|         false|         true|           false|
|          true|        false|            true|
+--------------+-------------+----------------+



In [86]:
df2.collect() # In the Row objects, null is represented as None
# Since null value is present, numeric data is stored as string, which we want to change

[Row(Name='John', Age=None, Salary='43.5'),
 Row(Name=None, Age='23', Salary='20.5'),
 Row(Name='Louis', Age='45', Salary=None),
 Row(Name='Sean', Age='22', Salary='67.8'),
 Row(Name=None, Age=None, Salary='23.5'),
 Row(Name='Bill', Age=None, Salary='122'),
 Row(Name=None, Age='12', Salary=None)]

In [93]:
df2.dtypes # Note that due presence of null values, numeric values are also stored as string

[('Name', 'string'), ('Age', 'string'), ('Salary', 'string')]

In [87]:
data = df2.select(df2.Name, df2.Age.cast(IntegerType()), df2.Salary.cast(FloatType())).collect()

In [102]:
data

[Row(Name='John', Age=None, Salary=43.5),
 Row(Name=None, Age=23, Salary=20.5),
 Row(Name='Louis', Age=45, Salary=None),
 Row(Name='Sean', Age=22, Salary=67.80000305175781),
 Row(Name=None, Age=None, Salary=23.5),
 Row(Name='Bill', Age=None, Salary=122.0),
 Row(Name=None, Age=12, Salary=None)]

In [103]:
dataschema = StructType([StructField('Name', StringType(), True), StructField('Age', IntegerType(), True),
                        StructField('Salary', FloatType(), True)])

In [104]:
df3 = spark.createDataFrame(data, schema=dataschema )

In [105]:
df3.collect() #Now the data type for numeric columns is correct

[Row(Name='John', Age=None, Salary=43.5),
 Row(Name=None, Age=23, Salary=20.5),
 Row(Name='Louis', Age=45, Salary=None),
 Row(Name='Sean', Age=22, Salary=67.80000305175781),
 Row(Name=None, Age=None, Salary=23.5),
 Row(Name='Bill', Age=None, Salary=122.0),
 Row(Name=None, Age=12, Salary=None)]

In [110]:
df3.select(format_number('Salary', 2)).collect()

[Row(format_number(Salary, 2)='43.50'),
 Row(format_number(Salary, 2)='20.50'),
 Row(format_number(Salary, 2)=None),
 Row(format_number(Salary, 2)='67.80'),
 Row(format_number(Salary, 2)='23.50'),
 Row(format_number(Salary, 2)='122.00'),
 Row(format_number(Salary, 2)=None)]

In [106]:
df3.na.drop().show()

+----+---+------+
|Name|Age|Salary|
+----+---+------+
|Sean| 22|  67.8|
+----+---+------+



In [107]:
df3.dtypes # Note that the data types are correct now

[('Name', 'string'), ('Age', 'int'), ('Salary', 'float')]

In [96]:
mean_age = int(df3.select(mean('Age')).collect()[0][0])

In [97]:
mean_age

25

In [98]:
median_salary = df3.approxQuantile('Salary', [0.5], 0.05)

In [99]:
median_salary

[43.5]

In [101]:
df3.fillna({'Name':'unknown', 'Age':mean_age, 'Salary':median_salary[0]}).show()

+-------+---+-----------------+
|   Name|Age|           Salary|
+-------+---+-----------------+
|   John| 25|             43.5|
|unknown| 23|             20.5|
|  Louis| 45|             43.5|
|   Sean| 22|67.80000305175781|
|unknown| 25|             23.5|
|   Bill| 25|            122.0|
|unknown| 12|             43.5|
+-------+---+-----------------+



## Dates and Timestamp

In [48]:
from pyspark.sql.functions import (to_date, year, format_number, month, dayofmonth, 
                                   dayofweek, dayofyear, weekofyear, quarter)

In [8]:
df1 = spark.read.csv("/media/sf_Ubuntu_Shared/appl_stock.csv", inferSchema=True, header=True)

In [9]:
df1.show(5)

+--------+----------+----------+----------+----------+---------+---------+
|    Date|      Open|      High|       Low|     Close|   Volume|Adj Close|
+--------+----------+----------+----------+----------+---------+---------+
|04/01/10|213.429998|214.499996|212.380001|214.009998|123432400|27.727039|
|05/01/10|214.599998|215.589994|213.249994|214.379993|150476200|27.774976|
|06/01/10|214.379993|    215.23|210.750004|210.969995|138040000|27.333178|
|07/01/10|    211.75|212.000006|209.050005|    210.58|119282800| 27.28265|
|08/01/10|210.299994|212.000006|209.060005|211.980005|111902700|27.464034|
+--------+----------+----------+----------+----------+---------+---------+
only showing top 5 rows



In [10]:
df1.dtypes

[('Date', 'string'),
 ('Open', 'double'),
 ('High', 'double'),
 ('Low', 'double'),
 ('Close', 'double'),
 ('Volume', 'int'),
 ('Adj Close', 'double')]

In [11]:
df2 = df1.withColumn('Date_new', to_date(df1.Date, format='dd/MM/yy'))

In [12]:
df2 = df2.drop(df2.Date)

In [13]:
df2.show(5)

+----------+----------+----------+----------+---------+---------+----------+
|      Open|      High|       Low|     Close|   Volume|Adj Close|  Date_new|
+----------+----------+----------+----------+---------+---------+----------+
|213.429998|214.499996|212.380001|214.009998|123432400|27.727039|2010-01-04|
|214.599998|215.589994|213.249994|214.379993|150476200|27.774976|2010-01-05|
|214.379993|    215.23|210.750004|210.969995|138040000|27.333178|2010-01-06|
|    211.75|212.000006|209.050005|    210.58|119282800| 27.28265|2010-01-07|
|210.299994|212.000006|209.060005|211.980005|111902700|27.464034|2010-01-08|
+----------+----------+----------+----------+---------+---------+----------+
only showing top 5 rows



In [14]:
df2.dtypes

[('Open', 'double'),
 ('High', 'double'),
 ('Low', 'double'),
 ('Close', 'double'),
 ('Volume', 'int'),
 ('Adj Close', 'double'),
 ('Date_new', 'date')]

In [15]:
row_list = df2.collect()

In [16]:
row_list[:5]

[Row(Open=213.429998, High=214.499996, Low=212.380001, Close=214.009998, Volume=123432400, Adj Close=27.727039, Date_new=datetime.date(2010, 1, 4)),
 Row(Open=214.599998, High=215.589994, Low=213.249994, Close=214.379993, Volume=150476200, Adj Close=27.774976, Date_new=datetime.date(2010, 1, 5)),
 Row(Open=214.379993, High=215.23, Low=210.750004, Close=210.969995, Volume=138040000, Adj Close=27.333178, Date_new=datetime.date(2010, 1, 6)),
 Row(Open=211.75, High=212.000006, Low=209.050005, Close=210.58, Volume=119282800, Adj Close=27.28265, Date_new=datetime.date(2010, 1, 7)),
 Row(Open=210.299994, High=212.000006, Low=209.060005, Close=211.980005, Volume=111902700, Adj Close=27.464034, Date_new=datetime.date(2010, 1, 8))]

In [26]:
row_list[0]['Date_new']

datetime.date(2010, 1, 4)

In [18]:
df3 = df2.withColumn('Year', year(df2.Date_new))

In [19]:
df3.show()

+----------+----------+----------+----------+---------+---------+----------+----+
|      Open|      High|       Low|     Close|   Volume|Adj Close|  Date_new|Year|
+----------+----------+----------+----------+---------+---------+----------+----+
|213.429998|214.499996|212.380001|214.009998|123432400|27.727039|2010-01-04|2010|
|214.599998|215.589994|213.249994|214.379993|150476200|27.774976|2010-01-05|2010|
|214.379993|    215.23|210.750004|210.969995|138040000|27.333178|2010-01-06|2010|
|    211.75|212.000006|209.050005|    210.58|119282800| 27.28265|2010-01-07|2010|
|210.299994|212.000006|209.060005|211.980005|111902700|27.464034|2010-01-08|2010|
|212.799997|213.000002|208.450005|210.110003|115557400|27.221758|2010-01-11|2010|
|209.189995|209.769995|206.419998|207.720001|148614900| 26.91211|2010-01-12|2010|
|207.870005|210.929995|204.099998|210.650002|151473000| 27.29172|2010-01-13|2010|
|210.110003|210.459997|209.020004|    209.43|108223500|27.133657|2010-01-14|2010|
|210.929995|211.

In [20]:
df4 = df3.groupBy('Year').avg('Close')

In [21]:
df4.show()

+----+------------------+
|Year|        avg(Close)|
+----+------------------+
|2015|120.03999980555547|
|2013| 472.6348802857143|
|2014| 295.4023416507935|
|2012| 576.0497195640002|
|2016|104.60400786904763|
|2010| 259.8424600000002|
|2011|364.00432532142867|
+----+------------------+



In [109]:
df4.select(['Year', format_number(df4['avg(Close)'], 2).alias('avgClose')]).orderBy(df4['Year'], ascending=False).show()

+----+--------+
|Year|avgClose|
+----+--------+
|2016|  103.62|
|2015|  120.04|
|2014|  295.40|
|2013|  472.63|
|2012|  576.05|
|2011|  364.00|
|2010|  259.84|
+----+--------+



In [31]:
DF1 = spark.createDataFrame(data=[(5,), (3,), (12,)], schema=['Age']) # sample dataframe

In [32]:
DF1.show()

+---+
|Age|
+---+
|  5|
|  3|
| 12|
+---+



In [35]:
DF2 = spark.createDataFrame(data=[(row_list[0]['Date_new'],), (row_list[100]['Date_new'],)], schema=['Date'])

In [36]:
DF2.show()

+----------+
|      Date|
+----------+
|2010-01-04|
|2010-05-27|
+----------+



In [40]:
DF2.select(year(DF2.Date).alias("Year")).show() # year function

+----+
|Year|
+----+
|2010|
|2010|
+----+



In [41]:
DF2.select(month(DF2.Date).alias("Month")).show() # month function

+-----+
|Month|
+-----+
|    1|
|    5|
+-----+



In [45]:
DF2.select(dayofmonth(DF2.Date).alias("Day of Month")).show() # day function

+------------+
|Day of Month|
+------------+
|           4|
|          27|
+------------+



In [46]:
DF2.select(dayofyear(DF2.Date).alias("Day of Year")).show() # day function

+-----------+
|Day of Year|
+-----------+
|          4|
|        147|
+-----------+



In [47]:
DF2.select(dayofweek(DF2.Date).alias("Day of Week")).show() # day function

+-----------+
|Day of Week|
+-----------+
|          2|
|          5|
+-----------+



In [49]:
DF2.select(weekofyear(DF2.Date).alias("Week")).show() # week function

+----+
|Week|
+----+
|   1|
|  21|
+----+



In [50]:
DF2.select(quarter(DF2.Date).alias("Quarter")).show() # quarter function

+-------+
|Quarter|
+-------+
|      1|
|      2|
+-------+



In [51]:
# reading the appl_stock.csv after fixing date format error

In [52]:
DF2 = spark.read.csv("/media/sf_Ubuntu_Shared/appl_stock.csv", inferSchema=True, header=True)

In [53]:
DF2.dtypes #note that the correct dtype has been assigned to Date column

[('Date', 'timestamp'),
 ('Open', 'double'),
 ('High', 'double'),
 ('Low', 'double'),
 ('Close', 'double'),
 ('Volume', 'int'),
 ('Adj Close', 'double')]

In [54]:
DF2.show(3)

+-------------------+----------+----------+----------+----------+---------+---------+
|               Date|      Open|      High|       Low|     Close|   Volume|Adj Close|
+-------------------+----------+----------+----------+----------+---------+---------+
|2010-01-04 00:00:00|213.429998|214.499996|212.380001|214.009998|123432400|27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|213.249994|214.379993|150476200|27.774976|
|2010-01-06 00:00:00|214.379993|    215.23|210.750004|210.969995|138040000|27.333178|
+-------------------+----------+----------+----------+----------+---------+---------+
only showing top 3 rows



In [55]:
row_list = DF2.collect()

In [56]:
row_list[0]

Row(Date=datetime.datetime(2010, 1, 4, 0, 0), Open=213.429998, High=214.499996, Low=212.380001, Close=214.009998, Volume=123432400, Adj Close=27.727039)

In [57]:
row_list[0]['Date']

datetime.datetime(2010, 1, 4, 0, 0)

## Pandas UDF

In [7]:
df1 = spark.createDataFrame(data =[(3, 7), (6, 2), (1, 9)], 
                           schema = ['A', 'B'])

In [8]:
df1.show()

+---+---+
|  A|  B|
+---+---+
|  3|  7|
|  6|  2|
|  1|  9|
+---+---+



In [9]:
from pyspark.sql.functions import pandas_udf, col

In [10]:
import numpy, pandas

In [11]:
import pyarrow

In [22]:
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "true")

In [16]:
def f1(x, y):
    return (x**2 + y**2)

In [17]:
f2 = pandas_udf(f=f1, returnType=IntegerType())

In [23]:
df1.select(f2(col('A'), col('B'))).show() # use pandas UDF

+--------+
|f1(A, B)|
+--------+
|      58|
|      40|
|      82|
+--------+



In [26]:
df2 = df1.select((df1.A > df1.B).alias('col1'))

In [27]:
df2.show()

+-----+
| col1|
+-----+
|false|
| true|
|false|
+-----+



In [29]:
df2.select(df2.col1.cast(FloatType())).show()

+----+
|col1|
+----+
| 0.0|
| 1.0|
| 0.0|
+----+



In [31]:
t = 4

In [32]:
df3 = df1.select((df1.A > t).alias('col'))

In [33]:
df3.show()

+-----+
|  col|
+-----+
|false|
| true|
|false|
+-----+



In [34]:
df3.select(df3.col.cast(IntegerType())).show()

+---+
|col|
+---+
|  0|
|  1|
|  0|
+---+

