Skip to content

Commit

Permalink
Introduces support for PyTorch only. (#4)
Browse files Browse the repository at this point in the history
* Remove support for TensorBoard causing error with protobuf

* Change loading/saving ONNX model to PyTorch state dict

* Update docs

* Add return type in create_explainer func and fix object type from kwargs in explainers

* Bump minor version
  • Loading branch information
adamwawrzynski committed Dec 21, 2022
1 parent 270fb4d commit f75f5ed
Show file tree
Hide file tree
Showing 21 changed files with 166 additions and 1,014 deletions.
84 changes: 40 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
# AutoXAI

## Requirements
# Requirements

To separate runtime environments for different services and repositories, it is
recommended to use a virtual Python environment, e.g. `virtualenv`. After
installing it, create a new environment and activate it. The project uses Python
version `3.10`.
The project was tested using Python version `3.8`.

In the example below, the `-p` parameter specifies the Python version
and the second parameter the name of the virtual environment, for
example `env`.
## Poetry

To separate runtime environments for different services and repositories, it is
recommended to use a virtual Python environment. You can configure `Poetry` to
create new virtual environment in project directory of every repository. To
install `Poetry` follow instruction at https://python-poetry.org/docs/#installing-with-the-official-installer. We are using `Poetry` in version
`1.2.1`. To install specific version You have to provide desired package
version:
```bash
virtualenv -p python3.10 .venv
source .venv/bin/activate
curl -sSL https://install.python-poetry.org | POETRY_VERSION=1.2.1 python3 -
```

The project uses the `poetry` package, which manages the dependencies in the
project. To install it first update the `pip` package and then install `poetry`
version `1.2.1`.

After installation configure creation of virtual environments in directory
of project.
```bash
python -m pip install --upgrade pip
poetry config virtualenvs.create true
poetry config virtualenvs.in-project true
```

Instructions for installation of `poetry`:
https://python-poetry.org/docs/#installing-with-the-official-installer.

The final step is to install all the dependencies defined in the
`pyproject.toml` file.

Expand All @@ -35,6 +31,18 @@ poetry install
```

Once all the steps have been completed, the environment is ready to go.
Virtual environment by default will be created with name `.venv` inside
project directory.

### Installation errors

If You encounter errors during dependencies installation You can disable
parallel installer, remove current virtual environment and remove `artifacts`
and `cache` directories from `poetry` root directory (by default is under
`/home/<user>/.cache/pypoetry/`). To disable parallel installer run:
```bash
poetry config installer.parallel false
```

## Pre-commit hooks setup

Expand All @@ -43,13 +51,9 @@ our [pre-commit][https://pre-commit.com/] hooks as the very first step after
cloning the repository:

```bash
poetry install
poetry run pre-commit install
```

pre-commit: https://pre-commit.com/


# Note

At the moment only explainable algorithms for image classification are
Expand All @@ -71,27 +75,19 @@ cache_directory/
└── <date>
├── <uuid>
│ ├── data
│ │ ├── input_data
│ │ │ └── <data>.pkl
│ │ ├── normalized
│ │ │ └── <data>.pkl
│ │ ├── original
│ │ │ └── <data>.pkl
│ │ └── predictions
│ │ └── <data>.pkl
│ ├── explanations
│ │ ├── <method1>
│ │ │ └── figures
│ │ │ ├── attributes.npy
│ │ │ └── params.json.pkl
│ │ ├── <method2>
│ │ │ └── figures
│ │ │ ├── attributes.npy
│ │ │ └── params.json.pkl
│ │ ├── ...
└── ...
│ │ ├── <data>.pkl
| | └─── ...
│ ├── labels
│ │ └── idx_to_label.json.pkl
| └── training
| ├── <epoch>
| | └── model.pt
... ...
```

Another part of this module is GUI interface to view explanations and
modify parameters of explainable algorithms. As a PoC application in
`streamlit` is developed.
## Examples

In `example/streamlit_app/` directory You can find sample application with
simple GUI to present interactive explanations of given models.
Scripts in `example/` directory contain samples of training models using
different callbacks.
2 changes: 1 addition & 1 deletion example/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import uuid

import torch
from mnist_model import LitMNIST
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from streamlit_app.mnist_model import LitMNIST

from src.cache_manager import LocalDirCacheManager
from src.callback import CustomPytorchLightningCallback
Expand Down
51 changes: 0 additions & 51 deletions example/mnist_tensorboard.py

This file was deleted.

11 changes: 2 additions & 9 deletions example/offline.py → example/offline_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,12 @@ def main():
input_data,
)

