Skip to content

Added type annotations for public API + flake8 fixes #627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions firebase_admin/__init__.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import json
import os
import threading
from typing import Any, Callable, Dict, Optional

from firebase_admin import credentials
from firebase_admin.__about__ import __version__
@@ -31,7 +32,8 @@
_CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId',
'storageBucket']

def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):

def initialize_app(credential: Optional[credentials.Base] = None, options: Optional[Dict[str, Any]] = None, name: str = _DEFAULT_APP_NAME) -> "App":
"""Initializes and returns a new App instance.

Creates a new App instance using the specified options
@@ -83,7 +85,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME):
'you call initialize_app().').format(name))


def delete_app(app):
def delete_app(app: "App"):
"""Gracefully deletes an App instance.

Args:
@@ -98,7 +100,7 @@ def delete_app(app):
with _apps_lock:
if _apps.get(app.name) is app:
del _apps[app.name]
app._cleanup() # pylint: disable=protected-access
app._cleanup() # pylint: disable=protected-access
return
if app.name == _DEFAULT_APP_NAME:
raise ValueError(
@@ -111,7 +113,7 @@ def delete_app(app):
'second argument.').format(app.name))


def get_app(name=_DEFAULT_APP_NAME):
def get_app(name: str = _DEFAULT_APP_NAME) -> "App":
"""Retrieves an App instance by name.

Args:
@@ -190,7 +192,7 @@ class App:
common to all Firebase APIs.
"""

def __init__(self, name, credential, options):
def __init__(self, name: str, credential: credentials.Base, options: Optional[Dict[str, Any]]):
"""Constructs a new App using the provided name and options.

Args:
@@ -265,7 +267,7 @@ def _lookup_project_id(self):
App._validate_project_id(self._options.get('projectId'))
return project_id

def _get_service(self, name, initializer):
def _get_service(self, name: str, initializer: Callable):
"""Returns the service instance identified by the given name.

Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each
8 changes: 4 additions & 4 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
"""Internal utilities common to all modules."""

import json
from typing import Callable, Optional

import google.auth
import requests
@@ -76,7 +77,7 @@
}


def _get_initialized_app(app):
def _get_initialized_app(app: Optional[firebase_admin.App]):
"""Returns a reference to an initialized App instance."""
if app is None:
return firebase_admin.get_app()
@@ -92,10 +93,9 @@ def _get_initialized_app(app):
' firebase_admin.App, but given "{0}".'.format(type(app)))



def get_app_service(app, name, initializer):
def get_app_service(app: Optional[firebase_admin.App], name: str, initializer: Callable):
app = _get_initialized_app(app)
return app._get_service(name, initializer) # pylint: disable=protected-access
return app._get_service(name, initializer) # pylint: disable=protected-access


def handle_platform_error_from_requests(error, handle_func=None):
9 changes: 6 additions & 3 deletions firebase_admin/credentials.py
Original file line number Diff line number Diff line change
@@ -16,11 +16,13 @@
import collections
import json
import pathlib
from typing import Any, Dict, Union

import google.auth
from google.auth.transport import requests
from google.oauth2 import credentials
from google.oauth2 import service_account
import google.auth.credentials


_request = requests.Request()
@@ -44,7 +46,7 @@
class Base:
"""Provides OAuth2 access tokens for accessing Firebase services."""

def get_access_token(self):
def get_access_token(self) -> AccessTokenInfo:
"""Fetches a Google OAuth2 access token using this credential instance.

Returns:
@@ -54,7 +56,7 @@ def get_access_token(self):
google_cred.refresh(_request)
return AccessTokenInfo(google_cred.token, google_cred.expiry)

def get_credential(self):
def get_credential(self) -> google.auth.credentials.Credentials:
"""Returns the Google credential instance used for authentication."""
raise NotImplementedError

@@ -64,7 +66,7 @@ class Certificate(Base):

_CREDENTIAL_TYPE = 'service_account'

def __init__(self, cert):
def __init__(self, cert: Union[str, Dict[str, Any]]):
"""Initializes a credential from a Google service account certificate.

Service account certificates can be downloaded as JSON files from the Firebase console.
@@ -158,6 +160,7 @@ def _load_credential(self):
if not self._g_credential:
self._g_credential, self._project_id = google.auth.default(scopes=_scopes)


class RefreshToken(Base):
"""A credential initialized from an existing refresh token."""

14 changes: 8 additions & 6 deletions firebase_admin/firestore.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
"""

try:
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
from google.cloud import firestore # pylint: disable=import-error,no-name-in-module
existing = globals().keys()
for key, value in firestore.__dict__.items():
if not key.startswith('_') and key not in existing:
@@ -28,13 +28,15 @@
raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure '
'to install the "google-cloud-firestore" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
import google.auth.credentials
from typing import Optional


_FIRESTORE_ATTRIBUTE = '_firestore'


def client(app=None):
def client(app: Optional[App] = None) -> firestore.Client:
"""Returns a client that can be used to interact with Google Cloud Firestore.

Args:
@@ -57,14 +59,14 @@ def client(app=None):
class _FirestoreClient:
"""Holds a Google Cloud Firestore client instance."""

def __init__(self, credentials, project):
def __init__(self, credentials: google.auth.credentials.Credentials, project: str):
self._client = firestore.Client(credentials=credentials, project=project)

