##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Creating a custom Counterfactual Dataset

<div class="devsite-table-wrapper"><table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/responsible_ai/model_remediation/counterfactual/guide/creating_a_custom_counterfactual_dataset">
  <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
  <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb">
  <img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a>
</td>
<td>
  <a target="_blank" href="https://github.com/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb">
  <img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a>
</td>
<td>
  <a target="_blank" href="https://storage.googleapis.com/tensorflow_docs/model-remediation/docs/counterfactual/guide/creating_a_custom_counterfactual_dataset.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
</td>
  <td>
    <a href="https://tfhub.dev/"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub model</a>
  </td>
</table></div>

## Introduction
To use Counterfactuals, you’ll need to develop a counterfactual dataset that can be used for pairing the loss function between the original value and a counterfactual value.

Unlike your training dataset, you don’t need to include labels in your counterfactual dataset. However, the data must be the same as the training dataset, with sensitive terms replaced or removed (depending on the type of correlation you are aiming to test). You also need to ensure that there are a sufficient number of sensitive attributes to get a meaningful result. 

There are three options available to develop a counterfactual dataset that are listed below. 

1.   Use build_counterfactual_dataset to pass a list of words that will be removed from the dataset via tf.strings.regex_replace.
2.   Create a custom function to pass to build_counterfactual_dataset to change the counterfactual. This might include using more specific regex functions or replacing words within your original dataset.
3.   Provide your own Counterfactual dataset that follows the same tuple structure as CounterfactualPackedInputs.

When using simple regex to create the counterfactual dataset, keep in mind that this may augment words that shouldn’t be changed, which could impact performance. It is important to understand your original dataset along with the counterfactual dataset to help produce an appropriate pairing for your model.

**Note**: The dataset created in this tutorial contains only a limited set of terms for describing gender. Further this tutorial only demonstrates the steps for creating a counterfactual dataset and does not represent a real-world use case. 


## Setup

We begin by installing TensorFlow Model Remediation.


In [None]:
!pip install --upgrade tensorflow-model-remediation

In [None]:
#@title Imports
import numpy as np
import tensorflow as tf
from tensorflow_model_remediation import counterfactual

## Create a simple TF Dataset

For the purposes of this Colab, we will create a simple dataset with one sentence: “He is a doctor and she is a nurse”.

If a model is only given this context, it will likely relate males to being a doctor and females to being a nurse. You could remove the gender terms to avoid that association by creating a counterfactual dataset containing the phrase, “is a doctor and is a nurse”. 
 
### Option 1: List of Words of Remove 

In order to use the build_counterfactual_dataset, you will first come up with a list of gender specific terms to remove with`build_counterfactual_dataset`

Note that the words you’re passing include spaces since it could remove more content than what would be ideal. For example, passing only the word “he” would change words like “the” into “t”; “there” into “tre”. .


In [None]:
#@title Create Counterfactual Inputs

simple_dataset_x = tf.constant(
    ["He is a doctor and she is a nurse" + str(i) for i in range(10)])
simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_dataset(
    original_dataset=simple_dataset,
    sensitive_terms_to_remove=['she', 'He'])

# Inspect the content of the TF Counterfactual Dataset
np.stack(list(counterfactual_data))

### Option 2: Custom Function  

For more flexibility around ways of modifying your original dataset, you can instead pass a custom function to `build_counterfactual_dataset`. 

In the example, you can consider replacing identity terms that reference the male gender with those that reference the female gender. This can be done by writing a function to replace a dictionary of words. 
 
Note that the only limitation on the custom function is that it must be a callable to accept and return a tuple in the format used in `Model.fit`.


In [None]:
words_to_replace = {"He": "She"}

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for keys in words_to_replace:
    original_x = tf.strings.regex_replace(
        original_x, keys, words_to_replace[keys])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, sample_weight=original_sample_weight)
    
counterfactual_data = counterfactual.keras.utils.build_counterfactual_dataset(
    original_dataset=simple_dataset,
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
np.stack(list(counterfactual_data))

### Option 3: Provide a Dataset 

Build a `CounterfactualPackedInputs` dataset directly to pass to the Counterfactual model.