Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vdk-control-cli: pass the printer class in JobDeploy creation and add additional memory printer #2477

Merged
merged 7 commits into from
Jul 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vdk.internal.control.command_groups.job.deploy_cli_impl import JobDeploy
from vdk.internal.control.configuration.defaults_config import load_default_team_name
from vdk.internal.control.utils import cli_utils
from vdk.internal.control.utils import output_printer
from vdk.internal.control.utils.cli_utils import get_or_prompt


Expand Down Expand Up @@ -169,7 +170,7 @@ def deploy(
rest_api_url: str,
output: str,
):
cmd = JobDeploy(rest_api_url, output)
cmd = JobDeploy(rest_api_url, output_printer.create_printer(output))
if operation == DeployOperation.UPDATE.value or enabled is not None:
name = get_or_prompt("Job Name", name)
team = get_or_prompt("Job Team", team)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from vdk.internal.control.job.job_config import JobConfig
from vdk.internal.control.rest_lib.factory import ApiClientFactory
from vdk.internal.control.rest_lib.rest_client_errors import ApiClientErrorDecorator
from vdk.internal.control.utils import output_printer
from vdk.internal.control.utils.cli_utils import get_or_prompt
from vdk.internal.control.utils.output_printer import OutputFormat
from vdk.internal.control.utils.output_printer import Printer
from vdk.internal.control.utils.output_printer import PrinterJson
from vdk.internal.control.utils.output_printer import PrinterText

log = logging.getLogger(__name__)

Expand All @@ -30,16 +31,15 @@ class JobDeploy:
ZIP_ARCHIVE_TYPE = "zip"
ARCHIVE_SUFFIX = "-archive"

def __init__(self, rest_api_url: str, output_format: str):
def __init__(self, rest_api_url: str, printer: Printer):
self.deploy_api = ApiClientFactory(rest_api_url).get_deploy_api()
self.jobs_api = ApiClientFactory(rest_api_url).get_jobs_api()
self.job_sources_api = ApiClientFactory(rest_api_url).get_jobs_sources_api()
# support for multiple deployments is not implemented yet so we can put anything here.
# Ultimately this will be user facing parameter (possibly fetched from config.ini)
self.__deployment_id = "production"
self.__job_archive = JobArchive()
self.__output_format = output_format
self.__printer = output_printer.create_printer(self.__output_format)
self.__printer = printer

