Skip to content

Commit

Permalink
Use psycopg(3) for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lpsinger committed Apr 16, 2024
1 parent b0e7312 commit cc42e5a
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 111 deletions.
4 changes: 3 additions & 1 deletion conftest.py
Expand Up @@ -7,7 +7,9 @@
@pytest.fixture
def engine(postgresql):
"""Create an SQLAlchemy engine with a disposable PostgreSQL database."""
return sa.create_engine('postgresql://', poolclass=sa.pool.StaticPool,
return sa.create_engine('postgresql+psycopg://',
poolclass=sa.pool.StaticPool,
pool_reset_on_return=None,
creator=lambda: postgresql)


Expand Down
4 changes: 3 additions & 1 deletion healpix_alchemy/tests/benchmarks/conftest.py
Expand Up @@ -8,7 +8,9 @@

@pytest.fixture
def engine(postgresql):
return sa.create_engine('postgresql://', poolclass=sa.pool.StaticPool,
return sa.create_engine('postgresql+psycopg://',
poolclass=sa.pool.StaticPool,
pool_reset_on_return=None,
creator=lambda: postgresql)


Expand Down
40 changes: 19 additions & 21 deletions healpix_alchemy/tests/benchmarks/data.py
Expand Up @@ -2,11 +2,9 @@
Notes
-----
We use the psycopg2 ``copy_from`` rather than SQLAlchemy for fast insertion.
We use the psycopg ``copy`` rather than SQLAlchemy for fast insertion.
"""
import io

from astropy.coordinates import SkyCoord, uniform_spherical_random_surface
from astropy import units as u
from mocpy import MOC
Expand Down Expand Up @@ -76,8 +74,8 @@ def get_random_galaxies(n, cursor):
points = SkyCoord(get_random_points(n, RANDOM_GALAXIES_SEED))
hpx = HPX.skycoord_to_healpix(points)

f = io.StringIO('\n'.join(f'{i}' for i in hpx))
cursor.copy_from(f, Galaxy.__tablename__, columns=('hpx',))
with cursor.copy(f'COPY {Galaxy.__tablename__} (hpx) FROM STDIN') as copy:
copy.write('\n'.join(f'{i}' for i in hpx))

return points

Expand All @@ -87,16 +85,16 @@ def get_random_fields(n, cursor):
footprints = get_footprints_grid(*get_ztf_footprint_corners(), centers)
mocs = [MOC.from_polygon_skycoord(footprint) for footprint in footprints]

f = io.StringIO('\n'.join(f'{i}' for i in range(len(mocs))))
cursor.copy_from(f, Field.__tablename__)
with cursor.copy(f'COPY {Field.__tablename__} FROM STDIN') as copy:
copy.write('\n'.join(f'{i}' for i in range(len(mocs))))

f = io.StringIO(
'\n'.join(
f'{i}\t{hpx}'
for i, moc in enumerate(mocs) for hpx in Tile.tiles_from(moc)
with cursor.copy(f'COPY {FieldTile.__tablename__} FROM STDIN') as copy:
copy.write(
'\n'.join(
f'{i}\t{hpx}'
for i, moc in enumerate(mocs) for hpx in Tile.tiles_from(moc)
)
)
)
cursor.copy_from(f, FieldTile.__tablename__)

return mocs

Expand All @@ -120,15 +118,15 @@ def get_random_sky_map(n, cursor):
probdensity = rng.uniform(0, 1, size=len(tiles) - 1)
probdensity /= np.sum(np.diff(tiles) * probdensity) * PIXEL_AREA

f = io.StringIO('1')
cursor.copy_from(f, Skymap.__tablename__)
with cursor.copy(f'COPY {Skymap.__tablename__} FROM STDIN') as copy:
copy.write('1')

f = io.StringIO(
'\n'.join(
f'1\t[{lo},{hi})\t{p}'
for lo, hi, p in zip(tiles[:-1], tiles[1:], probdensity)
with cursor.copy(f'COPY {SkymapTile.__tablename__} FROM STDIN') as copy:
copy.write(
'\n'.join(
f'1\t[{lo},{hi})\t{p}'
for lo, hi, p in zip(tiles[:-1], tiles[1:], probdensity)
)
)
)
cursor.copy_from(f, SkymapTile.__tablename__)

return tiles, probdensity

0 comments on commit cc42e5a

Please sign in to comment.