This notebook highlights on-going exploratory work comparing Postgresql query performance with PySpark SQL queries using the SDSS MaNGA DR15 dataset.

In [1]:
%%configure -f
{"name": "brian-query-benchmarks-2", "executorMemory": "36G", "numExecutors": 15, "executorCores": 10,
 "conf": {"spark.yarn.appMasterEnv.PYSPARK_PYTHON":"python3"}}

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
21,application_1609885494103_0025,pyspark,idle,Link,Link,


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
22,application_1609885494103_0026,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
# create a simple timer context manager
class Timer(object):
    def __enter__(self):
        self.start_time = time.perf_counter()
        return self
   
    def __exit__(self, *exc_info):
        self.end_time = time.perf_counter()
        elapsed = self.end_time-self.start_time
        print('Elapsed time [sec]:', elapsed)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
# creates a temporary "database" table
with Timer():
    drpall = spark.read.parquet('hdfs:///manga/brian-test/dr15/v2_4_3/drpall')
    drpall.createOrReplaceTempView('drpall')
    dapall = spark.read.parquet('hdfs:///manga/brian-test/dr15/v2_4_3/dapall')
    dapall.createOrReplaceTempView('dapall')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Elapsed time [sec]: 47.467089012265205

In [5]:
# creates a temporary "database" table for the DAP maps
with Timer():
    maps = spark.read.parquet('hdfs:///manga/brian-test/dr15/v2_4_3/maps')
    maps.createOrReplaceTempView('maps')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Elapsed time [sec]: 2.9763775691390038

In [6]:
# creates a temporary table for the DRP cubes
with Timer():
    cubes = spark.read.parquet('hdfs:///manga/arik-test/dr15/v2_4_3/logcube_voxel')
    cubes.createOrReplaceTempView('cubes')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Elapsed time [sec]: 1.4119328930974007

## Example Query 1
Select all galaxies with an H-alpha flux value > 5 in more than 20% of "good" spaxels, with "good" defined as a measured DAP value != -1

### Raw SQL 

SELECT anon_1.mangadatadb_cube_mangaid, anon_1.mangadatadb_cube_plate, concat(anon_1.mangadatadb_cube_plate, '-', anon_1.mangadatadb_ifudesign_name) AS plateifu, anon_1.mangadatadb_ifudesign_name
FROM (SELECT mangadatadb.cube.mangaid AS mangadatadb_cube_mangaid, mangadatadb.cube.plate AS mangadatadb_cube_plate, concat(mangadatadb.cube.plate, '-', mangadatadb.ifudesign.name) AS plateifu, mangadatadb.ifudesign.name AS mangadatadb_ifudesign_name, mangadapdb.cleanspaxelprop7.emline_gflux_ha_6564 AS mangadapdb_cleanspaxelprop7_emline_gflux_ha_6564, mangadapdb.cleanspaxelprop7.x AS mangadapdb_cleanspaxelprop7_x, mangadapdb.cleanspaxelprop7.y AS mangadapdb_cleanspaxelprop7_y
FROM mangadatadb.cube JOIN mangadatadb.ifudesign ON mangadatadb.ifudesign.pk = mangadatadb.cube.ifudesign_pk JOIN mangadapdb.file ON mangadatadb.cube.pk = mangadapdb.file.cube_pk JOIN mangadapdb.cleanspaxelprop7 ON mangadapdb.file.pk = mangadapdb.cleanspaxelprop7.file_pk JOIN mangadatadb.pipeline_info AS drpalias ON drpalias.pk = mangadatadb.cube.pipeline_info_pk JOIN mangadatadb.pipeline_info AS dapalias ON dapalias.pk = mangadapdb.file.pipeline_info_pk JOIN (SELECT mangadapdb.cleanspaxelprop7.file_pk AS binfile, count(mangadapdb.cleanspaxelprop7.pk) AS goodcount
FROM mangadapdb.cleanspaxelprop7
WHERE mangadapdb.cleanspaxelprop7.binid_binned_spectra != -1 AND mangadapdb.cleanspaxelprop7.binid_stellar_continua != -1 AND mangadapdb.cleanspaxelprop7.binid_spectral_indices != -1 AND mangadapdb.cleanspaxelprop7.binid_em_line_moments != -1 AND mangadapdb.cleanspaxelprop7.binid_em_line_models != -1 GROUP BY mangadapdb.cleanspaxelprop7.file_pk) AS bingood ON bingood.binfile = mangadapdb.cleanspaxelprop7.file_pk JOIN (SELECT mangadapdb.cleanspaxelprop7.file_pk AS valfile, count(mangadapdb.cleanspaxelprop7.pk) AS valcount
FROM mangadapdb.cleanspaxelprop7
WHERE mangadapdb.cleanspaxelprop7.emline_gflux_ha_6564 > 5 GROUP BY mangadapdb.cleanspaxelprop7.file_pk) AS goodhacount ON goodhacount.valfile = mangadapdb.cleanspaxelprop7.file_pk
WHERE drpalias.pk = 32 AND dapalias.pk = 34 AND goodhacount.valcount >= 0.2 * bingood.goodcount) AS anon_1 GROUP BY anon_1.mangadatadb_cube_mangaid, anon_1.mangadatadb_cube_plate, concat(anon_1.mangadatadb_cube_plate, '-', anon_1.mangadatadb_ifudesign_name), anon_1.mangadatadb_ifudesign_name