@staticmethod
def __detect_keytab_files_in_job_directory(job_path: str) -> None:
Expand Down Expand Up @@ -193,7 +193,7 @@ def __update_deployment(self, name: str, team: str, deployment: DataJobDeploymen
self.deploy_api.deployment_update(
team_name=team, job_name=name, data_job_deployment=deployment
)
if self.__output_format == OutputFormat.TEXT.value:
if isinstance(self.__printer, PrinterText):
log.info(
f"Request to deploy Data Job {name} using version {deployment.job_version} finished successfully.\n"
f"It would take a few minutes for the Data Job to be deployed in the server.\n"
Expand Down Expand Up @@ -239,7 +239,7 @@ def show(self, name: str, team: str) -> None:
),
deployments,
)
if self.__output_format == OutputFormat.TEXT.value:
if isinstance(self.__printer, PrinterText):
click.echo(
"You can compare the version seen here to the one seen when "
"deploying to verify your deployment was successful."
Expand Down Expand Up @@ -283,7 +283,7 @@ def create(
"Team Name", team or job_config.get_team() or load_default_team_name()
)

if self.__output_format == OutputFormat.TEXT.value:
if isinstance(self.__printer, PrinterText):
log.info(
f"Deploy Data Job with name {name} from directory {job_path} ... \n"
)
Expand All @@ -294,10 +294,10 @@ def create(
try:
job_archive_binary = self.__archive_binary(archive_path)

if self.__output_format == OutputFormat.TEXT.value:
if isinstance(self.__printer, PrinterText):
log.info("Uploading the data job might take some time ...")
with click_spinner.spinner(
disable=(self.__output_format == OutputFormat.JSON.value)
disable=(isinstance(self.__printer, PrinterJson))
):
data_job_version = self.job_sources_api.sources_upload(
team_name=team,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2023-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import abc
import io
import json
from enum import Enum
from enum import unique
Expand Down Expand Up @@ -60,7 +61,7 @@ def decorator(cls):


@printer("text")
class _PrinterText(Printer):
class PrinterText(Printer):
def print_table(self, table: Optional[List[Dict[str, Any]]]) -> None:
if table and len(table) > 0:
click.echo(tabulate(table, headers="keys", tablefmt="fancy_grid"))
Expand Down Expand Up @@ -93,7 +94,7 @@ def json_serial(obj):


@printer("json")
class _PrinterJson(Printer):
class PrinterJson(Printer):
def print_table(self, data: List[Dict[str, Any]]) -> None:
if data:
click.echo(json_format(data))
Expand All @@ -107,6 +108,35 @@ def print_dict(self, data: Dict[str, Any]) -> None:
click.echo("{}")


class InMemoryTextPrinter(Printer):
def __init__(self):
self.__output_buffer = io.StringIO()

def print_table(self, table: Optional[List[Dict[str, Any]]]) -> None:
if table and len(table) > 0:
print(
tabulate(table, headers="keys", tablefmt="fancy_grid"),
file=self.__output_buffer,
)
else:
print("No Data.", file=self.__output_buffer)

def print_dict(self, data: Optional[Dict[str, Any]]) -> None:
if data:
print(
tabulate(
[[k, v] for k, v in data.items()],
headers=("key", "value"),
),
file=self.__output_buffer,
)
else:
print("No Data.", file=self.__output_buffer)

def get_memory(self):
return self.__output_buffer.getvalue()


def create_printer(output_format: str) -> Printer:
"""
Creates a printer instance for the given output format.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
# Copyright 2023-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Any
from typing import Dict
from typing import List
from unittest.mock import patch

import pytest
from vdk.internal.control.utils import output_printer
from vdk.internal.control.utils.output_printer import _PrinterJson
from vdk.internal.control.utils.output_printer import _PrinterText
from vdk.internal.control.utils.output_printer import create_printer
from vdk.internal.control.utils.output_printer import InMemoryTextPrinter
from vdk.internal.control.utils.output_printer import Printer
from vdk.internal.control.utils.output_printer import PrinterJson
from vdk.internal.control.utils.output_printer import PrinterText


class TestPrinterText:
def test_print_dict(self):
with patch("click.echo") as mock_echo:
printer = _PrinterText()
printer = PrinterText()
data = {"key": "value"}

printer.print_dict(data)
Expand All @@ -26,7 +28,7 @@ def test_print_dict(self):

def test_print_table_with_data(self):
with patch("click.echo") as mock_echo:
printer = _PrinterText()
printer = PrinterText()

data = [{"key1": "value1", "key2": 2}, {"key1": "value3", "key2": 4}]

Expand All @@ -45,7 +47,7 @@ def test_print_table_with_data(self):

def test_print_table_with_no_data(self):
with patch("click.echo") as mock_echo:
printer = _PrinterText()
printer = PrinterText()
data = []

printer.print_table(data)
Expand All @@ -57,7 +59,7 @@ def test_print_table_with_no_data(self):
class TestPrinterJson:
def test_print_dict(self):
with patch("click.echo") as mock_echo:
printer = _PrinterJson()
printer = PrinterJson()

data = {"key": "value"}

Expand All @@ -68,7 +70,7 @@ def test_print_dict(self):

def test_print_table(self):
with patch("click.echo") as mock_echo:
printer = _PrinterJson()
printer = PrinterJson()
data = [
{"key1": "value1", "key2": "value2"},
{"key1": "value3", "key2": "value4"},
Expand All @@ -79,6 +81,53 @@ def test_print_table(self):
mock_echo.assert_called_once_with(expected_output)


class TestMemoryPrinter(unittest.TestCase):
def setUp(self):
self.printer = InMemoryTextPrinter()

def test_print_dict(self):
data = {"key": "value"}

self.printer.print_dict(data)

output = self.printer.get_memory().strip()

self.assertIn("key", output)
self.assertIn("value", output)

def test_print_table(self):
data = [
{"key1": "value1", "key2": "value2"},
{"key1": "value3", "key2": "value4"},
]
self.printer.print_table(data)

output = self.printer.get_memory().strip()

self.assertIn("key1", output)
self.assertIn("key2", output)
self.assertIn("value1", output)
self.assertIn("value2", output)
self.assertIn("value3", output)
self.assertIn("value4", output)

def test_print_dict_no_data(self):
self.printer.print_dict(None)

expected_output = "No Data."
actual_output = self.printer.get_memory().strip()

self.assertEqual(actual_output, expected_output)

def test_print_table_no_data(self):
self.printer.print_table(None)

expected_output = "No Data."
actual_output = self.printer.get_memory().strip()

self.assertEqual(actual_output, expected_output)


class TestCreatePrinter:
def test_create_printer_with_registered_format(self):
class MockPrinter(Printer):
Expand Down
Loading