Skip to content

Commit

Permalink
Drop string-bytes comparisons
Browse files Browse the repository at this point in the history
They were mostly internal ones, coming from dicts containing strings and
bytes as keys.

Close #147
  • Loading branch information
dvarrazzo committed Nov 13, 2021
1 parent 6e6aca0 commit b92f696
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 29 deletions.
8 changes: 3 additions & 5 deletions psycopg/psycopg/_encodings.py
Expand Up @@ -60,7 +60,6 @@
}

py_codecs: Dict[Union[bytes, str], str] = {}
py_codecs.update((k, v) for k, v in _py_codecs.items())
py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())

pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
Expand All @@ -79,7 +78,7 @@ def py2pgenc(name: str) -> bytes:
return pg_codecs[codecs.lookup(name).name]


def pg2pyenc(name: Union[bytes, str]) -> str:
def pg2pyenc(name: bytes) -> str:
"""Convert a Python encoding name to PostgreSQL encoding name.
Raise NotSupportedError if the PostgreSQL encoding is not supported by
Expand All @@ -88,6 +87,5 @@ def pg2pyenc(name: Union[bytes, str]) -> str:
try:
return py_codecs[name]
except KeyError:
if isinstance(name, bytes):
name = name.decode("utf8", "replace")
raise NotSupportedError(f"codec not available in Python: {name!r}")
sname = name.decode("utf8", "replace")
raise NotSupportedError(f"codec not available in Python: {sname!r}")
25 changes: 9 additions & 16 deletions psycopg/psycopg/_queries.py
Expand Up @@ -63,20 +63,22 @@ def convert(self, query: Query, vars: Optional[Params]) -> None:
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, Composable):
query = query.as_bytes(self._tx)
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query

if vars is not None:
(
self.query,
self._want_formats,
self._order,
self._parts,
) = _query2pg(query, self._encoding)
) = _query2pg(bquery, self._encoding)
else:
if isinstance(query, str):
query = query.encode(self._encoding)
self.query = query
self.query = bquery
self._want_formats = self._order = None

self.dump(vars)
Expand All @@ -103,7 +105,7 @@ def dump(self, vars: Optional[Params]) -> None:

@lru_cache()
def _query2pg(
query: Union[bytes, str], encoding: str
query: bytes, encoding: str
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
Expand All @@ -115,15 +117,6 @@ def _query2pg(
(sequence of names used in the query, in the position they appear)
``parts`` (splits of queries and placeholders).
"""
if isinstance(query, str):
query = query.encode(encoding)
if not isinstance(query, bytes):
# encoding from str already happened
raise TypeError(
f"the query should be str or bytes,"
f" got {type(query).__name__} instead"
)

parts = _split_query(query, encoding)
order: Optional[List[str]] = None
chunks: List[bytes] = []
Expand Down
2 changes: 1 addition & 1 deletion psycopg/psycopg/types/range.py
Expand Up @@ -455,7 +455,7 @@ def load_range_text(
if item is None:
item = m.group(4)
if item is not None:
upper = load(_re_undouble.sub(r"\1", item))
upper = load(_re_undouble.sub(rb"\1", item))
else:
upper = load(item)

Expand Down
6 changes: 1 addition & 5 deletions tests/pq/test_pgconn.py
Expand Up @@ -352,11 +352,7 @@ def test_used_password(pgconn, dsn, monkeypatch):
or [i for i in info if i.keyword == b"password"][0].val is not None
)
if has_password:
# The assumption that the password is needed is broken on the Travis
# PG 10 setup so let's skip that
print("\n".join(map(str, sorted(os.environ.items()))))
if not (os.environ.get("TRAVIS") and os.environ.get("PGVER") == "10"):
assert pgconn.used_password
assert pgconn.used_password

pgconn.finish()
pgconn.used_password
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conninfo.py
Expand Up @@ -251,7 +251,7 @@ def test_timezone_warn(self, conn, caplog):

def test_encoding(self, conn):
enc = conn.execute("show client_encoding").fetchone()[0]
assert conn.info.encoding == pg2pyenc(enc)
assert conn.info.encoding == pg2pyenc(enc.encode())

@pytest.mark.parametrize(
"enc, out, codec",
Expand Down
1 change: 0 additions & 1 deletion tests/types/test_composite.py
Expand Up @@ -79,7 +79,6 @@ def test_dump_builtin_empty_range(conn, fmt_in):
f"select pg_typeof(%{fmt_in})",
[info.python_type(10, Range(empty=True), [])],
)
print(cur._query.params[0])
assert cur.fetchone()[0] == "tmptype"


Expand Down

0 comments on commit b92f696

Please sign in to comment.