def get(self):
def get(self) -> firestore.Client:
return self._client

@classmethod
def from_app(cls, app):
def from_app(cls, app: App):
"""Creates a new _FirestoreClient for the specified app."""
credentials = app.credential.get_credential()
project = app.project_id
5 changes: 5 additions & 0 deletions firebase_admin/messaging.py
Original file line number Diff line number Diff line change
@@ -95,6 +95,7 @@
def _get_messaging_service(app):
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)


def send(message, dry_run=False, app=None):
"""Sends the given message via Firebase Cloud Messaging (FCM).

@@ -115,6 +116,7 @@ def send(message, dry_run=False, app=None):
"""
return _get_messaging_service(app).send(message, dry_run)


def send_all(messages, dry_run=False, app=None):
"""Sends the given list of messages via Firebase Cloud Messaging as a single batch.

@@ -135,6 +137,7 @@ def send_all(messages, dry_run=False, app=None):
"""
return _get_messaging_service(app).send_all(messages, dry_run)


def send_multicast(multicast_message, dry_run=False, app=None):
"""Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM).

@@ -166,6 +169,7 @@ def send_multicast(multicast_message, dry_run=False, app=None):
) for token in multicast_message.tokens]
return _get_messaging_service(app).send_all(messages, dry_run)


def subscribe_to_topic(tokens, topic, app=None):
"""Subscribes a list of registration tokens to an FCM topic.

@@ -185,6 +189,7 @@ def subscribe_to_topic(tokens, topic, app=None):
return _get_messaging_service(app).make_topic_management_request(
tokens, topic, 'iid/v1:batchAdd')


def unsubscribe_from_topic(tokens, topic, app=None):
"""Unsubscribes a list of registration tokens from an FCM topic.

20 changes: 10 additions & 10 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
@@ -211,13 +211,13 @@ def from_dict(cls, data, app=None):
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access
self._data = copy._data # pylint: disable=protected-access

def __eq__(self, other):
if isinstance(other, self.__class__):
@@ -334,7 +334,7 @@ def model_format(self):
def model_format(self, model_format):
if model_format is not None:
_validate_model_format(model_format)
self._model_format = model_format #Can be None
self._model_format = model_format # Can be None
return self

def as_dict(self, for_upload=False):
@@ -370,7 +370,7 @@ def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format

def __eq__(self, other):
@@ -405,7 +405,7 @@ def model_source(self, model_source):
if model_source is not None:
if not isinstance(model_source, TFLiteModelSource):
raise TypeError('Model source must be a TFLiteModelSource object.')
self._model_source = model_source # Can be None
self._model_source = model_source # Can be None

@property
def size_bytes(self):
@@ -485,7 +485,7 @@ def __init__(self, gcs_tflite_uri, app=None):

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
return False

def __ne__(self, other):
@@ -775,7 +775,7 @@ def _validate_display_name(display_name):

def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, str) for tag in tags):
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
@@ -789,6 +789,7 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
@@ -809,7 +810,7 @@ def _validate_list_filter(list_filter):

def _validate_page_size(page_size):
if page_size is not None:
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
# Specifically type() to disallow boolean which is a subtype of int
raise TypeError('Page size must be a number or None.')
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
@@ -864,7 +865,7 @@ def _exponential_backoff(self, current_attempt, stop_time):

if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)
@@ -925,7 +926,6 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
_validate_model(model)
try:
8 changes: 5 additions & 3 deletions firebase_admin/storage.py
Original file line number Diff line number Diff line change
@@ -25,12 +25,14 @@
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')

from firebase_admin import _utils
from firebase_admin import _utils, App
from typing import Optional


_STORAGE_ATTRIBUTE = '_storage'

def bucket(name=None, app=None) -> storage.Bucket:

def bucket(name: Optional[str] = None, app: Optional[App] = None) -> storage.Bucket:
"""Returns a handle to a Google Cloud Storage bucket.

If the name argument is not provided, uses the 'storageBucket' option specified when
@@ -67,7 +69,7 @@ def from_app(cls, app):
# significantly speeds up the initialization of the storage client.
return _StorageClient(credentials, app.project_id, default_bucket)

def bucket(self, name=None):
def bucket(self, name: Optional[str] = None):
"""Returns a handle to the specified Cloud Storage Bucket."""
bucket_name = name if name is not None else self._default_bucket
if bucket_name is None:
5 changes: 3 additions & 2 deletions firebase_admin/tenant_mgt.py
Original file line number Diff line number Diff line change
@@ -183,6 +183,7 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non
FirebaseError: If an error occurs while retrieving the user accounts.
"""
tenant_mgt_service = _get_tenant_mgt_service(app)

def download(page_token, max_results):
return tenant_mgt_service.list_tenants(page_token, max_results)
return ListTenantsPage(download, page_token, max_results)
@@ -206,7 +207,7 @@ class Tenant:
def __init__(self, data):
if not isinstance(data, dict):
raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data))
if not 'name' in data:
if 'name' not in data:
raise ValueError('Tenant response missing required keys.')

self._data = data
@@ -256,7 +257,7 @@ def auth_for_tenant(self, tenant_id):

client = auth.Client(self.app, tenant_id=tenant_id)
self.tenant_clients[tenant_id] = client
return client
return client

def get_tenant(self, tenant_id):
"""Gets the tenant corresponding to the given ``tenant_id``."""