In [1]:
import polars as pl
pl.Config.set_tbl_cols(10)
print(pl.__version__)

import os

1.18.0


## Get list of files

In [2]:
parquet_folder = "/home/vikas/Desktop/Globus/Gaia/gaia_parquet"

def get_files(table_name):

    return os.path.join(parquet_folder, table_name, "*")

## Run query

### Eager

In [3]:
# df_photometry = (
#     pl.read_parquet(get_files("photometry"))
#     .filter(pl.col("phot_bp_mean_flux") > 5)
# )

### Lazy

#### Number of unique objects

In [4]:
def stream_query_1():

    q1 = (
        pl.scan_parquet(get_files("photometry"))
          .select("object_id")
          .unique()
    )

    #print(q1.explain(streaming = True))

    return q1.collect(streaming = True)

In [5]:
#stream_query_1().shape

#### Highest and lowest brightness

In [6]:
def stream_query_2():

    q1 = (
        pl.scan_parquet(get_files("photometry"))
          .group_by("healpix")
          .agg(
              pl.col("phot_g_mean_mag").min().alias("min_phot_g_mean_mag"),
              pl.col("phot_g_mean_mag").max().alias("max_phot_g_mean_mag")
          )
        .sort("healpix")
    )

    #print(q1.explain(streaming = True))

    return q1.collect(streaming = True)

In [7]:
#stream_query_2()

#### Joins

In [8]:
def stream_query_3(max_healpix):

    q1 = (pl.scan_parquet(get_files("radial_velocity"))
            .select("radial_velocity",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q2 = (pl.scan_parquet(get_files("astrometry"))
            .select("ra_error",
                    "dec_error",
                    "parallax_error",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q3 = (pl.scan_parquet(get_files("gspphot"))
            .select("distance_gspphot",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q4 = (
        pl.scan_parquet(get_files("photometry"))
          .filter(pl.col("healpix") < max_healpix)
          .select(
              "object_id",
              "healpix",
              "phot_g_mean_mag"
          )
          .join(q1, 
                on = ["object_id", "healpix"],
                how = "left")
          .join(q2, 
                on = ["object_id", "healpix"],
                how = "left")
          .join(q3, 
                on = ["object_id", "healpix"],
                how = "left")
          .fill_null(strategy = "zero")
          .group_by("healpix", maintain_order = False) 
          .agg(
              pl.col("phot_g_mean_mag").min().alias("min_phot_g_mean_mag"),
              pl.col("phot_g_mean_mag").max().alias("max_phot_g_mean_mag"),
              pl.col("radial_velocity").min().alias("min_radial_velocity"),
              pl.col("radial_velocity").max().alias("max_radial_velocity"),
              pl.col("ra_error").mean().alias("mean_ra_error"),
              pl.col("dec_error").mean().alias("mean_dec_error"),
              pl.col("parallax_error").mean().alias("mean_parallax_error"),
              pl.col("distance_gspphot").min().alias("min_distance_gspphot"),
              pl.col("distance_gspphot").max().alias("max_distance_gspphot")
              )
          .sort("healpix")
    )

    #print(q2.explain(streaming = True))

    return q4.collect(streaming = True)

In [12]:
%%time

stream_query_3(1300)

CPU times: user 1min 1s, sys: 4.8 s, total: 1min 5s
Wall time: 15.2 s


healpix,min_phot_g_mean_mag,max_phot_g_mean_mag,min_radial_velocity,max_radial_velocity,mean_ra_error,mean_dec_error,mean_parallax_error,min_distance_gspphot,max_distance_gspphot
i32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,6.855546,19.842781,-185.370361,244.173599,0.056582,0.051986,,41.3937,19867.748047
1,7.355467,19.706682,-224.071869,243.473114,0.058384,0.054546,,8.6131,10106.150391
2,7.312822,19.992142,-545.814758,241.080765,0.062313,0.056831,,24.5298,15997.161133
3,6.964732,19.499153,-295.097992,235.114029,0.060839,0.054957,,30.2407,15285.828125
4,3.382374,20.126875,-258.272003,256.130188,0.062764,0.056046,,31.565901,9622.666992
…,…,…,…,…,…,…,…,…,…
1295,6.074595,20.110273,-98.658363,296.361847,0.034101,0.049451,,15.0559,14408.869141
1296,4.80866,20.670574,-93.955078,348.503418,0.039399,0.049156,,26.889999,12304.650391
1297,6.93891,20.097332,-44.543076,232.561188,0.031051,0.052891,,73.528,8851.900391
1298,6.049368,20.341269,-145.648224,426.696136,0.03497,0.048439,,66.710197,11965.27832


#### Red and blue shift

In [10]:
def stream_query_4():
    
    q1 = (
        pl.scan_parquet(get_files("radial_velocity"))
          .select("radial_velocity")
          .with_columns(
              # Red shift
              pl.when(pl.col("radial_velocity") > 0)
                .then(1)
                .otherwise(0)
                .alias("red"),
              
              # Blue shift
              pl.when(pl.col("radial_velocity") < 0)
                .then(-1)
                .otherwise(0)
                .alias("blue"),
          )
    )

    # Compute the sum of each column
    result = q1.select([
        pl.col("red").sum().alias("sum_red"),
        pl.col("blue").sum().alias("sum_blue")
    ])

    #print(result.explain(streaming=True))

    return result.collect(streaming = True)

In [11]:
# stream_query_4()