Skip to content

Commit

Permalink
Typing fixes for new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Jan 30, 2023
1 parent cf6f774 commit 2d11181
Showing 5 changed files with 83 additions and 96 deletions.
6 changes: 2 additions & 4 deletions lisa/environment.py
Original file line number Diff line number Diff line change
@@ -489,15 +489,13 @@ def load_environments(
class EnvironmentHookSpec:
@hookspec
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
...
raise NotImplementedError


class EnvironmentHookImpl:
@hookimpl
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
information: Dict[str, str] = {}
information["name"] = environment.name

information: Dict[str, str] = {"name": environment.name}
if environment.nodes:
node = environment.default_node
try:
13 changes: 6 additions & 7 deletions lisa/sut_orchestrator/aws/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
@@ -67,16 +67,15 @@ class AwsNodeSchema:
data_disk_size: int = 32
disk_type: str = ""

# for marketplace image, which need to accept terms
_marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None
def __post_init__(self) -> None:

# Caching for marketplace image
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

@property
def marketplace(self) -> AwsVmMarketplaceSchema:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

if not self._marketplace:
if self._marketplace is None:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"
142 changes: 71 additions & 71 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

import re
import sys
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from threading import Lock
@@ -175,11 +175,12 @@ class AzureNodeSchema:
# image.
is_linux: Optional[bool] = None

_marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None
def __post_init__(self) -> None:

_shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None
# Caching
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
self._shared_gallery: Optional[SharedImageGallerySchema] = None

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
# trim whitespace of values.
strip_strs(
self,
@@ -201,80 +202,78 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

@property
def marketplace(self) -> Optional[AzureVmMarketplaceSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace
if not marketplace:
if isinstance(self.marketplace_raw, dict):

if self._marketplace is not None:
return self._marketplace

if isinstance(self.marketplace_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = {
k: v.lower() for k, v in self.marketplace_raw.items()
}
self._marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# This step makes sure marketplace_raw is validated, and
# filters out any unwanted content.
self.marketplace_raw = self._marketplace.to_dict() # type: ignore

elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
self.marketplace_raw = marketplace.to_dict() # type: ignore
elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
marketplace_strings = re.split(
r"[:\s]+", self.marketplace_raw.lower()
marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower())

if len(marketplace_strings) != 4:
raise LisaException(
"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
"The marketplace parameter should be in the format: "
"'<Publisher> <Offer> <Sku> <Version>' "
"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = (
self._marketplace.to_dict() # type: ignore [attr-defined]
)

if len(marketplace_strings) == 4:
marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = marketplace.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
f"The marketplace parameter should be in the format: "
f"'<Publisher> <Offer> <Sku> <Version>' "
f"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = marketplace
return marketplace
return self._marketplace

@marketplace.setter
def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None:
self._marketplace = value
if value is None:
self.marketplace_raw = None
else:
self.marketplace_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.marketplace_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_shared_gallery"):
self._shared_gallery: Optional[SharedImageGallerySchema] = None
shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery
if shared_gallery:
return shared_gallery

if self._shared_gallery is not None:
return self._shared_gallery

if isinstance(self.shared_gallery_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
self.shared_gallery_raw = {
k: v.lower() for k, v in self.shared_gallery_raw.items()
}

shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
@@ -283,6 +282,8 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self._shared_gallery = shared_gallery

elif self.shared_gallery_raw:
assert isinstance(
self.shared_gallery_raw, str
@@ -299,11 +300,12 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
elif len(shared_gallery_strings) == 3:
shared_gallery = SharedImageGallerySchema(
self._shared_gallery = SharedImageGallerySchema(
self.subscription_id, None, *shared_gallery_strings
)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

else:
raise LisaException(
f"Invalid value for the provided shared gallery "
@@ -313,16 +315,16 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
f"<image_definition>/<image_version>' or '<image_gallery>/"
f"<image_definition>/<image_version>'"
)
self._shared_gallery = shared_gallery
return shared_gallery

return self._shared_gallery

@shared_gallery.setter
def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None:
self._shared_gallery = value
if value is None:
self.shared_gallery_raw = None
else:
self.shared_gallery_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.shared_gallery_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

def get_image_name(self) -> str:
result = ""
@@ -365,9 +367,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter":
parameters["shared_gallery_raw"] = parameters["shared_gallery"]
del parameters["shared_gallery"]

arm_parameters = AzureNodeArmParameter(**parameters)

return arm_parameters
return AzureNodeArmParameter(**parameters)


class DataDiskCreateOption:
1 change: 1 addition & 0 deletions lisa/tools/bzip2.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ class Bzip2(Tool):
def command(self) -> str:
return "bzip2"

@property
def can_install(self) -> bool:
return True

17 changes: 3 additions & 14 deletions lisa/util/__init__.py
Original file line number Diff line number Diff line change
@@ -159,9 +159,7 @@ def __init__(self, os: "OperatingSystem", message: str = "") -> None:
self.version = os.information.full_version
self.kernel_version = ""
if hasattr(os, "get_kernel_information"):
self.kernel_version = (
os.get_kernel_information().raw_version # type: ignore
)
self.kernel_version = os.get_kernel_information().raw_version
self._extended_message = message

def __str__(self) -> str:
@@ -505,18 +503,9 @@ def find_group_in_lines(


def deep_update_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
if (
dest is None
or isinstance(dest, int)
or isinstance(dest, bool)
or isinstance(dest, float)
or isinstance(dest, str)
):
result = dest
else:
result = dest.copy()

if isinstance(result, dict):
if isinstance(dest, dict):
result = dest.copy()
for key, value in src.items():
if isinstance(value, dict) and key in dest:
value = deep_update_dict(value, dest[key])

0 comments on commit 2d11181

Please sign in to comment.