### Postgres Results
Above postgres query time takes ~34 minutes to return 664 row results.

1st run
Time: 1998737.520 ms (33:18.738)

2nd run
Time: 2071235.639 ms (34:31.236)



### PySpark Query via DataFrames

In [7]:
# define good spaxels
good_spaxels = ((maps.binid_binned_spectra != -1) & 
                   (maps.binid_stellar_continua != -1) & 
                   (maps.binid_em_line_models != -1) & 
                   (maps.binid_em_line_moments != -1) & 
                   (maps.binid_spectral_indices != -1))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
# run the PySpark query
with Timer():
    # get total counts of number of good spaxels, grouped by plateifu
    tc = maps.filter(good_spaxels).groupby('plateifu').count().withColumnRenamed('count', 'totalc')

    # get counts of number of good spaxels with H-alpha > 5, grouped by plateifu
    fc = maps.filter(good_spaxels).filter(maps['emline_gflux_ha_6564'] > 5).groupby('plateifu').count().withColumnRenamed('count', 'filterc')

    # join the tables and filter where 
    tmp = tc.join(fc, 'plateifu')
    tmp.filter(tmp.filterc >= 0.2 * tmp.totalc).count()
    

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

664
Elapsed time [sec]: 17.228862084448338

### PySpark Query via Spark SQL

In [13]:
# define sql command to count total number of good spaxels
totsql = """ select f.plateifu, count(f.*) as tcount from maps as f \
where f.binid_binned_spectra != -1 and f.binid_stellar_continua != -1 and f.binid_em_line_models != -1 \
and f.binid_em_line_moments != -1 and f.binid_spectral_indices != -1 group by f.plateifu
"""

# define sql command to count number of good spaxels with H-alpha flux > 5
hasql = """ select f.plateifu, count(f.*) as vcount from maps as f \
where f.emline_gflux_ha_6564 > 5 and f.binid_binned_spectra != -1 and f.binid_stellar_continua != -1 \
and f.binid_em_line_models != -1 and f.binid_em_line_moments != -1 and f.binid_spectral_indices != -1 \
group by f.plateifu
"""

# construct the complete sql command to select galaxies that have an H-alpha flux > 5 
# in more than 20% of their spaxels 
sql = """ select t.plateifu, t.tcount, v.vcount from (select f.plateifu, count(f.*) as tcount from maps as f \
where f.binid_binned_spectra != -1 and f.binid_stellar_continua != -1 and f.binid_em_line_models != -1 \
and f.binid_em_line_moments != -1 and f.binid_spectral_indices != -1 group by f.plateifu) as t, (select f.plateifu, count(f.*) as vcount from maps as f \
where f.emline_gflux_ha_6564 > 5 and f.binid_binned_spectra != -1 and f.binid_stellar_continua != -1 \
and f.binid_em_line_models != -1 and f.binid_em_line_moments != -1 and f.binid_spectral_indices != -1 \
group by f.plateifu) as v where t.plateifu=v.plateifu and \
v.vcount >= 0.2*t.tcount
"""

with Timer():
    spark.sql(sql).count()


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

664
Elapsed time [sec]: 146.5893799038604

## Example Query 2
Select all galaxies with an NSA sersic_n index < 2, an H-alpha summed-EW > 6, and an NSA sersic log stellar mass between 9.5-11

### Raw SQL 

