In [1]:
import os
from contextlib import contextmanager

import numpy as np
import pandas as pd
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from sshtunnel import SSHTunnelForwarder

In [2]:
load_dotenv("misc/.env")

# SSH and DB configuration from environment variables
SSH_HOST = os.getenv("SSH_HOST")
SSH_PORT = int(os.getenv("SSH_PORT", 22))
SSH_USER = os.getenv("SSH_USER")
SSH_PASSWORD = os.getenv("SSH_PASSWORD")

DB_HOST = os.getenv("DB_HOST")
DB_PORT = int(os.getenv("DB_PORT", 5432))
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")


@contextmanager
def __create_ssh_tunnel():

    tunnel = None
    try:

        tunnel = SSHTunnelForwarder(
            (SSH_HOST, SSH_PORT),
            ssh_username=SSH_USER,
            ssh_password=SSH_PASSWORD,
            remote_bind_address=(DB_HOST, DB_PORT),
            local_bind_address=("127.0.0.1", 5432),
        )
        tunnel.start()

        yield tunnel
    except Exception:
        raise
    finally:
        if tunnel:
            tunnel.stop()


@contextmanager
def create_db_engine():
    """
    Creates an SQLAlchemy engine connected through the SSH tunnel.
    Yields the engine for use in a with statement.
    """
    with __create_ssh_tunnel() as tunnel:
        engine = None
        try:

            connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{tunnel.local_bind_port}/{DB_NAME}"

            engine = create_engine(
                connection_string, echo=False, pool_pre_ping=True, pool_recycle=3600
            )
            yield engine
        except SQLAlchemyError as e:

            raise
        finally:
            if engine:
                engine.dispose()


def create_crowd_levels(df, target_district):
    target_column = f'{target_district.replace("_", " ")}_c_lvl'

    out = pd.qcut(
        df[target_district].rank(method="first"),
        q=[0, 0.3, 0.7, 1],  # q=[0, 0.2, 0.45, 0.65, 0.8, 1] for 5 bins
        labels=[0, 1, 2],
    )

    return out.astype(np.uint8), target_column

In [4]:
with create_db_engine() as engine:
    with engine.connect() as conn:
        query = f"SELECT * FROM crowdedness"
        chunks = pd.read_sql_query(sql=query, con=conn, chunksize=10**6)
    df = pd.concat(chunks, ignore_index=True, axis=0)