# Query Vast DB

## Vast DB endpoint

In [1]:
import os

VASTDB_ENDPOINT = os.getenv("VASTDB_ENDPOINT")
VASTDB_ACCESS_KEY = os.getenv("VASTDB_ACCESS_KEY")
VASTDB_SECRET_KEY = os.getenv("VASTDB_SECRET_KEY")

# Use NYT BUCKET (DB) for now
VASTDB_NYT_BUCKET=os.getenv("VASTDB_NYT_BUCKET")

bucket_name = VASTDB_NYT_BUCKET
schema_name = 'cosmology'
table_name = 'particles'

## Python SDK Connection

In [2]:
# Source: https://vast-data.github.io/data-platform-field-docs/vast_database/ingestion/python_sdk_parquet_import.html

import io
import os
import pyarrow as pa
from pyarrow import csv as pa_csv
import pyarrow.parquet as pq
from io import StringIO
import numpy as np
import pandas as pd
import vastdb
from vastdb.config import QueryConfig

def connect_to_vastdb(endpoint, access_key, secret_key):
    """Connects to VastDB."""
    try:
        session = vastdb.connect(endpoint=endpoint, access=access_key, secret=secret_key)
        print("Connected to VastDB")
        return session
    except Exception as e:
        raise RuntimeError(f"Failed to connect to VastDB: {e}") from e

def query_vastdb(session, bucket_name, schema_name, table_name, limit=None):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)
        table = schema.table(table_name, fail_if_missing=False) or schema.create_table(table_name, pa_table.schema)

        if limit:
            # See: https://vast-data.github.io/data-platform-field-docs/vast_database/sdk_ref/limit_n.html
            config = QueryConfig(
                num_splits=1,                	  # Manually specify 1 split
                num_sub_splits=1,                 # Each split will be divided into 1 subsplits
                limit_rows_per_sub_split=limit,   # Each subsplit will process 10 rows at a time
            )
            batches = table.select(config=config)
            first_batch = next(batches)
            return first_batch.to_pandas()
        else:
            return table.select().read_all().to_pandas()

In [3]:
session = connect_to_vastdb(VASTDB_ENDPOINT, VASTDB_ACCESS_KEY, VASTDB_SECRET_KEY)

Connected to VastDB


## Inspect a few records

In [4]:
import time

start_time = time.time()
df = query_vastdb(session, bucket_name, schema_name, table_name, limit=5)
end_time = time.time()

print(f"Query execution time: {end_time - start_time} seconds")

Query execution time: 0.08570432662963867 seconds


In [5]:
df

Unnamed: 0,Coordinates,Velocity,Mass
0,"[0.06941523640696762, 0.05991952653932303, 0.0...","[-89.80153, 146.60365, 17.358175]",0.000123
1,"[0.016258423139909595, 0.13074506414506318, 0....","[-76.61928, 137.16403, 27.057497]",0.000123
2,"[0.05543383231113987, 0.12075238123551599, 0.0...","[-71.89609, 142.62872, 25.440746]",0.000123
3,"[0.09699105193528355, 0.1417782287453339, 0.05...","[-76.60565, 151.27818, 13.464693]",0.000123
4,"[0.045465816091822514, 0.16487884582547954, 0....","[-64.08211, 141.82565, 43.24954]",0.000123


### Compute the Total Mass (PartType0)

We have only loaded PartType0 into the DB

In [6]:
import time

start_time = time.time()

import duckdb
conn = duckdb.connect()

with session.transaction() as tx:
    table = tx.bucket(bucket_name).schema(schema_name).table(table_name)
    batches = table.select(columns=['Mass'])
    print(conn.execute(
    """
    SELECT SUM(Mass) FROM batches
    """
    ).arrow())

end_time = time.time()
print(f"Query execution time: {end_time - start_time} seconds")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

pyarrow.Table
sum(Mass): double
----
sum(Mass): [[89401.89030224655]]
Query execution time: 11.10551118850708 seconds


### Average Velocity of Gas Particles (PartType0) - DuckDB

In [7]:
import time
start_time = time.time()

import duckdb
conn = duckdb.connect()

with session.transaction() as tx:
    table = tx.bucket(bucket_name).schema(schema_name).table(table_name)
    batches = table.select(columns=['Velocity'])
    print(conn.execute(
    """
    SELECT 
       AVG(Velocity[1]) AS AvgVelocity_X,
       AVG(Velocity[2]) AS AvgVelocity_Y,
       AVG(Velocity[3]) AS AvgVelocity_Z
    FROM batches
    """
    ).arrow())

