Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions appengine/standard_python3/pubsub/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def receive_messages_handler():

# [END gae_standard_pubsub_auth_push]


# [START gae_standard_pubsub_push]
@app.route("/pubsub/push", methods=["POST"])
def receive_pubsub_messages_handler():
Expand All @@ -121,6 +122,7 @@ def receive_pubsub_messages_handler():
# Returning any 2xx status indicates successful receipt of the message.
return "OK", 200


# [END gae_standard_pubsub_push]


Expand Down
5 changes: 1 addition & 4 deletions appengine/standard_python3/pubsub/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ def test_push_endpoint(monkeypatch, client, fake_token):
assert r.status_code == 200

# Push request without JWT token validation
url = (
"/pubsub/push?token="
+ os.environ["PUBSUB_VERIFICATION_TOKEN"]
)
url = "/pubsub/push?token=" + os.environ["PUBSUB_VERIFICATION_TOKEN"]

r = client.post(
url,
Expand Down
22 changes: 13 additions & 9 deletions composer/workflows/airflow_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@
},
{
"airflow_db_model": TaskInstance,
"age_check_column": TaskInstance.start_date
if AIRFLOW_VERSION < ["2", "2", "0"]
else TaskInstance.start_date,
"age_check_column": (
TaskInstance.start_date
if AIRFLOW_VERSION < ["2", "2", "0"]
else TaskInstance.start_date
),
"keep_last": False,
"keep_last_filters": None,
"keep_last_group_by": None,
Expand All @@ -127,9 +129,9 @@
},
{
"airflow_db_model": XCom,
"age_check_column": XCom.execution_date
if AIRFLOW_VERSION < ["2", "2", "5"]
else XCom.timestamp,
"age_check_column": (
XCom.execution_date if AIRFLOW_VERSION < ["2", "2", "5"] else XCom.timestamp
),
"keep_last": False,
"keep_last_filters": None,
"keep_last_group_by": None,
Expand Down Expand Up @@ -157,9 +159,11 @@
DATABASE_OBJECTS.append(
{
"airflow_db_model": TaskReschedule,
"age_check_column": TaskReschedule.execution_date
if AIRFLOW_VERSION < ["2", "2", "0"]
else TaskReschedule.start_date,
"age_check_column": (
TaskReschedule.execution_date
if AIRFLOW_VERSION < ["2", "2", "0"]
else TaskReschedule.start_date
),
"keep_last": False,
"keep_last_filters": None,
"keep_last_group_by": None,
Expand Down
18 changes: 9 additions & 9 deletions datacatalog/quickstart/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ def quickstart(override_values):

tag_template.fields["source"] = datacatalog_v1.types.TagTemplateField()
tag_template.fields["source"].display_name = "Source of data asset"
tag_template.fields[
"source"
].type_.primitive_type = datacatalog_v1.types.FieldType.PrimitiveType.STRING
tag_template.fields["source"].type_.primitive_type = (
datacatalog_v1.types.FieldType.PrimitiveType.STRING
)

tag_template.fields["num_rows"] = datacatalog_v1.types.TagTemplateField()
tag_template.fields["num_rows"].display_name = "Number of rows in data asset"
tag_template.fields[
"num_rows"
].type_.primitive_type = datacatalog_v1.types.FieldType.PrimitiveType.DOUBLE
tag_template.fields["num_rows"].type_.primitive_type = (
datacatalog_v1.types.FieldType.PrimitiveType.DOUBLE
)

tag_template.fields["has_pii"] = datacatalog_v1.types.TagTemplateField()
tag_template.fields["has_pii"].display_name = "Has PII"
tag_template.fields[
"has_pii"
].type_.primitive_type = datacatalog_v1.types.FieldType.PrimitiveType.BOOL
tag_template.fields["has_pii"].type_.primitive_type = (
datacatalog_v1.types.FieldType.PrimitiveType.BOOL
)

tag_template.fields["pii_type"] = datacatalog_v1.types.TagTemplateField()
tag_template.fields["pii_type"].display_name = "PII type"
Expand Down
6 changes: 3 additions & 3 deletions datacatalog/snippets/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def random_existing_tag_template_id(client, project_id, resources_to_delete):
random_tag_template_id = f"python_sample_{temp_suffix()}"
random_tag_template = datacatalog_v1.types.TagTemplate()
random_tag_template.fields["source"] = datacatalog_v1.types.TagTemplateField()
random_tag_template.fields[
"source"
].type_.primitive_type = datacatalog_v1.FieldType.PrimitiveType.STRING.value
random_tag_template.fields["source"].type_.primitive_type = (
datacatalog_v1.FieldType.PrimitiveType.STRING.value
)
random_tag_template = client.create_tag_template(
parent=datacatalog_v1.DataCatalogClient.common_location_path(
project_id, LOCATION
Expand Down
60 changes: 31 additions & 29 deletions dataflow/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def cloud_build_submit(
cmd = ["gcloud", "auth", "configure-docker"]
logging.info(f"{cmd}")
subprocess.check_call(cmd)
gcr_project = project.replace(':', '/')
gcr_project = project.replace(":", "/")

if substitutions:
cmd_substitutions = [
Expand Down Expand Up @@ -568,8 +568,7 @@ def cloud_build_submit(
]
logging.info(f"{cmd}")
subprocess.check_call(cmd)
logging.info(
f"Created image: gcr.io/{gcr_project}/{image_name}:{UUID}")
logging.info(f"Created image: gcr.io/{gcr_project}/{image_name}:{UUID}")
yield f"{image_name}:{UUID}"
else:
raise ValueError("must specify either `config` or `image_name`")
Expand All @@ -587,8 +586,7 @@ def cloud_build_submit(
]
logging.info(f"{cmd}")
subprocess.check_call(cmd)
logging.info(
f"Deleted image: gcr.io/{gcr_project}/{image_name}:{UUID}")
logging.info(f"Deleted image: gcr.io/{gcr_project}/{image_name}:{UUID}")

@staticmethod
def dataflow_job_url(
Expand Down Expand Up @@ -765,7 +763,7 @@ def dataflow_flex_template_build(
) -> str:
# https://cloud.google.com/sdk/gcloud/reference/dataflow/flex-template/build
template_gcs_path = f"gs://{bucket_name}/{template_file}"
gcr_project = project.replace(':', '/')
gcr_project = project.replace(":", "/")
cmd = [
"gcloud",
"dataflow",
Expand All @@ -774,7 +772,7 @@ def dataflow_flex_template_build(
template_gcs_path,
f"--project={project}",
f"--image=gcr.io/{gcr_project}/{image_name}",
"--sdk-language=PYTHON"
"--sdk-language=PYTHON",
]
if metadata_file:
cmd.append(f"--metadata-file={metadata_file}")
Expand All @@ -794,34 +792,38 @@ def dataflow_flex_template_run(
parameters: dict[str, str] = {},
project: str = PROJECT,
region: str = REGION,
additional_experiments: dict[str,str] = {},
additional_experiments: dict[str, str] = {},
) -> str:
import yaml

# https://cloud.google.com/sdk/gcloud/reference/dataflow/flex-template/run
unique_job_name = Utils.hyphen_name(job_name)
logging.info(f"dataflow_job_name: {unique_job_name}")
cmd = [
"gcloud",
"dataflow",
"flex-template",
"run",
unique_job_name,
f"--template-file-gcs-location={template_path}",
f"--project={project}",
f"--region={region}",
f"--staging-location=gs://{bucket_name}/staging",
] + [
f"--parameters={name}={value}"
for name, value in {
**parameters,
}.items()
] + [
f"--additional-experiments={name}={value}"
for name, value in {
**additional_experiments,
}.items()
]
cmd = (
[
"gcloud",
"dataflow",
"flex-template",
"run",
unique_job_name,
f"--template-file-gcs-location={template_path}",
f"--project={project}",
f"--region={region}",
f"--staging-location=gs://{bucket_name}/staging",
]
+ [
f"--parameters={name}={value}"
for name, value in {
**parameters,
}.items()
]
+ [
f"--additional-experiments={name}={value}"
for name, value in {
**additional_experiments,
}.items()
]
)
logging.info(f"{cmd}")

stdout = subprocess.check_output(cmd).decode("utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
"""

import setuptools

setuptools.setup()
9 changes: 6 additions & 3 deletions dataflow/gemma-flex-template/custom_model_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,17 @@ def run_inference(
with beam.Pipeline(options=beam_options) as pipeline:
_ = (
pipeline
| "Subscribe to Pub/Sub" >> beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Subscribe to Pub/Sub"
>> beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Decode" >> beam.Map(lambda msg: msg.decode("utf-8"))
| "RunInference Gemma" >> RunInference(handler)
| "Format output" >> beam.Map(
| "Format output"
>> beam.Map(
lambda response: json.dumps(
{"input": response.example, "outputs": response.inference}
)
)
| "Encode" >> beam.Map(lambda msg: msg.encode("utf-8"))
| "Publish to Pub/Sub" >> beam.io.gcp.pubsub.WriteToPubSub(topic=args.responses_topic)
| "Publish to Pub/Sub"
>> beam.io.gcp.pubsub.WriteToPubSub(topic=args.responses_topic)
)
4 changes: 1 addition & 3 deletions dataflow/gemma-flex-template/noxfile_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,5 @@
# The Python version used is defined by the Dockerfile and the job
# submission enviornment must match.
"ignored_versions": ["2.7", "3.6", "3.7", "3.8", "3.9", "3.11", "3.12"],
"envs": {
"PYTHONPATH": ".."
},
"envs": {"PYTHONPATH": ".."},
}
22 changes: 13 additions & 9 deletions dataflow/gemma/custom_model_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self,
model_name: str = "gemma_2B",
):
""" Implementation of the ModelHandler interface for Gemma using text as input.
"""Implementation of the ModelHandler interface for Gemma using text as input.

Example Usage::

Expand All @@ -48,7 +48,7 @@ def __init__(
self._env_vars = {}

def share_model_across_processes(self) -> bool:
""" Indicates if the model should be loaded once-per-VM rather than
"""Indicates if the model should be loaded once-per-VM rather than
once-per-worker-process on a VM. Because Gemma is a large language model,
this will always return True to avoid OOM errors.
"""
Expand All @@ -62,7 +62,7 @@ def run_inference(
self,
batch: Sequence[str],
model: GemmaCausalLM,
inference_args: Optional[dict[str, Any]] = None
inference_args: Optional[dict[str, Any]] = None,
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.

Expand All @@ -85,7 +85,8 @@ def run_inference(
class FormatOutput(beam.DoFn):
def process(self, element, *args, **kwargs):
yield "Input: {input}, Output: {output}".format(
input=element.example, output=element.inference)
input=element.example, output=element.inference
)


if __name__ == "__main__":
Expand Down Expand Up @@ -119,13 +120,16 @@ def process(self, element, *args, **kwargs):

pipeline = beam.Pipeline(options=beam_options)
_ = (
pipeline | "Read Topic" >>
beam.io.ReadFromPubSub(subscription=args.messages_subscription)
pipeline
| "Read Topic"
>> beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Parse" >> beam.Map(lambda x: x.decode("utf-8"))
| "RunInference-Gemma" >> RunInference(
| "RunInference-Gemma"
>> RunInference(
GemmaModelHandler(args.model_path)
) # Send the prompts to the model and get responses.
| "Format Output" >> beam.ParDo(FormatOutput()) # Format the output.
| "Publish Result" >>
beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))
| "Publish Result"
>> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)
)
pipeline.run()
40 changes: 20 additions & 20 deletions dataflow/gemma/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def messages_topic(pubsub_topic: Callable[[str], str]) -> str:


@pytest.fixture(scope="session")
def messages_subscription(pubsub_subscription: Callable[[str, str], str],
messages_topic: str) -> str:
def messages_subscription(
pubsub_subscription: Callable[[str, str], str], messages_topic: str
) -> str:
return pubsub_subscription("messages", messages_topic)


Expand All @@ -81,20 +82,21 @@ def responses_topic(pubsub_topic: Callable[[str], str]) -> str:


@pytest.fixture(scope="session")
def responses_subscription(pubsub_subscription: Callable[[str, str], str],
responses_topic: str) -> str:
def responses_subscription(
pubsub_subscription: Callable[[str, str], str], responses_topic: str
) -> str:
return pubsub_subscription("responses", responses_topic)


@pytest.fixture(scope="session")
def dataflow_job(
project: str,
bucket_name: str,
location: str,
unique_name: str,
container_image: str,
messages_subscription: str,
responses_topic: str,
project: str,
bucket_name: str,
location: str,
unique_name: str,
container_image: str,
messages_subscription: str,
responses_topic: str,
) -> Iterator[str]:
# Launch the streaming Dataflow pipeline.
conftest.run_cmd(
Expand Down Expand Up @@ -127,20 +129,18 @@ def dataflow_job(

@pytest.mark.timeout(3600)
def test_pipeline_dataflow(
project: str,
location: str,
dataflow_job: str,
messages_topic: str,
responses_subscription: str,
project: str,
location: str,
dataflow_job: str,
messages_topic: str,
responses_subscription: str,
) -> None:
print(f"Waiting for the Dataflow workers to start: {dataflow_job}")
conftest.wait_until(
lambda: conftest.dataflow_num_workers(project, location, dataflow_job)
> 0,
lambda: conftest.dataflow_num_workers(project, location, dataflow_job) > 0,
"workers are running",
)
num_workers = conftest.dataflow_num_workers(project, location,
dataflow_job)
num_workers = conftest.dataflow_num_workers(project, location, dataflow_job)
print(f"Dataflow job num_workers: {num_workers}")

messages = ["This is a test for a Python sample."]
Expand Down
4 changes: 1 addition & 3 deletions dataflow/gemma/noxfile_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,5 @@
# The Python version used is defined by the Dockerfile and the job
# submission enviornment must match.
"ignored_versions": ["2.7", "3.6", "3.7", "3.8", "3.9", "3.10", "3.12"],
"envs": {
"PYTHONPATH": ".."
},
"envs": {"PYTHONPATH": ".."},
}
Loading