SELECT mangadatadb.cube.mangaid AS "cube.mangaid", mangadatadb.cube.plate AS "cube.plate", concat(mangadatadb.cube.plate, '-', mangadatadb.ifudesign.name) AS "cube.plateifu", mangadatadb.ifudesign.name AS "ifu.name", mangadapdb.cleanspaxelprop7.emline_sew_ha_6564 AS emline_sew_ha_6564, mangasampledb.nsa.sersic_n AS "nsa.sersic_n", CAST(CASE WHEN (mangasampledb.nsa.sersic_mass > 0.0) THEN log(mangasampledb.nsa.sersic_mass) WHEN (mangasampledb.nsa.sersic_mass = 0.0) THEN 0.0 END AS FLOAT) AS "nsa.sersic_logmass", mangadapdb.cleanspaxelprop7.x AS "spaxelprop.x", mangadapdb.cleanspaxelprop7.y AS "spaxelprop.y"
FROM mangadatadb.cube JOIN mangadatadb.ifudesign ON mangadatadb.ifudesign.pk = mangadatadb.cube.ifudesign_pk JOIN mangadapdb.file ON mangadatadb.cube.pk = mangadapdb.file.cube_pk JOIN mangadapdb.cleanspaxelprop7 ON mangadapdb.file.pk = mangadapdb.cleanspaxelprop7.file_pk JOIN mangasampledb.manga_target ON mangasampledb.manga_target.pk = mangadatadb.cube.manga_target_pk JOIN mangasampledb.manga_target_to_nsa ON mangasampledb.manga_target.pk = mangasampledb.manga_target_to_nsa.manga_target_pk JOIN mangasampledb.nsa ON mangasampledb.nsa.pk = mangasampledb.manga_target_to_nsa.nsa_pk JOIN mangadatadb.pipeline_info AS drpalias ON drpalias.pk = mangadatadb.cube.pipeline_info_pk JOIN mangadatadb.pipeline_info AS dapalias ON dapalias.pk = mangadapdb.file.pipeline_info_pk
WHERE CAST(CASE WHEN (mangasampledb.nsa.sersic_mass > 0.0) THEN log(mangasampledb.nsa.sersic_mass) WHEN (mangasampledb.nsa.sersic_mass = 0.0) THEN 0.0 END AS FLOAT) >= 9.5 AND CAST(CASE WHEN (mangasampledb.nsa.sersic_mass > 0.0) THEN log(mangasampledb.nsa.sersic_mass) WHEN (mangasampledb.nsa.sersic_mass = 0.0) THEN 0.0 END AS FLOAT) < 11.0 AND mangasampledb.nsa.sersic_n < 2.0 AND mangadapdb.cleanspaxelprop7.emline_sew_ha_6564 > 6.0 AND drpalias.pk = 32 AND dapalias.pk = 34

### Postgres Results
Above postgres query takes ~3 mins to run first time, ~15 seconds after query caching, returning 1,235,317 row results.

1st run
Time: 201567.467 ms (03:21.567)

2nd run
Time: 15866.201 ms (00:15.866)


### PySpark Query via DataFrames

In [11]:
from pyspark.sql.functions import log10
with Timer():
    # filter drpall on sersic index and log of stellar mass
    sub = drpall.filter((drpall.nsa_sersic_n < 2.0) & 
                  (log10(drpall.nsa_sersic_mass) >=9.5) & 
                  (log10(drpall.nsa_sersic_mass) < 11.))

    # filter the maps on h-alpha EW
    ew = maps.filter(good_spaxels).filter(maps.emline_sew_ha_6564 > 6)

    # join the tables and select some,count
    tmp = ew.join(sub, 'plateifu')
    tmp.select(tmp.plateifu, tmp.x, tmp.y, tmp.emline_sew_ha_6564, tmp.nsa_sersic_n, log10(tmp.nsa_sersic_mass)).count()


1235317
Elapsed time [sec]: 3.6993655618280172

### PySpark Query via Spark SQL

In [12]:
# create sql command
sql = """select f.plateifu, f.emline_sew_ha_6564, d.nsa_sersic_n, log10(d.nsa_sersic_mass) \
from maps as f join drpall as d on d.plateifu=f.plateifu \
where (f.emline_sew_ha_6564 > 6 and f.binid_binned_spectra != -1 and f.binid_stellar_continua != -1 and \
f.binid_em_line_models != -1 and f.binid_em_line_moments != -1 and f.binid_spectral_indices != -1 and \
d.nsa_sersic_n < 2.0 and log10(d.nsa_sersic_mass) between 9.5 and 11.0)"""

# run Spark sql

with Timer():
    spark.sql(sql).count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1235317
Elapsed time [sec]: 3.1305666361004114