In [16]:
import polars as pl
pl.Config.set_tbl_cols(10)
print(pl.__version__)

import os
import numpy as np
print(np.__version__)

1.18.0
1.25.0


## Get list of files

In [2]:
parquet_folder = "/home/vikas/Desktop/Globus/Gaia/gaia_parquet"

def get_files(table_name):

    return os.path.join(parquet_folder, table_name, "*")

## Run query

### Eager

In [3]:
# df_photometry = (
#     pl.read_parquet(get_files("photometry"))
#     .filter(pl.col("phot_bp_mean_flux") > 5)
# )

### Lazy

#### Number of unique objects

In [4]:
def stream_query_1():

    q1 = (
        pl.scan_parquet(get_files("photometry"))
          .select("object_id")
          .unique()
    )

    #print(q1.explain(streaming = True))

    return q1.collect(streaming = True)

In [5]:
#stream_query_1().shape

#### Highest and lowest brightness

In [6]:
def stream_query_2():

    q1 = (
        pl.scan_parquet(get_files("photometry"))
          .group_by("healpix")
          .agg(
              pl.col("phot_g_mean_mag").min().alias("min_phot_g_mean_mag"),
              pl.col("phot_g_mean_mag").max().alias("max_phot_g_mean_mag")
          )
        .sort("healpix")
    )

    #print(q1.explain(streaming = True))

    return q1.collect(streaming = True)

In [7]:
#stream_query_2()

#### Joins

In [8]:
def stream_query_3(max_healpix):

    q1 = (pl.scan_parquet(get_files("radial_velocity"))
            .select("radial_velocity",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q2 = (pl.scan_parquet(get_files("astrometry"))
            .select("ra_error",
                    "dec_error",
                    "parallax_error",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q3 = (pl.scan_parquet(get_files("gspphot"))
            .select("distance_gspphot",
                    "object_id",
                    "healpix")
            .filter(pl.col("healpix") < max_healpix)
         )

    q4 = (
        pl.scan_parquet(get_files("photometry"))
          .filter(pl.col("healpix") < max_healpix)
          .select(
              "object_id",
              "healpix",
              "phot_g_mean_mag"
          )
          .join(q1, 
                on = ["object_id", "healpix"],
                how = "left")
          .join(q2, 
                on = ["object_id", "healpix"],
                how = "left")
          .join(q3, 
                on = ["object_id", "healpix"],
                how = "left")
          .fill_null(strategy = "zero")
          .group_by("healpix", maintain_order = False) 
          .agg(
              pl.col("phot_g_mean_mag").min().alias("min_phot_g_mean_mag"),
              pl.col("phot_g_mean_mag").max().alias("max_phot_g_mean_mag"),
              pl.col("radial_velocity").min().alias("min_radial_velocity"),
              pl.col("radial_velocity").max().alias("max_radial_velocity"),
              pl.col("ra_error").mean().alias("mean_ra_error"),
              pl.col("dec_error").mean().alias("mean_dec_error"),
              pl.col("parallax_error").mean().alias("mean_parallax_error"),
              pl.col("distance_gspphot").min().alias("min_distance_gspphot"),
              pl.col("distance_gspphot").max().alias("max_distance_gspphot")
              )
          .sort("healpix")
    )

    #print(q2.explain(streaming = True))

    return q4.collect(streaming = True)

In [13]:
# %%time

# stream_query_3(1300)

#### Red and blue shift

In [10]:
def stream_query_4():
    
    q1 = (
        pl.scan_parquet(get_files("radial_velocity"))
          .select("radial_velocity")
          .with_columns(
              # Red shift
              pl.when(pl.col("radial_velocity") > 0)
                .then(1)
                .otherwise(0)
                .alias("red"),
              
              # Blue shift
              pl.when(pl.col("radial_velocity") < 0)
                .then(-1)
                .otherwise(0)
                .alias("blue"),
          )
    )

    # Compute the sum of each column
    result = q1.select([
        pl.col("red").sum().alias("sum_red"),
        pl.col("blue").sum().alias("sum_blue")
    ])

    #print(result.explain(streaming=True))

    return result.collect(streaming = True)

In [11]:
# stream_query_4()

#### Calculate ZPC

In [28]:
def stream_query_5():

    q1 = (
        pl.scan_parquet(get_files("photometry"))
          .select(
              "healpix",
              "phot_g_mean_mag",
              "phot_g_mean_flux"
          )
          .with_columns(
              (pl.col("phot_g_mean_mag") + 
              (2.5 * pl.col("phot_g_mean_flux").log(base = 10)))
              .alias("g_ZP")
          )
          .group_by("healpix")
          .agg(
              pl.col("phot_g_mean_mag").min().alias("min_phot_g_mean_mag"),
              pl.col("phot_g_mean_mag").max().alias("max_phot_g_mean_mag"),
              pl.col("g_ZP").mean().alias("mean_g_ZP")
          )
          .sort("healpix")
    )

    #print(q1.explain(streaming = True))

    return q1.collect(streaming = True)

In [29]:
stream_query_5()

healpix,min_phot_g_mean_mag,max_phot_g_mean_mag,mean_g_ZP
i32,f32,f32,f32
0,6.855546,19.842781,25.687496
1,7.355467,19.706682,25.687496
2,7.312822,19.992142,25.68749
3,6.964732,19.499153,25.687496
4,3.382374,20.126875,25.687496
…,…,…,…
2685,6.030447,20.092436,25.687498
2686,4.124845,20.017384,25.687498
2687,4.425611,20.595608,25.687496
2688,4.03991,19.728607,25.687498
