Skip to content

Commit

Permalink
Import model_card_toolkit as mct (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
codesue committed Jun 12, 2023
1 parent 5598a50 commit 42d1860
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 92 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ for more installation options.

## Getting Started

import model_card_toolkit
import model_card_toolkit as mct

# Initialize the Model Card Toolkit with a path to store generate assets
model_card_output_path = ...
mct = model_card_toolkit.ModelCardToolkit(model_card_output_path)
toolkit = mct.ModelCardToolkit(model_card_output_path)

# Initialize the model_card_toolkit.ModelCard, which can be freely populated
model_card = mct.scaffold_assets()
# Initialize the ModelCard, which can be freely populated
model_card = toolkit.scaffold_assets()
model_card.model_details.name = 'My Model'

# Write the model card data to a proto file
mct.update_model_card(model_card)
toolkit.update_model_card(model_card)

# Return the model card document as an HTML page
html = mct.export_format()
html = toolkit.export_format()

## Model Card Generation on TFX

Expand Down
10 changes: 5 additions & 5 deletions model_card_toolkit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ class ModelCardToolkit():
Standard workflow:
```python
import model_card_toolkit
import model_card_toolkit as mct
# Initialize the Model Card Toolkit with a path to store generate assets
model_card_dir_path = ...
mct = model_card_toolkit.ModelCardToolkit(model_card_dir_path)
toolkit = mct.ModelCardToolkit(model_card_dir_path)
# Initialize the ModelCard, which can be freely populated
model_card = mct.scaffold_assets()
model_card = toolkit.scaffold_assets()
model_card.model_details.name = 'My Model'
# Write the model card data to a proto file
mct.update_model_card(model_card)
toolkit.update_model_card(model_card)
# Return the model card document as an HTML page
html = mct.export_format()
html = toolkit.export_format()
```
"""
def __init__(
Expand Down
36 changes: 18 additions & 18 deletions model_card_toolkit/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def setUp(self):

def test_scaffold_assets(self):
output_dir = self.mct_dir
mct = core.ModelCardToolkit(output_dir=output_dir)
self.assertEqual(mct.output_dir, output_dir)
mct.scaffold_assets()
toolkit = core.ModelCardToolkit(output_dir=output_dir)
self.assertEqual(toolkit.output_dir, output_dir)
toolkit.scaffold_assets()
self.assertIn(
'default_template.html.jinja',
os.listdir(os.path.join(output_dir, 'template/html'))
Expand All @@ -49,17 +49,17 @@ def test_scaffold_assets(self):
)

def test_scaffold_assets_with_json(self):
mct = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = mct.scaffold_assets({'model_details': {
toolkit = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = toolkit.scaffold_assets({'model_details': {
'name': 'json_test',
}})
self.assertEqual(mc.model_details.name, 'json_test')

def test_update_model_card_with_valid_model_card(self):
mct = core.ModelCardToolkit(output_dir=self.mct_dir)
valid_model_card = mct.scaffold_assets()
toolkit = core.ModelCardToolkit(output_dir=self.mct_dir)
valid_model_card = toolkit.scaffold_assets()
valid_model_card.model_details.name = 'My Model'
mct.update_model_card(valid_model_card)
toolkit.update_model_card(valid_model_card)
proto_path = os.path.join(self.mct_dir, 'data/model_card.proto')

model_card_proto = io_utils.parse_proto_file(
Expand All @@ -71,8 +71,8 @@ def test_update_model_card_with_valid_model_card_as_proto(self):
valid_model_card = model_card_pb2.ModelCard()
valid_model_card.model_details.name = 'My Model'

mct = core.ModelCardToolkit(output_dir=self.mct_dir)
mct.update_model_card(valid_model_card)
toolkit = core.ModelCardToolkit(output_dir=self.mct_dir)
toolkit.update_model_card(valid_model_card)
proto_path = os.path.join(self.mct_dir, 'data/model_card.proto')

model_card_proto = io_utils.parse_proto_file(
Expand All @@ -81,11 +81,11 @@ def test_update_model_card_with_valid_model_card_as_proto(self):
self.assertEqual(model_card_proto, valid_model_card)

def test_export_format(self):
mct = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = mct.scaffold_assets()
toolkit = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = toolkit.scaffold_assets()
mc.model_details.name = 'My Model'
mct.update_model_card(mc)
result = mct.export_format()
toolkit.update_model_card(mc)
result = toolkit.export_format()

proto_path = os.path.join(self.mct_dir, 'data/model_card.proto')
self.assertTrue(os.path.exists(proto_path))
Expand All @@ -102,16 +102,16 @@ def test_export_format(self):
self.assertIn('My Model', content)

def test_export_format_with_customized_template_and_output_name(self):
mct = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = mct.scaffold_assets()
toolkit = core.ModelCardToolkit(output_dir=self.mct_dir)
mc = toolkit.scaffold_assets()
mc.model_details.name = 'My Model'
mct.update_model_card(mc)
toolkit.update_model_card(mc)

template_path = os.path.join(
self.mct_dir, 'template/html/default_template.html.jinja'
)
output_file = 'my_model_card.html'
result = mct.export_format(
result = toolkit.export_format(
template_path=template_path, output_file=output_file
)

Expand Down
4 changes: 2 additions & 2 deletions model_card_toolkit/core_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def test_scaffold_assets_with_store(
num_eval_artifacts = 1
output_dir = self.mct_dir
store = tf_testdata_utils.get_tfx_pipeline_metadata_store(self.tmp_db_path)
mct = core.ModelCardToolkit(
toolkit = core.ModelCardToolkit(
output_dir=output_dir, mlmd_source=tf_sources.MlmdSource(
store=store, model_uri=tf_testdata_utils.TFX_0_21_MODEL_URI
)
)
mc = mct.scaffold_assets()
mc = toolkit.scaffold_assets()
self.assertIsNotNone(mc.model_details.name)
self.assertIsNotNone(mc.model_details.version.name)
self.assertIn(
Expand Down
14 changes: 7 additions & 7 deletions model_card_toolkit/documentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
## Getting Started

```
import model_card_toolkit
import model_card_toolkit as mct
# Initialize the Model Card Toolkit with a path to store generate assets
model_card_output_path = ...
mct = model_card_toolkit.ModelCardToolkit(model_card_output_path)
toolkit = mct.ModelCardToolkit(model_card_output_path)
# Initialize the model_card_toolkit.ModelCard, which can be freely populated
model_card = mct.scaffold_assets()
# Initialize the ModelCard, which can be freely populated
model_card = toolkit.scaffold_assets()
model_card.model_details.name = 'My Model'
# Write the model card data to a file
mct.update_model_card(model_card) # writes to proto
mct.update_model_card_json(model_card) # writes to JSON
toolkit.update_model_card(model_card) # writes to proto
toolkit.update_model_card_json(model_card) # writes to JSON
# Return the model card document as an HTML page, and save to file
html = mct.export_format()
html = toolkit.export_format()
```

## Tutorials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"from datetime import date\n",
"from io import BytesIO\n",
"from IPython import display\n",
"import model_card_toolkit as mctlib\n",
"import model_card_toolkit as mct\n",
"from sklearn.datasets import load_breast_cancer\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"from sklearn.model_selection import train_test_split\n",
Expand Down Expand Up @@ -330,9 +330,9 @@
},
"outputs": [],
"source": [
"mct = mctlib.ModelCardToolkit()\n",
"toolkit = mct.ModelCardToolkit()\n",
"\n",
"model_card = mct.scaffold_assets()"
"model_card = toolkit.scaffold_assets()"
]
},
{
Expand All @@ -357,46 +357,46 @@
" 'This model predicts whether breast cancer is benign or malignant based on '\n",
" 'image measurements.')\n",
"model_card.model_details.owners = [\n",
" mctlib.Owner(name= 'Model Cards Team', contact='model-cards@google.com')\n",
" mct.Owner(name= 'Model Cards Team', contact='model-cards@google.com')\n",
"]\n",
"model_card.model_details.references = [\n",
" mctlib.Reference(reference='https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)'),\n",
" mctlib.Reference(reference='https://minds.wisconsin.edu/bitstream/handle/1793/59692/TR1131.pdf')\n",
" mct.Reference(reference='https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)'),\n",
" mct.Reference(reference='https://minds.wisconsin.edu/bitstream/handle/1793/59692/TR1131.pdf')\n",
"]\n",
"model_card.model_details.version.name = str(uuid.uuid4())\n",
"model_card.model_details.version.date = str(date.today())\n",
"\n",
"model_card.considerations.ethical_considerations = [mctlib.Risk(\n",
"model_card.considerations.ethical_considerations = [mct.Risk(\n",
" name=('Manual selection of image sections to digitize could create '\n",
" 'selection bias'),\n",
" mitigation_strategy='Automate the selection process'\n",
")]\n",
"model_card.considerations.limitations = [mctlib.Limitation(description='Breast cancer diagnosis')]\n",
"model_card.considerations.use_cases = [mctlib.UseCase(description='Breast cancer diagnosis')]\n",
"model_card.considerations.users = [mctlib.User(description='Medical professionals'), mctlib.User(description='ML researchers')]\n",
"model_card.considerations.limitations = [mct.Limitation(description='Breast cancer diagnosis')]\n",
"model_card.considerations.use_cases = [mct.UseCase(description='Breast cancer diagnosis')]\n",
"model_card.considerations.users = [mct.User(description='Medical professionals'), mct.User(description='ML researchers')]\n",
"\n",
"model_card.model_parameters.data.append(mctlib.Dataset())\n",
"model_card.model_parameters.data.append(mct.Dataset())\n",
"model_card.model_parameters.data[0].graphics.description = (\n",
" f'{len(X_train)} rows with {len(X_train.columns)} features')\n",
"model_card.model_parameters.data[0].graphics.collection = [\n",
" mctlib.Graphic(image=mean_radius_train),\n",
" mctlib.Graphic(image=mean_texture_train)\n",
" mct.Graphic(image=mean_radius_train),\n",
" mct.Graphic(image=mean_texture_train)\n",
"]\n",
"model_card.model_parameters.data.append(mctlib.Dataset())\n",
"model_card.model_parameters.data.append(mct.Dataset())\n",
"model_card.model_parameters.data[1].graphics.description = (\n",
" f'{len(X_test)} rows with {len(X_test.columns)} features')\n",
"model_card.model_parameters.data[1].graphics.collection = [\n",
" mctlib.Graphic(image=mean_radius_test),\n",
" mctlib.Graphic(image=mean_texture_test)\n",
" mct.Graphic(image=mean_radius_test),\n",
" mct.Graphic(image=mean_texture_test)\n",
"]\n",
"model_card.quantitative_analysis.graphics.description = (\n",
" 'ROC curve and confusion matrix')\n",
"model_card.quantitative_analysis.graphics.collection = [\n",
" mctlib.Graphic(image=roc_curve),\n",
" mctlib.Graphic(image=confusion_matrix)\n",
" mct.Graphic(image=roc_curve),\n",
" mct.Graphic(image=confusion_matrix)\n",
"]\n",
"\n",
"mct.update_model_card(model_card)"
"toolkit.update_model_card(model_card)"
]
},
{
Expand All @@ -418,7 +418,7 @@
"source": [
"# Return the model card document as an HTML page\n",
"\n",
"html = mct.export_format()\n",
"html = toolkit.export_format()\n",
"\n",
"display.display(display.HTML(html))"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import model_card_toolkit as mctlib\n",
"import model_card_toolkit as mct\n",
"from model_card_toolkit.documentation.examples import cats_vs_dogs\n",
"from model_card_toolkit.utils.graphics import figure_to_base64str\n",
"import tempfile\n",
Expand Down Expand Up @@ -282,10 +282,10 @@
"source": [
"# https://github.com/tensorflow/model-card-toolkit/blob/master/model_card_toolkit/model_card_toolkit.py\n",
"model_card_dir = tempfile.mkdtemp()\n",
"mct = mctlib.ModelCardToolkit(model_card_dir)\n",
"toolkit = mct.ModelCardToolkit(model_card_dir)\n",
"\n",
"# https://github.com/tensorflow/model-card-toolkit/blob/master/model_card_toolkit/model_card.py\n",
"model_card = mct.scaffold_assets()"
"model_card = toolkit.scaffold_assets()"
]
},
{
Expand Down Expand Up @@ -336,15 +336,15 @@
" 'performed with high accuracy on both Cat and Dog images.'\n",
")\n",
"model_card.model_details.owners = [\n",
" mctlib.Owner(name='Model Cards Team', contact='model-cards@google.com')\n",
" mct.Owner(name='Model Cards Team', contact='model-cards@google.com')\n",
"]\n",
"model_card.model_details.version = mctlib.Version(name='v1.0', date='08/28/2020')\n",
"model_card.model_details.version = mct.Version(name='v1.0', date='08/28/2020')\n",
"model_card.model_details.references = [\n",
" mctlib.Reference(reference='https://www.tensorflow.org/guide/keras/transfer_learning'),\n",
" mctlib.Reference(reference='https://arxiv.org/abs/1801.04381'),\n",
" mct.Reference(reference='https://www.tensorflow.org/guide/keras/transfer_learning'),\n",
" mct.Reference(reference='https://arxiv.org/abs/1801.04381'),\n",
"]\n",
"model_card.model_details.licenses = [mctlib.License(identifier='Apache-2.0')]\n",
"model_card.model_details.citations = [mctlib.Citation(citation='https://github.com/tensorflow/model-card-toolkit/blob/master/model_card_toolkit/documentation/examples/Standalone_Model_Card_Toolkit_Demo.ipynb')]"
"model_card.model_details.licenses = [mct.License(identifier='Apache-2.0')]\n",
"model_card.model_details.citations = [mct.Citation(citation='https://github.com/tensorflow/model-card-toolkit/blob/master/model_card_toolkit/documentation/examples/Standalone_Model_Card_Toolkit_Demo.ipynb')]"
]
},
{
Expand All @@ -369,9 +369,9 @@
"outputs": [],
"source": [
"model_card.quantitative_analysis.performance_metrics = [\n",
" mctlib.PerformanceMetric(type='accuracy', value=str(accuracy)),\n",
" mctlib.PerformanceMetric(type='accuracy', value=str(cat_accuracy), slice='cat'),\n",
" mctlib.PerformanceMetric(type='accuracy', value=str(dog_accuracy), slice='Dog'),\n",
" mct.PerformanceMetric(type='accuracy', value=str(accuracy)),\n",
" mct.PerformanceMetric(type='accuracy', value=str(cat_accuracy), slice='cat'),\n",
" mct.PerformanceMetric(type='accuracy', value=str(dog_accuracy), slice='Dog'),\n",
"]"
]
},
Expand All @@ -395,12 +395,12 @@
"outputs": [],
"source": [
"model_card.considerations.use_cases = [\n",
" mctlib.UseCase(description='This model classifies images of cats and dogs.')\n",
" mct.UseCase(description='This model classifies images of cats and dogs.')\n",
"]\n",
"model_card.considerations.limitations = [\n",
" mctlib.Limitation(description='This model is not able to classify images of other classes.')\n",
" mct.Limitation(description='This model is not able to classify images of other classes.')\n",
"]\n",
"model_card.considerations.ethical_considerations = [mctlib.Risk(\n",
"model_card.considerations.ethical_considerations = [mct.Risk(\n",
" name=\n",
" 'While distinguishing between cats and dogs is generally agreed to be '\n",
" 'a benign application of machine learning, harmful results can occur '\n",
Expand Down Expand Up @@ -487,12 +487,12 @@
},
"outputs": [],
"source": [
"model_card.model_parameters.data.append(mctlib.Dataset())\n",
"model_card.model_parameters.data.append(mct.Dataset())\n",
"model_card.model_parameters.data[0].graphics.collection = [\n",
" mctlib.Graphic(name='Validation Set Size', image=validation_set_size_barchart),\n",
" mct.Graphic(name='Validation Set Size', image=validation_set_size_barchart),\n",
"]\n",
"model_card.quantitative_analysis.graphics.collection = [\n",
" mctlib.Graphic(name='Accuracy', image=accuracy_barchart),\n",
" mct.Graphic(name='Accuracy', image=accuracy_barchart),\n",
"]"
]
},
Expand All @@ -516,7 +516,7 @@
},
"outputs": [],
"source": [
"mct.update_model_card(model_card)"
"toolkit.update_model_card(model_card)"
]
},
{
Expand All @@ -537,7 +537,7 @@
"outputs": [],
"source": [
"# Generate a model card document in HTML (default)\n",
"html_doc = mct.export_format()\n",
"html_doc = toolkit.export_format()\n",
"\n",
"# Display the model card document in HTML\n",
"display.display(display.HTML(html_doc))"
Expand All @@ -562,7 +562,7 @@
"source": [
"# Generate a model card document in Markdown\n",
"md_path = os.path.join(model_card_dir, 'template/md/default_template.md.jinja')\n",
"md_doc = mct.export_format(template_path=md_path, output_file='model_card.md')\n",
"md_doc = toolkit.export_format(template_path=md_path, output_file='model_card.md')\n",
"\n",
"# Display the model card document in Markdown\n",
"display.display(display.Markdown(md_doc))"
Expand Down
Loading

0 comments on commit 42d1860

Please sign in to comment.