path: str = os.path.join(experiment.path, "training", "0", "model.onnx")
path: str = os.path.join(experiment.path, "training", "0", "model.pt")

if not os.path.exists(path):
os.makedirs(Path(path).parent)

torch.onnx.export(
model,
input_data,
path,
verbose=True,
input_names=["conv1"],
output_names=["output1"],
)
torch.save(model.state_dict(), path)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions example/streamlit_app/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ application, the labels directory with the `idx_to_label.json.pkl` file,
which contains the JSON with the index-class mapping, and the training
directory, which contains directories corresponding to the training epoch
number, in which the models are stored, always with the same name
`model.onnx` in ONNX format.
`model.pt` in PyTorch state dict format.

Example log directory structure:
```bash
Expand All @@ -36,6 +36,6 @@ logs/
| └── idx_to_label.json.pkl
└── training
└── 0
└── model.onnx
└── model.pt

```
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(
torch.nn.Linear(hidden_size, self.num_classes),
)

mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
mnist_full = MNIST(
self.data_dir, train=True, download=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
Expand Down
23 changes: 11 additions & 12 deletions example/streamlit_app/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

from typing import List

import onnx
import onnx2torch
import torch
from torch import fx
from mnist_model import LitMNIST # pylint: disable = (import-error)


def load_model(model_path: str) -> fx.GraphModule:
"""Load model from local path and convert in into torch.fx.GraphModule.
def load_model(model_path: str) -> torch.nn.Module:
"""Load model's state dict from local path.
Args:
model_path: Path to local ONNX model.
model_path: Path to local model's state dict.
Returns:
Converted ONNX model to torch.fx.GraphModule.
Model with loaded state dict.
"""
return onnx2torch.convert(onnx.load(model_path))
model = LitMNIST(batch_size=1, data_dir=".")
model.load_state_dict(torch.load(model_path))
return model


def get_model_layers(model: fx.GraphModule) -> List[torch.nn.Module]:
def get_model_layers(model: torch.nn.Module) -> List[torch.nn.Module]:
"""Get all layers from given model.
Args:
Expand All @@ -31,8 +31,7 @@ def get_model_layers(model: fx.GraphModule) -> List[torch.nn.Module]:
"""
layers = []
for module in model.modules():
if not isinstance(module, fx.graph_module.GraphModule):
if isinstance(module, torch.nn.Conv2d):
layers.append(module)
if isinstance(module, torch.nn.Conv2d):
layers.append(module)

return layers
17 changes: 9 additions & 8 deletions example/streamlit_app/run_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

import numpy as np
import streamlit as st
from method_names import MethodName
from model_utils import get_model_layers, load_model
from settings import Settings
from streamlit_utils import (
from method_names import MethodName # pylint: disable = (import-error)
from model_utils import get_model_layers, load_model # pylint: disable = (import-error)
from settings import Settings # pylint: disable = (import-error)
from streamlit_utils import ( # pylint: disable = (import-error)
disable_explain,
initialize_session_state,
load_idx_to_labels,
load_input_data,
load_original_data,
load_subdir,
)
from torch import fx
from visualization_utils import convert_figure_to_numpy
from visualization_utils import ( # pylint: disable = (import-error)
convert_figure_to_numpy,
)

from src.explainer.base_explainer import CVExplainer
from src.explainer.gradcam import GuidedGradCAMCVExplainer, LayerGradCAMCVExplainer
Expand Down Expand Up @@ -274,10 +275,10 @@ def main_view() -> None:
hash_selectbox,
"training",
epoch_number,
"model.onnx",
"model.pt",
)

model: fx.GraphModule = load_model(model_path=model_path)
model = load_model(model_path=model_path)
model_layers = get_model_layers(model=model)
st.session_state[Settings.model_layers_key] = model_layers
method_string = st.session_state[Settings.method_label]
Expand Down
4 changes: 2 additions & 2 deletions example/streamlit_app/streamlit_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""File contains functions """
"""File contains functions to manipulate st.session_state and cache."""

import os
from typing import Any, Dict, List, Union, cast

import streamlit as st
import torch
from settings import Settings
from settings import Settings # pylint: disable = (import-error)

from src.cache_manager import LocalDirCacheManager

Expand Down
Loading

0 comments on commit f75f5ed

Please sign in to comment.