diff --git a/Makefile b/Makefile index e7a06904..83268315 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # -SRC = *.py standard_importer worldbank_wdi who_gho un_wpp +SRC = *.py standard_importer worldbank_wdi who_gho un_wpp owid steps tests default: @echo 'Available commands:' diff --git a/owid/__init__.py b/owid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/owid/dtypes.py b/owid/dtypes.py new file mode 100644 index 00000000..165d7882 --- /dev/null +++ b/owid/dtypes.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +# +# __init__.py +# importers +# + +""" +A first cut at making a Pythonic API that can handle data and metadata together. + +Philosophy: + +- Incremental: you should still be able to use with some or all metadata missing +- General: where possible, aim for things many people would want, not just OWID +""" + +from collections import defaultdict +from os import path +from typing import Protocol, Iterator, List, Dict, Any, NoReturn, Optional +from dataclasses import dataclass, field +import datetime as dt +import uuid +import re + +import pandas as pd +from dataclasses_json import dataclass_json + + +@dataclass_json +@dataclass +class Source: + name: Optional[str] = None + description: Optional[str] = None + url: Optional[str] = None + source_data_url: Optional[str] = None + owid_data_url: Optional[str] = None + date_accessed: Optional[str] = None + publication_date: Optional[str] = None + publication_year: Optional[int] = None + + +@dataclass_json +@dataclass +class AboutThisDataset: + """ + Metadata for an entire dataset, meant to be shared by all tables in this dataset. + Most of this comes directly from Walden. + + Goal: you can build an addressing scheme from this metadata. + """ + + namespace: Optional[str] = None + short_name: Optional[str] = None + title: Optional[str] = None + description: Optional[str] = None + sources: List[Source] = field(default_factory=list) + license_name: Optional[str] = None + license_url: Optional[str] = None + + +@dataclass +class AboutThisTable: + """ + Metadata for a table within a broader dataset. + """ + + # every dataset needs a name, use a UUID if we have nothing better + short_name: str = field(default_factory=lambda: str(uuid.uuid4())) + + primary_key: Optional[List[str]] = None + schema: Optional[Dict[str, "AboutThisSeries"]] = None + + # only used if all fields in the table share the same dataset + dataset: Optional[AboutThisDataset] = None + + def is_empty(self) -> bool: + return _is_uuid4(self.short_name) and all( + getattr(self, f) is None + for f in AboutThisTable.__dataclass_fields__ + if f != "short_name" + ) + + +@dataclass +class AboutThisSeries: + """ + Metadata for an individual field in a table. + """ + + name: Optional[str] = None + title: Optional[str] = None + description: Optional[str] = None + source_name: Optional[str] = None + dataset: Optional[AboutThisDataset] = None + + # XXX add units, type, etc + + def is_empty(self) -> bool: + return all( + getattr(self, f) is None for f in AboutThisSeries.__dataclass_fields__ + ) + + +class RichDataFrame(pd.DataFrame): + """ + A data frame that contains metadata about where it came from and about + the columns it has. Use it like a normal data frame, except that you can also + add metadata to it, field by field, or using the `metadata` attribute. + """ + + def __init__(self, *args, metadata: Optional[AboutThisTable] = None, **kwargs): + super().__init__(*args, **kwargs) + self.metadata = metadata or AboutThisTable() + if not self.primary_key: + self.primary_key = _detect_primary_key(self) + + @property + def _constructor(self): + return RichDataFrame + + @property + def _constructor_sliced(self): + return RichSeries + + _metadata = list(AboutThisTable.__dataclass_fields__) + + def set_metadata(self, metadata: Optional[AboutThisTable]) -> None: + metadata = metadata or AboutThisTable() + for field in AboutThisTable.__dataclass_fields__: + value = getattr(metadata, field) + setattr(self, field, value) + + def get_metadata(self) -> Optional[AboutThisTable]: + return AboutThisTable( + **{f: getattr(self, f, None) for f in AboutThisTable.__dataclass_fields__} + ) + + def ensure_named(self) -> None: + """ + Make sure this table has a name we can refer to it with. + """ + self.metadata.short_name = str(uuid.uuid4()) + + def to_feather(self, *args, **kwargs) -> None: + if self.has_index(): + pd.DataFrame.to_feather(self.reset_index(), *args, **kwargs) + + def has_index(self) -> bool: + names = self.index.names + return len(names) != 1 or names[0] is not None + + def all_columns(self): + "Return all column names, including those in the index." + cols = [col for col in self.index.names if col] + cols += list(self.columns) + return cols + + metadata = property(get_metadata, set_metadata) + + +class RichSeries(pd.Series): + """ + A pandas Series with optional metadata about this column and where it came from. + Use it like a normal series, or enrich it with fields from AboutThisSeries. + """ + + def __init__(self, *args, metadata: Optional[AboutThisSeries] = None, **kwargs): + super().__init__(*args, **kwargs) + self.metadata = metadata or AboutThisSeries(name=kwargs.get("name")) + + @property + def _constructor(self): + return RichSeries + + @property + def _constructor_expanddim(self): + return RichDataFrame + + _metadata = list(AboutThisSeries.__dataclass_fields__) + + def set_metadata(self, metadata: Optional[AboutThisSeries]) -> None: + for field in AboutThisSeries.__dataclass_fields__: + if metadata is not None: + value = getattr(metadata, field) + else: + value = None + setattr(self, field, value) + + def get_metadata(self) -> Optional[AboutThisSeries]: + return AboutThisSeries( + **{f: getattr(self, f, None) for f in AboutThisSeries.__dataclass_fields__} + ) + + metadata = property(get_metadata, set_metadata) + + +class Dataset(Protocol): + metadata: AboutThisDataset + + def __iter__(self) -> Iterator[RichDataFrame]: + ... + + def __len__(self) -> int: + ... + + def __getitem__(self, table_name: str) -> RichDataFrame: + ... + + def add(self, table: RichDataFrame) -> None: + ... + + +class SerializableDataset(Dataset, Protocol): + def save(self, path: str) -> None: + ... + + @staticmethod + def load(path: str) -> "Dataset": + ... + + +def _detect_primary_key(df: pd.DataFrame) -> Optional[List[str]]: + primary_key: List[str] = list(df.index.names) + if primary_key[0] is not None: + return primary_key + + return None + + +@dataclass +class InMemoryDataset(SerializableDataset): + tables: Dict[str, RichDataFrame] = field(default_factory=dict) + metadata: AboutThisDataset = field(default_factory=AboutThisDataset) + + def __len__(self) -> int: + return len(self.tables) + + def __iter__(self) -> Iterator[RichDataFrame]: + yield from self.tables.values() + + def __getitem__(self, table_name: str) -> RichDataFrame: + return self.tables[table_name] + + def add(self, table: RichDataFrame) -> None: + # link the table's metadata to that of the entire dataset + table.dataset = self.metadata + + # add the table to our collection + self.tables[table.metadata.short_name] = table # type: ignore + + +def _is_uuid4(name: str) -> bool: + return bool( + re.match("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", name) + ) diff --git a/owid/formats.py b/owid/formats.py new file mode 100644 index 00000000..fc1dee0d --- /dev/null +++ b/owid/formats.py @@ -0,0 +1,264 @@ +# +# formats.py +# +# Load and save datasets and tables in different on-disk formats. +# + + +from dataclasses import dataclass +from typing import Any, Iterator, NoReturn, Dict, List, Optional +import atexit +import tempfile +import json +from os import path +import os +from collections import defaultdict +import shutil + +from . import dtypes + +import frictionless +import pandas as pd + +# ---------------------------------------------------------------------------------------------# +# Frictionless data format +# ---------------------------------------------------------------------------------------------# + + +def error_on_set(self: Any, v: Any) -> NoReturn: + raise Exception("setter not implemented") + + +class Frictionless: + """ + Encoding and decoding according to the frictionless standard. + + See: https://specs.frictionlessdata.io/schemas/data-package.json + """ + + @staticmethod + def save(ds: dtypes.Dataset, dirname: str) -> None: + package_file = path.join(dirname, "datapackage.json") + + if path.exists(dirname): + # save over the top of an existing dataset, but be careful about it + if not path.exists(package_file): + raise Exception( + f"refuse to save over the top of a non-dataset folder: {dirname}" + ) + + shutil.rmtree(dirname) + + os.mkdir(dirname) + + resources = [] + for table in ds: + resource = Frictionless.save_table(table, dirname) + resources.append(resource) + + m = Frictionless.encode_metadata(ds.metadata) + m["resources"] = resources + with open(package_file, "w") as ostream: + json.dump(m, ostream, indent=2) + + @staticmethod + def save_table(table: dtypes.RichDataFrame, dirname: str) -> dict: + metadata: dtypes.AboutThisTable = table.metadata + + if not metadata.short_name: + raise Exception("cannot serialise a table without a short name for it") + + # save the data + dest_file = path.join(dirname, f"{metadata.short_name}.feather") + table.to_feather(dest_file) + + # return the resource metadata + schema: Dict[str, dtypes.AboutThisSeries] = table.metadata.schema # type: ignore + d = { + "name": metadata.short_name, + "path": path.basename(dest_file), + "schema": { + "primaryKey": table.primary_key, + "fields": { + col: Frictionless.encode_series_metadata(schema[col]) + for col in schema + }, + }, + } + return d + + @staticmethod + def load(dirname: str) -> "FrictionlessDataset": + pkg_file = path.join(dirname, "datapackage.json") + return FrictionlessDataset(pkg_file) + + @staticmethod + def encode_metadata(metadata: dtypes.AboutThisDataset) -> dict: + d = { + "_namespace": metadata.namespace, + "name": metadata.short_name, + "title": metadata.title, + "description": metadata.description, + "licenses": [{"name": metadata.license_name, "path": metadata.license_url}], + "sources": [ + { + "title": source.name, + "path": source.url, + "_date_accessed": source.date_accessed, + "_publication_date": source.publication_date, + "_publication_year": source.publication_year, + "_description": source.description, + "_source_data_url": source.source_data_url, + "_owid_data_url": source.owid_data_url, + } + for source in metadata.sources + ], + } + return pruned(d) + + @staticmethod + def decode_metadata(metadata: dict) -> dtypes.AboutThisDataset: + # we only support one license, but frictionless allows many + licenses: List[dict] = metadata.get("licenses", defaultdict(lambda: None)) + if len(licenses) > 1: + raise ValueError("OWID datasets only support one license per dataset") + license = licenses[0] + + sources = [defaultdict(lambda: None, source) for source in metadata["sources"]] + + return dtypes.AboutThisDataset( + namespace=metadata.get("_namespace"), + short_name=metadata.get("name"), + title=metadata.get("title"), + description=metadata.get("description"), + license_name=license["name"], + license_url=license["path"], + sources=[ + dtypes.Source( + name=source["title"], + description=source["_description"], + url=source["path"], + source_data_url=source["_source_data_url"], + owid_data_url=source["_owid_data_url"], + date_accessed=source["_date_accessed"], + publication_date=source["_publication_date"], + publication_year=source["_publication_year"], + ) + for source in sources + ], + ) + + @staticmethod + def table_from_resource( + resource: frictionless.Resource, base_dir: str, dataset: dtypes.AboutThisDataset + ): + metadata: dtypes.AboutThisTable = Frictionless.decode_table_metadata( + resource, dataset + ) + + filename = path.join(base_dir, resource.path) # type: ignore + df = pd.read_feather(filename) + df = dtypes.RichDataFrame(df, metadata=metadata) + + if df.primary_key: + df.set_index(df.primary_key, inplace=True) + + return df + + @staticmethod + def decode_table_metadata( + resource: frictionless.Resource, + dataset: Optional[dtypes.AboutThisDataset] = None, + ) -> dtypes.AboutThisTable: + fields: Dict[str, Any] = resource.schema["fields"] # type: ignore + return dtypes.AboutThisTable( + short_name=resource.name, # type: ignore + primary_key=resource.schema.get("primaryKey"), + schema={ + col: Frictionless.decode_series_metadata( + resource.schema.fields[col], dataset + ) + for col in fields + }, + dataset=dataset, + ) + + @staticmethod + def encode_series_metadata(metadata: dtypes.AboutThisSeries) -> Dict[str, Any]: + return { + "name": metadata.name, + "title": metadata.title, + "description": metadata.description, + "_source_name": metadata.source_name, + } + + @staticmethod + def decode_series_metadata( + schema: Dict[str, Any], dataset: Optional[dtypes.AboutThisDataset] = None + ) -> dtypes.AboutThisSeries: + return dtypes.AboutThisSeries( + name=schema["name"], + title=schema.get("title"), + description=schema.get("description"), + source_name=schema.get("_source_name"), + dataset=dataset, + ) + + +class FrictionlessDataset: + """ + A dataset is a folder in Frictionless data format, containing many CSV files and one + huge datapackage.json files containing schemas and metadata for them all. + """ + + def __init__(self, pkg_file: str): + self.pkg_file = pkg_file + self.pkg = frictionless.Package(pkg_file) + self.metadata = Frictionless.decode_metadata(self.pkg) + + def __len__(self) -> int: + return len(self.pkg.resources) # type: ignore + + def __iter__(self) -> Iterator[dtypes.RichDataFrame]: + base_dir = path.dirname(self.pkg_file) + for resource in self.pkg.resources: # type: ignore + table = Frictionless.table_from_resource(resource, base_dir, self.metadata) + yield table + + def __getitem__(self, table_name: str) -> dtypes.RichDataFrame: + base_dir = path.dirname(self.pkg_file) + (t,) = [r for r in self.pkg.resources if r["name"] == table_name] # type: ignore + return Frictionless.table_from_resource(t, base_dir, self.metadata) + + def save(self, path: str) -> None: + Frictionless.save(self, path) + + @staticmethod + def load(path: str) -> "FrictionlessDataset": + return Frictionless.load(path) + + def add(self, table: dtypes.RichDataFrame) -> None: + raise Exception("not yet implemented") + + +def _get_primary_key(df: pd.DataFrame) -> List[str]: + primary_key: List[str] = list(df.index.names) + if primary_key[0] is not None: + return primary_key + + return [] + + +def pruned(v: Any) -> Any: + "Prune a JSON-like document to remove any (k, v) pairs where the value is None." + if isinstance(v, dict): + v = v.copy() + for k in list(v): + v[k] = pruned(v[k]) + if v[k] is None: + del v[k] + + elif isinstance(v, list): + return [pruned(x) for x in v] + + return v diff --git a/requirements.txt b/requirements.txt index 548bc7fe..e2c3ff6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ requests==2.26.0 aiohttp==3.7.4.post0 beautifulsoup4==4.9.3 click==8.0.1 -dataclasses==0.8 pandas==1.3.2 pdfminer==20191125 pymysql==1.0.2 @@ -12,3 +11,6 @@ python-dotenv==0.19.0 simplejson==3.17.5 tqdm==4.62.2 Unidecode==1.2.0 +frictionless[pandas] +dataclasses_json +pyarrow diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100644 index 00000000..2dd350fb --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,82 @@ +from owid import dtypes +import pandas as pd + + +def test_rich_series_no_metadata(): + # no metadata case + s1 = dtypes.RichSeries([1, 2, 3]) + assert s1.metadata.is_empty() + assert s1.metadata == dtypes.AboutThisSeries() + + +def test_rich_series_keeps_name(): + # compatibility of metadata with name + s2 = dtypes.RichSeries([1, 2, 3], name="numbers") + assert s2.metadata == dtypes.AboutThisSeries(name="numbers") + + +def test_rich_series_slicing_and_access(): + m3 = dtypes.AboutThisSeries( + name="gdp", title="GDP per capita in 2011 international dollars" + ) + s3 = dtypes.RichSeries( + [53015, 54008, 55335], + index=pd.MultiIndex.from_tuples( + [("usa", 2016), ("usa", 2017), ("usa", 2018)], names=["country", "year"] + ), + metadata=m3, + ) + assert s3.metadata == m3 # construction + assert s3.iloc[:2].metadata == m3 # slicing + assert s3.title == m3.title # individual access + + +def test_rich_dataframe_no_metadata(): + d = dtypes.RichDataFrame({"a": [1, 2, 3], "b": ["dog", "sheep", "pig"]}) + assert d.metadata.is_empty() + + +def test_rich_dataframe_detect_primary_key(): + d = dtypes.RichDataFrame( + {"c": ["dog", "sheep", "pig"]}, + index=pd.MultiIndex.from_tuples( + [(1, 2020), (2, 2020), (2, 2021)], names=["a", "b"] + ), + metadata=dtypes.AboutThisTable(short_name="example"), + ) + assert d.primary_key == ["a", "b"] + assert d.metadata == dtypes.AboutThisTable( + short_name="example", primary_key=["a", "b"] + ) + + +def test_rich_dataframe_creates_rich_series(): + gho = dtypes.AboutThisDataset(short_name="GHO") + d = dtypes.RichDataFrame( + {"c": ["dog", "sheep", "pig"]}, + index=pd.MultiIndex.from_tuples( + [(1, 2020), (2, 2020), (2, 2021)], names=["a", "b"] + ), + metadata=dtypes.AboutThisTable(dataset=gho), + ) + s = d.c + assert isinstance(s, dtypes.RichSeries) + assert s.dataset == gho + assert s.metadata.dataset == gho + assert s.metadata.name == "c" + + +def test_rich_dataframe_metadata_survives_copying(): + metadata = dtypes.AboutThisTable(dataset=dtypes.AboutThisDataset(short_name="GHO")) + d = dtypes.RichDataFrame( + {"c": ["dog", "sheep", "pig"]}, + index=pd.MultiIndex.from_tuples( + [(1, 2020), (2, 2020), (2, 2021)], names=["a", "b"] + ), + metadata=metadata, + ) + assert not d.metadata.is_empty() + + # try slicing and copying + assert d.iloc[:1].metadata == d.metadata # type: ignore + assert d.copy().metadata == d.metadata diff --git a/tests/test_formats.py b/tests/test_formats.py new file mode 100644 index 00000000..73d6e9e7 --- /dev/null +++ b/tests/test_formats.py @@ -0,0 +1,243 @@ +# +# test_formats.py +# + +from inspect import trace +import tempfile +import random +from typing import Any, Dict +import typing +import datetime as dt +import os +from os import path + +from owid import dtypes +from owid.formats import Frictionless + +import frictionless +import pandas as pd + + +def test_encode_metadata_to_frictionless(): + metadata = mock_dataset_metadata() + d = Frictionless.encode_metadata(metadata) + + # the frictionless standard requires at least one resources, so add a dummy one + d["resources"] = [ + { + "path": "https://owid-test.nyc3.digitaloceanspaces.com/importers/01-minimal.csv" + } + ] + + assert metadata.namespace == d["_namespace"] + assert metadata.short_name == d["name"] + assert metadata.title == d["title"] + assert metadata.description == d["description"] + assert metadata.license_name == d["licenses"][0]["name"] + assert metadata.license_url == d["licenses"][0]["path"] + for i, source in enumerate(metadata.sources): + s = d["sources"][i] + assert source.name == s["title"] + assert source.description == s["_description"] + assert source.url == s["path"] + assert source.source_data_url == s["_source_data_url"] + assert source.owid_data_url == s["_owid_data_url"] + + # validate against the frictionless standard + assert frictionless.validate(d).errors == [] + + +def test_decode_metadata_from_frictionless(): + d = { + "_namespace": "drinks", + "name": "very_fancy", + "title": "All cocktails known to mankind", + "description": "Long markdown doc...", + "sources": [ + { + "title": "Bartender's guide 2040", + "path": "https://dev.null/", + "_description": "An extremely long markdown description...", + "_source_data_url": "https://dev.null/example.csv", + "_owid_data_url": "https://fake.ourworldindata.org/example.csv", + } + ], + "licenses": [ + { + "name": "CC-BY-NC-4.0", + "path": "https://creativecommons.org/licenses/by-nc/4.0/", + } + ], + } + metadata = Frictionless.decode_metadata(d) + assert metadata.namespace == d["_namespace"] + assert metadata.short_name == d["name"] + assert metadata.title == d["title"] + assert metadata.description == d["description"] + assert metadata.license_name == d["licenses"][0]["name"] + assert metadata.license_url == d["licenses"][0]["path"] + for i, s in enumerate(d["sources"]): + source = metadata.sources[i] + assert source.name == s["title"] + assert source.description == s["_description"] + assert source.url == s["path"] + assert source.source_data_url == s["_source_data_url"] + assert source.owid_data_url == s["_owid_data_url"] + + +def test_frictionless_series_metadata_roundtrip(): + m1: dtypes.AboutThisSeries = attr_updated( + mock(dtypes.AboutThisSeries), dataset=None + ) + + s = Frictionless.encode_series_metadata(m1) + + m2 = Frictionless.decode_series_metadata(s) + + assert m1 == m2 + + +def test_frictionless_round_trip(): + "Check that we can encode data to frictionless in a lossless way." + # set up dataset + metadata = mock(dtypes.AboutThisDataset) + df = dtypes.RichDataFrame( + { + "ice_cream": ["black sesame", "marshmallow", "pepparkakor"], + }, + index=pd.Index(["AUS", "USA", "SWE"], name="country"), + metadata=dtypes.AboutThisTable( + short_name="best_flavours", + primary_key=["country"], + schema={ + "country": attr_updated(mock(dtypes.AboutThisSeries), dataset=metadata), + "ice_cream": attr_updated( + mock(dtypes.AboutThisSeries), dataset=metadata + ), + }, + ), + ) + ds: dtypes.SerializableDataset = dtypes.InMemoryDataset(metadata=metadata) + ds.add(df) + + with tempfile.TemporaryDirectory() as temp_dir: + # get rid of the auto-created directory + os.rmdir(temp_dir) + + # save to disk + Frictionless.save(ds, temp_dir) + + # check that the package validates clean + package_file = path.join(temp_dir, "datapackage.json") + assert path.exists(package_file) + + # XXX failing validation with internal error in frictionless + # assert frictionless.validate(package_file).errors == [] + + # read from disk + ds2 = Frictionless.load(temp_dir) + + print(ds["best_flavours"].metadata.schema) + print(ds2["best_flavours"].metadata.schema) + + assert_ds_eq(ds, ds2) + + +def assert_dataclass_eq(lhs: Any, rhs: Any, _type: Any) -> None: + for f in _type.__dataclass_fields__: + assert getattr(lhs, f) == getattr(rhs, f), f + + +def assert_df_eq(lhs: dtypes.RichDataFrame, rhs: dtypes.RichDataFrame) -> None: + # assert lhs.metadata == rhs.metadata + assert_dataclass_eq(lhs.metadata, rhs.metadata, type(lhs.metadata)) + assert lhs.to_dict() == rhs.to_dict() + + +def assert_ds_eq(lhs: dtypes.Dataset, rhs: dtypes.Dataset) -> None: + # assert lhs.metadata == rhs.metadata + assert_dataclass_eq(lhs.metadata, rhs.metadata, type(lhs.metadata)) + assert len(lhs) == len(rhs) + for lhs_t, rhs_t in zip(lhs, rhs): + assert_df_eq(lhs_t, rhs_t) + + +_MOCK_STRINGS = None + + +def mock_dataset_metadata() -> dtypes.AboutThisDataset: + return dtypes.AboutThisDataset( + **{ + f.name: mock(f.type) + for f in dtypes.AboutThisDataset.__dataclass_fields__.values() + } + ) + + +def is_optional_type(_type: type) -> bool: + return ( + getattr(_type, "__origin__", None) == typing.Union + and len(getattr(_type, "__args__", ())) == 2 + and getattr(_type, "__args__")[1] == type(None) + ) + + +def strip_option(_type: type) -> type: + return _type.__args__[0] # type: ignore + + +def mock(_type: type) -> Any: + global _MOCK_STRINGS + + if is_optional_type(_type): + _type = strip_option(_type) + + if hasattr(_type, "__forward_arg__"): + raise ValueError(_type) + + if _type == int: + return random.randint(0, 1000) + + elif _type == float: + return 10 * random.random() / random.random() + + elif _type == dt.date: + return dt.date.fromordinal( + dt.date.today().toordinal() - random.randint(0, 1000) + ) + + elif _type == str: + if not _MOCK_STRINGS: + _MOCK_STRINGS = [l.strip() for l in open("/usr/share/dict/words")] + + # some strings in the frictionless standard must be lowercase with no spaces + return random.choice(_MOCK_STRINGS).lower() + + elif getattr(_type, "_name", None) == "List": + # e.g. List[int] + return [mock(_type.__args__[0]) for i in range(random.randint(1, 4))] # type: ignore + + elif getattr(_type, "_name", None) == "Dict": + # e.g. Dict[str, int] + _from, _to = _type.__args__ # type: ignore + return {mock(_from): mock(_to) for i in range(random.randint(1, 8))} + + elif hasattr(_type, "__dataclass_fields__"): + # all dataclasses + return _type( + **{ + f.name: mock(f.type) + for f in _type.__dataclass_fields__.values() # type: ignore + } + ) + + raise ValueError(f"don't know how to mock type: {_type}") + + +T = typing.TypeVar("T") + + +def attr_updated(obj: T, **kwargs) -> T: + for k, v in kwargs.items(): + setattr(obj, k, v) + return obj