# Import HDF5 file into Vast DB

## Source Data

In [1]:
SOURCE_URL = "https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000/"

## Vast DB destination

In [2]:
import os

VASTDB_ENDPOINT = os.getenv("VASTDB_ENDPOINT")
VASTDB_ACCESS_KEY = os.getenv("VASTDB_ACCESS_KEY")
VASTDB_SECRET_KEY = os.getenv("VASTDB_SECRET_KEY")

# Use NYT BUCKET (DB) for now
VASTDB_NYT_BUCKET=os.getenv("VASTDB_NYT_BUCKET")
schema_name = 'cosmology'
table_name = 'particles'

url = "https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000/"

## Vast DB utility functions

In [3]:
# Source: https://vast-data.github.io/data-platform-field-docs/vast_database/ingestion/python_sdk_parquet_import.html

import io
import os
import pyarrow as pa
from pyarrow import csv as pa_csv
import pyarrow.parquet as pq
from io import StringIO
import numpy as np
import pandas as pd
import vastdb
from vastdb.config import QueryConfig

def read_parquet(file_path):
    """Reads Parquet data from a file."""
    try:
        return pq.read_table(file_path)
    except Exception as e:
        raise RuntimeError(f"Error reading Parquet file: {e}") from e

def connect_to_vastdb(endpoint, access_key, secret_key):
    """Connects to VastDB."""
    try:
        session = vastdb.connect(endpoint=endpoint, access=access_key, secret=secret_key)
        print("Connected to VastDB")
        return session
    except Exception as e:
        raise RuntimeError(f"Failed to connect to VastDB: {e}") from e

def write_to_vastdb(session, bucket_name, schema_name, table_name, pa_table):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)

        table = schema.table(table_name, fail_if_missing=False) or schema.create_table(table_name, pa_table.schema)

        columns_to_add = get_columns_to_add(table.arrow_schema, pa_table.schema)
        for column in columns_to_add:
            table.add_column(column)

        table.insert(pa_table)

def get_columns_to_add(existing_schema, desired_schema):
    """Identifies columns to add to an existing schema."""
    existing_fields = set(existing_schema.names)
    desired_fields = set(desired_schema.names)
    return [pa.schema([pa.field(name, desired_schema.field(name).type)]) for name in desired_fields - existing_fields]


def query_vastdb(session, bucket_name, schema_name, table_name, limit=None):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)
        table = schema.table(table_name, fail_if_missing=False) or schema.create_table(table_name, pa_table.schema)

        if limit:
            # See: https://vast-data.github.io/data-platform-field-docs/vast_database/sdk_ref/limit_n.html
            config = QueryConfig(
                num_splits=1,                	  # Manually specify 1 split
                num_sub_splits=1,                 # Each split will be divided into 1 subsplits
                limit_rows_per_sub_split=limit,   # Each subsplit will process 10 rows at a time
            )
            batches = table.select(config=config)
            first_batch = next(batches)
            return first_batch.to_pandas()
        else:
            return table.select().read_all().to_pandas()

def drop_vastdb_table(session, bucket_name, schema_name, table_name):
    """Writes data to VastDB."""
    with session.transaction() as tx:
        bucket = tx.bucket(bucket_name)
        schema = bucket.schema(schema_name, fail_if_missing=False) or bucket.create_schema(schema_name)

        table = schema.table(table_name, fail_if_missing=False)
        if table:
            table.drop()


In [4]:
session = connect_to_vastdb(VASTDB_ENDPOINT, VASTDB_ACCESS_KEY, VASTDB_SECRET_KEY)

Connected to VastDB


## List files

In [5]:
import requests
from bs4 import BeautifulSoup

def get_file_list(source_url):
    # Send a GET request to fetch the HTML content of the directory
    response = requests.get(source_url)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Parse the HTML content using BeautifulSoup
        soup = BeautifulSoup(response.text, 'html.parser')
        
        # Find all <a> tags (which typically contain links)
        files = []
        for link in soup.find_all('a'):
            href = link.get('href')
            if href and href.endswith('.hdf5'):
                files.append(f"{source_url}/{href}")
        
        return files
    else:
        print(f"Failed to access {url}. HTTP Status Code: {response.status_code}")


