Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wandb support #16

Merged
merged 20 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
files: autoxai
types: [ python ]
entry: poetry run mypy
args: ["--ignore-missing-imports"]
- id: locks
name: Update locks
files: pyproject.toml
Expand Down
118 changes: 113 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,114 @@
# AutoXAI

# Requirements
AutoXAI simplifies the application of e**X**plainable **AI** algorithms to explain the
performance of neural network models during training. The library acts as an
aggregator of existing libraries with implementations of various XAI algorithms and
seeks to facilitate and popularize their use in machine learning projects.

Currently, only algorithms related to computer vision are supported, but we plan to
add support for text, tabular and multimodal data problems in the future.

## Table of content:
* [Installation](#installation)
* [GPU acceleration](#gpu-acceleration)
* [Manual installation](#manual-installation)
* [Getting started](#getting-started)
* [Development](#development)
* [Requirements](#requirements)
* [CUDA](#cuda)
* [Poetry](#poetry)
* [pyenv](#pyenv)
* [Installation errors](#installation-errors)
* [pre-commit hooks](#pre-commit-hooks-setup)
* [Note](#note)
* [Artifacts directory structure](#artifacts-directory-structure)
* [Examples](#examples)

# Installation

Installation requirements:
* `Python` >= 3.8 & < 3.11
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved

## GPU acceleration

In order to use the torch library with GPU acceleration, you need to install
a dedicated version of torch with support for the installed version of CUDA
drivers in the version supported by the library, at the moment `torch==1.12.1`.
List of `torch` wheels with CUDA support can be found at
[https://download.pytorch.org/whl/torch/](https://download.pytorch.org/whl/torch/).

## Manual installation

If you would like to install from source you can build `wheel` package using `poetry`.
The assumption is that the `poetry` package is installed. You can find how to install
`poetry` [here](#poetry). To build `wheel` package run:

```bash
git clone https://github.com/softwaremill/AutoXAI.git
cd AutoXAI/
poetry install
poetry build
```

As a result you will get `wheel` file inside `dist/` directory that you can install
via `pip`:
```bash
pip install dist/autoxai-0.3.1-py3-none-any.whl
```

# Getting started

To use the AutoXAI library in your ML project, simply add an additional object of type
`WandBCallback` to the `Trainer`'s callback list from the `pytorch-lightning` library.
Currently, only the Weights and Biases tool for tracking experiments is supported.

Below is a code snippet from the example (`example/mnist_wandb.py`):

```python
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

import wandb
from autoxai.callbacks.wandb_callback import WandBCallback
from autoxai.explainer.gradient_shap import GradientSHAPCVExplainer
from autoxai.explainer.integrated_gradients import IntegratedGradientsCVExplainer

...
wandb.login()
wandb_logger = WandbLogger(project=project_name, log_model="all")
callback = WandBCallback(
wandb_logger=wandb_logger,
explainers=[
IntegratedGradientsCVExplainer(),
GradientSHAPCVExplainer(),
],
idx_to_label={index: index for index in range(0, 10)},
)
model = LitMNIST()
trainer = Trainer(
accelerator="gpu",
devices=1 if torch.cuda.is_available() else None,
max_epochs=max_epochs,
logger=wandb_logger,
callbacks=[callback],
)
trainer.fit(model)
```

## CLI

A CLI tool is available to update the artifacts of an experiment tracked in
Weights and Biases. Allows you to create XAI explanations and send them to
W&B offline. To check options type:

```bash
autoxai-wandb-updater --help
```

# Development

## Requirements

The project was tested using Python version `3.8`.

Expand Down Expand Up @@ -89,14 +197,14 @@ cloning the repository:
poetry run pre-commit install
```

# Note

## Note
---
At the moment only explainable algorithms for image classification are
implemented to test design of the architecture. In future more algorithms
and more computer vision tasks will be introduces. In the end module should
work with all types of tasks (NLP, etc.).

## Architecture
### Artifacts directory structure

Module is designed to operate in two modes: offline and online. In offline
mode user can explain already trained model against test data. In online
Expand All @@ -120,7 +228,7 @@ cache_directory/
... ...
```

## Examples
### Examples

In `example/streamlit_app/` directory You can find sample application with
simple GUI to present interactive explanations of given models.
Expand Down
4 changes: 2 additions & 2 deletions autoxai/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def on_sanity_check_end(

logger.info("Saving all validation samples to data directory.")
index: int = 0
item: torch.Tensor
for dataloader in trainer.val_dataloaders:
for batch in dataloader:
items, predictions = batch
for item, _ in zip(items, predictions):
for item, _ in zip(*batch):
self.cache_manager.save_artifact(
path=os.path.join(self.experiment.path_to_data, str(index)),
obj=item,
Expand Down
Empty file added autoxai/callbacks/__init__.py
Empty file.
231 changes: 231 additions & 0 deletions autoxai/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Callback for Weights and Biases."""
from collections import defaultdict
from typing import Dict, Generator, List, Optional, Tuple

import matplotlib
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

import wandb
from autoxai.context_manager import AutoXaiExplainer, ExplainerWithParams
from autoxai.explainer.base_explainer import CVExplainer

AttributeMapType = Dict[str, List[torch.Tensor]]
CaptionMapType = Dict[str, List[str]]
FigureMapType = Dict[str, List[matplotlib.pyplot.Figure]]


class WandBCallback(pl.callbacks.Callback):
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved
"""Library callback for Weights and Biases."""

def __init__( # pylint: disable = (too-many-arguments)
self,
wandb_logger: WandbLogger,
explainers: List[ExplainerWithParams],
idx_to_label: Dict[int, str],
max_artifacts: int = 3,
):
"""Initialize Callback class.

Args:
wandb_logger: Pytorch-lightning wandb logger.
idx_to_label: Index to label mapping.
explainers: List of explainer algorithms of type ExplainerWithParams.
idx_to_label: Dictionary with mapping from model index to label.
max_artifacts: Number of maximum number of artifacts to be logged.
Defaults to 3.
"""
super().__init__()
self.explainers = explainers
self.wandb_logger = wandb_logger
self.idx_to_label = idx_to_label
self.max_artifacts = max_artifacts

def _save_idx_mapping(self) -> None:
"""Saving index to label mapping to experiment logs directory."""
self.wandb_logger.log_table(
key="idx2label",
columns=["index", "label"],
data=[[key, val] for key, val in self.idx_to_label.items()],
)

def iterate_dataloader(
self, dataloader_list: List[DataLoader], max_items: int
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
"""Iterate over dataloader list with constraint on max items returned.

Args:
dataloader: Trainer dataloader.
max_items: Max items to return.

Yields:
Tuple containing training sample and corresponding label.
"""
index: int = 0
dataloader: DataLoader
item: torch.Tensor
target_label: torch.Tensor
for dataloader in dataloader_list:
for batch in dataloader:
for item, target_label in zip(*batch):
if index >= max_items:
break

index += 1
yield item, target_label

def explain( # pylint: disable = (too-many-arguments)
self,
model: pl.LightningModule,
item: torch.Tensor,
target_label: torch.Tensor,
attributes_dict: AttributeMapType,
caption_dict: CaptionMapType,
figures_dict: AttributeMapType,
) -> Tuple[AttributeMapType, CaptionMapType, AttributeMapType,]:
"""Calculate explainer attributes, creates captions and figures.

Args:
model: Model to explain.
item: Input data sample tensor.
target_label: Sample label.
attributes_dict: List of attributes for every explainer and sample.
caption_dict: List of captions for every explainer and sample.
figures_dict: List of figures for every explainer and sample.

Returns:
Tuple of maps containing attributes, captions and figures for
every explainer and sample.
"""
with AutoXaiExplainer(
model=model,
explainers=self.explainers,
target=int(target_label.item()),
) as xai_model:
_, attributes = xai_model(item.to(model.device))

for explainer in self.explainers:
explainer_name: str = explainer.explainer_name.name
explainer_attributes: torch.Tensor = attributes[explainer_name]
attributes_dict[explainer_name].append(explainer_attributes)
caption_dict[explainer_name].append(f"label: {target_label}")
figure = CVExplainer.visualize(
attributions=explainer_attributes,
transformed_img=item,
)
figures_dict[explainer_name].append(figure)

return attributes_dict, caption_dict, figures_dict

def on_train_start(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule, # pylint: disable = (unused-argument)
) -> None:
"""Save index to labels mapping and validation samples to experiment
at `fit`.

Args:
trainer: Trainer object.
pl_module: Model to explain.
"""
if trainer.val_dataloaders is None:
return

self._save_idx_mapping()

image_matrix: Optional[torch.Tensor] = None
image_labels: List[str] = []

for item, target_label in self.iterate_dataloader(
dataloader_list=trainer.val_dataloaders, max_items=self.max_artifacts
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved
):
if image_matrix is None:
image_matrix = item
else:
image_matrix = torch.cat( # pylint: disable = (no-member)
[image_matrix, item]
)

image_labels.append(f"label: {target_label.item()}")

if image_matrix is None:
return

list_of_images: List[torch.Tensor] = list(torch.split(image_matrix, 1))
self.wandb_logger.log_image(
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved
key="validation_data",
images=list_of_images[: min(len(list_of_images), self.max_artifacts)],
caption=image_labels[: min(len(image_labels), self.max_artifacts)],
)

def on_validation_epoch_end( # pylint: disable = (too-many-arguments, too-many-locals)
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
) -> None:
"""Export model's state dict in log directory on validation epoch end.

Args:
trainer: Trainer object.
pl_module: Model to explain.
"""
if trainer.val_dataloaders is None:
return

attributes_dict: AttributeMapType = defaultdict(list)
caption_dict: CaptionMapType = defaultdict(list)
figures_dict: AttributeMapType = defaultdict(list)

for item, target_label in self.iterate_dataloader(
dataloader_list=trainer.val_dataloaders,
max_items=self.max_artifacts,
):
attributes_dict, caption_dict, figures_dict = self.explain(
model=pl_module,
item=item,
target_label=target_label,
attributes_dict=attributes_dict,
caption_dict=caption_dict,
figures_dict=figures_dict,
)

self.log_explanations(
attributes_dict=attributes_dict,
caption_dict=caption_dict,
figures_dict=figures_dict,
)

def log_explanations(
self,
attributes_dict: AttributeMapType,
caption_dict: CaptionMapType,
figures_dict: AttributeMapType,
) -> None:
"""Log explanation artifacts to W&B experiment.

Args:
attributes_dict: Tensor attributes for every sample and every explainer.
caption_dict: Caption for every sample and every explainer.
figures_dict: Figure with attributes for every sample and every explainer.
"""
# upload artifacts to the wandb experiment
for explainer in self.explainers:
explainer_name: str = explainer.explainer_name.name
self.wandb_logger.log_image(
key=f"{explainer_name}",
images=[val.numpy() for val in attributes_dict[explainer_name]],
caption=caption_dict[explainer_name],
)

# matplotlib Figures can not be directly logged via WandbLogger
# we have to use native Run object from wandb which is more powerfull
wandb_image_list: List[wandb.Image] = []
for figure in figures_dict[explainer_name]:
wandb_image_list.append(wandb.Image(figure))

self.wandb_logger.experiment.log(
{f"{explainer_name}_explanations": wandb_image_list}
)
Empty file added autoxai/cli/__init__.py
Empty file.
Loading