##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Creating a custom Counterfactual Dataset

<div class="devsite-table-wrapper"><table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/responsible_ai/model_remediation/counterfactual/guide/creating_a_custom_counterfactual_dataset">
  <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
  <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb">
  <img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a>
</td>
<td>
  <a target="_blank" href="https://github.com/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb">
  <img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a>
</td>
<td>
  <a target="_blank" href="https://storage.googleapis.com/tensorflow_docs/model-remediation/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
</td>
  <td>
    <a href="https://tfhub.dev/"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub model</a>
  </td>
</table></div>

## Introduction

To create a Counterfactual model, you need to provide an instance of `CounterfactualPackedInputs` that contains the `original_input` and `counterfactual_data`. `CounterfactualPackedInputs` looks like the following:
```
CounterfactualPackedInputs(
  original_inputdataset=(x, y, sample_weight),
  counterfactual_data:(original_x, counterfactual_x,
                       counterfactual_sample_weight)
)
```
The `original_input` should be the original dataset that is used to train your Keras model. `counterfactual_data` should be a TensorFlow Dataset with the original `x` value, the corresponding `counterfactual_x”` value, and the `counterfactual_sample_weight”`.The ‘counterfactual_x’ value is nearly identical to the original value but with one or more of the sensitive attributes removed or replaced. This dataset is used to pair the loss function between the original value and the counterfactual value with the goal of assuring that the model’s prediction doesn’t change when the sensitive attribute is different. `original_input` and `counterfactual_data` need to be the same shape. You can duplicate values of `counterfactual_data` so that it’s the same shape as `original_input`. 

`counterfactual_data` needs to:
*   ‘counterfactual_x’ value is nearly identical to the original value but with one or more of the sensitive attributes removed or replaced
*   Be the same shape as `original_input` (you can duplicate values so that they’re the same shape)

`counterfactual_data` does not need to:
*   Have overlap with data within `original_input` 
*   Have ground truth labels 

Here’s an example of what a `counterfactual_data` would look like, where you replace “he” with “she”:

```
original_x: “He is in the UK”
counterfactual_x: “She is in the UK” 
counterfactual_sample_weight”: 1

original_x: “He went to school”
counterfactual_x”: “She went to school” 
counterfactual_sample_weight”: 1
```
If you have a text classifier, you can use `build_counterfactual_data` to help create a counterfactual dataset. For all other data types, you need to provide a counterfactual dataset directly. 


## Setup

You'll begin by installing TensorFlow Model Remediation.


In [None]:
!pip install --upgrade tensorflow-model-remediation

In [None]:
#@title Imports
import numpy as np
import tensorflow as tf
from tensorflow_model_remediation import counterfactual

## Create a simple Dataset

For demonstrative purposes, we’ll create `counterfactual_data` from the `original_input` using `build_counterfactual_dataset`. Note that you can also construct `counterfactual_data` from unlabeled data (as opposed to constructing it from `original_input`). You will create a simple dataset with one sentence: “He is a doctor and she is a nurse” which will serve as the `original_input`.

If a model is only given this context, it will likely relate males to being a doctor and females to being a nurse. You could remove the gender terms to avoid that association by creating an instance of `counterfactual_data` with the sentence, “is a doctor and is a nurse”.

Note: The dataset created in this tutorial contains only a limited set of terms for describing gender. Further this tutorial only demonstrates the steps for creating an instance of `counterfactual_data` and does not represent a real-world use case.

## Build a `counterfactual_data` 
`build_counterfactual_dataset` can be used in the following ways:


1.   **Remove terms**: Use build_counterfactual_dataset to pass a list of words that will be removed from the dataset via tf.strings.regex_replace.
2.   **Replace terms**: Create a custom function to pass to build_counterfactual_dataset. This might include using more specific regex functions to replace words within your original dataset or support non-text features.

`build_counterfactual_dataset` takes in `original_input` and either removes or replaces terms with regex (as described above). It will filter `original_input` so it only contains text with the identity terms that are being removed or replaced. Then, it will reshape `counterfactual_data` so it’s the same size as `original_input` by duplicating the terms.

### Option 1: List of Words to Remove
Pass in a list of gender-related terms to remove with`build_counterfactual_data`.

When using simple regex to create `counterfactual_data`, keep in mind that this may augment words that shouldn’t be changed. A simple regex removal of a word such as “he” could change a number of different words in the model.For example “theory” will become “tory” and “when” will become “wn”. It is good practice to check that the changes made to the `counterfactual_x` value make sense in the context of the `orginal_x` value.

In [None]:
#@title Create Counterfactual Inputs

simple_dataset_x = tf.constant(
    ["He is a doctor and she is a nurse" + str(i) for i in range(10)])
simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['she', 'He'])

# Inspect the content of the TF Counterfactual Dataset
np.stack(list(counterfactual_data))

### Option 2: Custom Function  

For more flexibility around ways of modifying your original dataset, you can instead pass a custom function to `build_counterfactual_data`. 

In the example, you can consider replacing identity terms that reference the male gender with those that reference the female gender. This can be done by writing a function to replace a dictionary of words. 
 
Note that the only limitation on the custom function is that it must be a callable to accept and return a tuple in the format used in [`Model.fit`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit).


In [None]:
words_to_replace = {"He": "She"}

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for keys in words_to_replace:
    original_x = tf.strings.regex_replace(
        original_x, keys, words_to_replace[keys])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, sample_weight=original_sample_weight)
    
counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
np.stack(list(counterfactual_data))