In [6]:
# check the first few urls
get_file_list(SOURCE_URL)[1:5]

['https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000//snap_028_z000p000.1.hdf5',
 'https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000//snap_028_z000p000.2.hdf5',
 'https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000//snap_028_z000p000.3.hdf5',
 'https://virgodb.cosma.dur.ac.uk/public/agb/snapshot_028_z000p000//snap_028_z000p000.4.hdf5']

## Download and Import Data

In [7]:
# Ensure we don't have old data in the table
drop_vastdb_table(session, VASTDB_NYT_BUCKET, schema_name, table_name)

In [8]:
from tqdm import tqdm
import os
import requests

def download_file(file_url, save_dir):
    """
    Downloads a file from the given URL and saves it to the specified directory with a progress bar.
    
    Args:
        file_url (str): URL of the file to download.
        save_dir (str): Directory where the file will be saved.
    
    Returns:
        str: Local path to the downloaded file.
    """
    filename = file_url.split('/')[-1]
    local_path = os.path.join(save_dir, filename)

    response = requests.get(file_url, stream=True)
    response.raise_for_status()  # Raise HTTP errors, if any

    total_size = int(response.headers.get('content-length', 0))  # Total file size in bytes
    chunk_size = 8192  # Read in chunks of 8 KB

    with open(local_path, 'wb') as f, tqdm(
        desc=f"Downloading {filename}",
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
        mininterval=0.5,
    ) as progress:
        for chunk in response.iter_content(chunk_size=chunk_size):
            f.write(chunk)
            progress.update(len(chunk))

    return local_path


In [9]:
import h5py
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import os
from tqdm.notebook import tqdm

def import_table_particles(hdf5_file_path):

    with h5py.File(hdf5_file_path, 'r') as f:
        # Get dataset sizes
        coordinates_dataset = f['PartType0/Coordinates']
        velocities_dataset = f['PartType0/Velocity']
        mass_dataset = f['PartType0/Mass']
    
        chunk_size = 10000  # Number of rows to process at a time
        total_rows = coordinates_dataset.shape[0]
    
        # Create a progress bar using tqdm
        progress_bar = tqdm(total=total_rows, 
                            unit="rows", 
                            desc="Importing 'Particles' to Vast DB",
                            mininterval=0.5)
    
        for start in range(0, total_rows, chunk_size):
            end = min(start + chunk_size, total_rows)
    
            # Load a chunk of data
            coordinates_chunk = coordinates_dataset[start:end]
            velocities_chunk = velocities_dataset[start:end]
            mass_chunk = mass_dataset[start:end]
    
            # Convert to Arrow Table
            table = pa.Table.from_pandas(pd.DataFrame({
                'Coordinates': list(coordinates_chunk),
                'Velocity': list(velocities_chunk),
                'Mass': mass_chunk
            }))
    
            write_to_vastdb(session, VASTDB_NYT_BUCKET, schema_name, table_name, table)
            progress_bar.update(chunk_size)
            
        progress_bar.close()

In [None]:
import shutil

# Temporary directory to save downloaded files
temp_dir = "temp_files"

if os.path.exists(temp_dir):
    shutil.rmtree(temp_dir)

os.makedirs(temp_dir, exist_ok=False)

for file_url in get_file_list(SOURCE_URL):
    local_path = download_file(file_url, temp_dir)    
    import_table_particles(local_path)

    if os.path.exists(local_path):
        os.remove(local_path)

# Cleanup temporary directory
os.rmdir(temp_dir)

Downloading snap_028_z000p000.0.hdf5:   0%|          | 0.00/1.84G [00:00<?, ?B/s]

Importing 'Particles' to Vast DB:   0%|          | 0/12643607 [00:00<?, ?rows/s]

Downloading snap_028_z000p000.1.hdf5:   0%|          | 0.00/1.80G [00:00<?, ?B/s]

Importing 'Particles' to Vast DB:   0%|          | 0/12597522 [00:00<?, ?rows/s]

Downloading snap_028_z000p000.2.hdf5:   0%|          | 0.00/1.65G [00:00<?, ?B/s]

## Verify data in Vast DB

In [None]:
df = query_vastdb(session, VASTDB_NYT_BUCKET, schema_name, table_name, limit=5)
df