Skip to content

Commit

Permalink
Add type annotation and black/isort checks
Browse files Browse the repository at this point in the history
Adds type annotation to the library including mypy sanity checks. Also
adds black and isort as the builtin sanity tests.

The MANIFEST.in has been updated to also include other files that should
be present in the sdist for people repackaging the library.
  • Loading branch information
jborean93 committed Jun 7, 2024
1 parent 0256cec commit 5f71c99
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 24 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
fail-fast: false
matrix:
include:
- python-version: '3.7'
- python-version: '3.8'
- python-version: '3.9'
- python-version: '3.10'
Expand All @@ -47,6 +46,10 @@ jobs:
python -m pip install .
python -m pip install -r requirements.txt
python -m black . --check
python -m isort . --check-only
python -m mypy .
python -m tests.test_server &
python -m pytest \
Expand Down
5 changes: 5 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
include CONTRIBUTORS.rst
include LICENSE
include requirements.txt

recursive-include tests *
recursive-exclude tests *.pyc
34 changes: 34 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,37 @@ requires = [
"setuptools >= 42.0.0", # Supports license_files
]
build-backend = "setuptools.build_meta"

[tool.black]
exclude = '''
/(
\.git
| \.venv
| build
| dist
)/
'''

[tool.isort]
profile = "black"

[tool.mypy]
exclude = "setup\\.py|build/|tests/"
mypy_path = "$MYPY_CONFIG_FILE_DIR"
python_version = "3.8"
show_error_codes = true
show_column_numbers = true
disallow_any_unimported = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_reexport = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true

[[tool.mypy.overrides]]
module = "requests.packages.urllib3.*"
ignore_missing_imports = true
Empty file added requests_ntlm/py.typed
Empty file.
52 changes: 31 additions & 21 deletions requests_ntlm/requests_ntlm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import warnings
import base64
import typing as t

from urllib.parse import urlparse

import requests
import spnego

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.exceptions import UnsupportedAlgorithm
from requests.auth import AuthBase
from requests.packages.urllib3.response import HTTPResponse
import spnego


class ShimSessionSecurity:
Expand All @@ -19,7 +23,7 @@ class ShimSessionSecurity:
def __init__(self, context: spnego.ContextProxy) -> None:
self._context = context

def wrap(self, message) -> t.Tuple[bytes, bytes]:
def wrap(self, message: bytes) -> tuple[bytes, bytes]:
wrap_res = self._context.wrap(message, encrypt=True)
signature = wrap_res.data[:16]
data = wrap_res.data[16:]
Expand All @@ -44,7 +48,13 @@ class HttpNtlmAuth(AuthBase):
Supports pass-the-hash.
"""

def __init__(self, username, password, session=None, send_cbt=True):
def __init__(
self,
username: str | None,
password: str | None,
session: None = None,
send_cbt: bool = True,
) -> None:
"""Create an authentication handler for NTLM over HTTP.
:param str username: Username in 'domain\\username' format
Expand All @@ -59,16 +69,16 @@ def __init__(self, username, password, session=None, send_cbt=True):
# This exposes the encrypt/decrypt methods used to encrypt and decrypt messages
# sent after ntlm authentication. These methods are utilised by libraries that
# call requests_ntlm to encrypt and decrypt the messages sent after authentication
self.session_security = None
self.session_security: ShimSessionSecurity | None = None