end_time = time.time()
print(f"Query execution time: {end_time - start_time} seconds")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

pyarrow.Table
AvgVelocity_X: double
AvgVelocity_Y: double
AvgVelocity_Z: double
----
AvgVelocity_X: [[12.279645131906921]]
AvgVelocity_Y: [[9.441816902890864]]
AvgVelocity_Z: [[123.35455341785374]]
Query execution time: 86.42967963218689 seconds


### Average Velocity of Gas Particles (PartType0) - Python

Options for accelerating:

- Distributed processing, e.g. Spark, Trino, Dash
- 5.3+ Vast DB Query Engine

In [8]:
import time
start_time = time.time()

import pyarrow as pa
import numpy as np
import sys
import time

total_count = 0  # Track total number of velocity values

sum_velocities_x = 0
sum_velocities_y = 0
sum_velocities_z = 0

avg_velocities_x = 0
avg_velocities_y = 0
avg_velocities_z = 0

batch_count = 0   # Count processed batches

with session.transaction() as tx:
    table = tx.bucket(bucket_name).schema(schema_name).table(table_name)
    batches = table.select(columns=['Velocity'])
    
    for batch in batches:
        batch_count += 1
        
        # Extract velocity data from each batch
        velocities = batch.column("Velocity")
        
        # Convert to NumPy array
        velocities = velocities.to_numpy(zero_copy_only=False)

        velocities = np.stack(velocities) 
        
        count = velocities.shape[0]
        sum_velocities_x += np.sum(velocities[:, 0])  # Sum X components
        sum_velocities_y += np.sum(velocities[:, 1])  # Sum Y components
        sum_velocities_z += np.sum(velocities[:, 2])  # Sum Z components
        total_count += count
        
        # Calculate final average velocities considering total count
        avg_velocities_x = sum_velocities_x / total_count
        avg_velocities_y = sum_velocities_y / total_count
        avg_velocities_z = sum_velocities_z / total_count

        sys.stdout.write(f"\rBatches processed: {batch_count}. "
                         f"Average Velocities: "
                         f"X = {avg_velocities_x:.5f}, "
                         f"Y = {avg_velocities_y:.5f}, "
                         f"Z = {avg_velocities_z:.5f}"
                        )
        sys.stdout.flush()

end_time = time.time()
print(f"\nQuery execution time: {end_time - start_time} seconds")

Batches processed: 5484. Average Velocities: X = 12.27965, Y = 9.44182, Z = 123.3545552
Query execution time: 582.431357383728 seconds


### Predicate push down example

In [9]:
import time
start_time = time.time()

from ibis import _

PREDICATE = (_.Mass > 0.001)
COLUMNS = ['Coordinates', 'Velocity', 'Mass']

with session.transaction() as tx:
    table = tx.bucket(bucket_name).schema(schema_name).table(table_name)
    batches = table.select(columns=COLUMNS, predicate=PREDICATE)
    df = batches.read_all().to_pandas()

end_time = time.time()
print(f"Query execution time: {end_time - start_time} seconds")

Query execution time: 2.82857346534729 seconds


In [10]:
df

Unnamed: 0,Coordinates,Velocity,Mass
0,"[0.5366043257165535, 2.9507586855666257, 0.060...","[-62.40552, 82.05905, 79.45422]",0.002860
1,"[1.4083577130599776, 3.8619471494089255, 0.659...","[-52.00628, 111.871796, 72.877495]",0.001037
2,"[1.896930915327498, 3.6891011985705258, 0.1617...","[40.638702, 177.69398, 125.804886]",0.042142
3,"[4.174399015605908, 51.07290815200347, 31.7342...","[-41.858036, -98.15191, 315.02325]",0.001285
4,"[3.949060406567913, 51.522647710560065, 31.236...","[41.77037, 40.92835, 526.1117]",0.002155
...,...,...,...
8036,"[2.994530653978653, 51.043874127236656, 31.200...","[-167.65419, 15.782176, 420.5042]",0.001188
8037,"[2.7969713901060724, 51.47328076505169, 31.570...","[-296.2971, 398.66193, -148.71107]",0.001089
8038,"[2.7784863031871385, 51.13566860463838, 31.394...","[105.399414, -230.19371, 216.46248]",0.004032
8039,"[2.7664752607531082, 51.18302120799442, 31.475...","[263.1287, -70.09283, -423.9448]",0.001118