def retry_using_http_NTLM_auth(
self,
auth_header_field,
auth_header,
response,
auth_type,
args,
):
auth_header_field: str,
auth_header: str,
response: requests.Response,
auth_type: str,
args: t.Any,
) -> requests.Response:
# Get the certificate of the server if using HTTPS for CBT
server_certificate_hash = self._get_server_cert(response)
cbt = None
Expand All @@ -84,7 +94,7 @@ def retry_using_http_NTLM_auth(
content_length = int(
response.request.headers.get("Content-Length", "0"), base=10
)
if hasattr(response.request.body, "seek"):
if response.request.body and hasattr(response.request.body, "seek"):
if content_length > 0:
response.request.body.seek(-content_length, 1)
else:
Expand All @@ -96,7 +106,7 @@ def retry_using_http_NTLM_auth(
response.raw.release_conn()
request = response.request.copy()

target_hostname = urlparse(response.url).hostname
target_hostname = t.cast(str, urlparse(response.url).hostname)
spnego_options = spnego.NegotiateOptions.none
if self.username and self.password:
# If a username and password are specified force spnego to use the
Expand All @@ -117,7 +127,7 @@ def retry_using_http_NTLM_auth(
options=spnego_options,
)
# Perform the first step of the NTLM authentication
negotiate_message = base64.b64encode(client.step()).decode()
negotiate_message = base64.b64encode(client.step() or b"").decode()
auth = "%s %s" % (auth_type, negotiate_message)

request.headers[auth_header] = auth
Expand All @@ -129,7 +139,7 @@ def retry_using_http_NTLM_auth(
# challenge and not the real content, so the content will be short
# anyway.
args_nostream = dict(args, stream=False)
response2 = response.connection.send(request, **args_nostream)
response2 = response.connection.send(request, **args_nostream) # type: ignore[attr-defined]

# needed to make NTLM auth compatible with requests-2.3.0

Expand Down Expand Up @@ -165,7 +175,7 @@ def retry_using_http_NTLM_auth(
# Parse the challenge in the ntlm context and perform
# the second step of authentication
val = base64.b64decode(ntlm_header_value[len(auth_strip) :].encode())
authenticate_message = base64.b64encode(client.step(val)).decode()
authenticate_message = base64.b64encode(client.step(val) or b"").decode()

auth = "%s %s" % (auth_type, authenticate_message)
request.headers[auth_header] = auth
Expand All @@ -179,7 +189,7 @@ def retry_using_http_NTLM_auth(

return response3

def response_hook(self, r, **kwargs):
def response_hook(self, r: requests.Response, **kwargs: t.Any) -> requests.Response:
"""The actual hook handler."""
if r.status_code == 401:
# Handle server auth.
Expand Down Expand Up @@ -209,7 +219,7 @@ def response_hook(self, r, **kwargs):

return r

def _get_server_cert(self, response):
def _get_server_cert(self, response: requests.Response) -> bytes | None:
"""
Get the certificate at the request_url and return it as a hash. Will get the raw socket from the
original response from the server. This socket is then checked if it is an SSL socket and then used to
Expand Down Expand Up @@ -240,7 +250,7 @@ def _get_server_cert(self, response):

return None

def __call__(self, r):
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
# we must keep the connection because NTLM authenticates the
# connection, not single requests
r.headers["Connection"] = "Keep-Alive"
Expand All @@ -249,7 +259,7 @@ def __call__(self, r):
return r


def _auth_type_from_header(header):
def _auth_type_from_header(header: str) -> str | None:
"""
Given a WWW-Authenticate or Proxy-Authenticate header, returns the
authentication type to use. We prefer NTLM over Negotiate if the server
Expand All @@ -263,7 +273,7 @@ def _auth_type_from_header(header):
return None


def _get_certificate_hash(certificate_der):
def _get_certificate_hash(certificate_der: bytes) -> bytes | None:
# https://tools.ietf.org/html/rfc5929#section-4.1
cert = x509.load_der_x509_certificate(certificate_der, default_backend())

Expand All @@ -279,7 +289,7 @@ def _get_certificate_hash(certificate_der):

# if the cert signature algorithm is either md5 or sha1 then use sha256
# otherwise use the signature algorithm
if hash_algorithm.name in ["md5", "sha1"]:
if not hash_algorithm or hash_algorithm.name in ["md5", "sha1"]:
digest = hashes.Hash(hashes.SHA256(), default_backend())
else:
digest = hashes.Hash(hash_algorithm, default_backend())
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
black == 24.4.2
isort == 5.13.2
requests>=2.0.0
pyspnego
cryptography>=1.3
flask
mypy == 1.10.0
pytest
pytest-cov
types-requests
wheel
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ classifiers =
Intended Audience :: Developers
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
License :: OSI Approved :: ISC License (ISCL)

[options]
python_requires = >= 3.7
python_requires = >= 3.8
install_requires =
cryptography >= 1.3
pyspnego >= 0.4.0
requests >= 2.0.0

[tools.setuptools.package-data]
"requests_ntlm" = ["py.typed"]

0 comments on commit 5f71c99

Please sign in to comment.