##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model Remediation Case Study

<div class="devsite-table-wrapper"><table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/responsible_ai/model_remediation/counterfactual/guide/counterfactual_keras">
  <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
</td>
<td>
  <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/counterfactual_keras.ipynb">
  <img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a>
</td>
<td>
  <a target="_blank" href="https://github.com/tensorflow/model-remediation/blob/master/docs/counterfactual/guide/counterfactual_keras.ipynb">
  <img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a>
</td>
<td>
  <a target="_blank" href="https://storage.googleapis.com/tensorflow_docs/model-remediation/docs/counterfactual/guide/counterfactual_keras.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
</td>
  <td>
    <a href="https://tfhub.dev/"><img src="https://www.tensorflow.org/images/hub_logo_32px.png" />See TF Hub model</a>
  </td>
</table></div>

In this notebook, you’ll train a text classifier to identify written content that could be considered toxic or harmful. Counterfactual can be used to improve classifier robustness, by identifying and mitigating correlations between identity terms and the toxicity score. For example, this type of correlation was seen in the Perspective API, which uses machine learning to identify toxic comments. Perspective API takes comment text as input and returns a "score" from 0 to 1 that indicates the probability that the comment is similar to toxic comments it's seen in the past. A score of 0 signifies 0% probability that the comment is toxic, a score of 1 indicates 100% probability that the comment is toxic, and a score of 0.5 denotes a 50% probability that the comment is toxic (i.e., that the model is not sure).

After the initial launch of Perspective API, external users discovered a positive correlation between identity terms containing information on race or sexual orientation and toxicity score. For example, the phrase "I am a gay black woman" received a toxicity score of 0.87. In this case, the identity terms were not being used pejoratively, so this example was classified incorrectly. Counterfactual can be used to remediate this incorrect correlation and improve classifier precision in similar cases.

Specifically, you will:

1.   Build a baseline model and measure its performance on text containing references to gender groups.
2.   Build a counterfactual dataset and evaluate the model’s performance on flip rate and flip count to determine if Counterfactual should be applied. 
3.   Train with the Counterfactual technique to avoid unintended correlation between model output and sensitive identity terms.
4.   Evaluate the new model’s performance on the flip rate and flip count.

This tutorial demonstrates usage of the Counterfactual technique with a very minimal workflow, not to lay out a principled approach to fairness in machine learning. The Counterfactual technique is one tool in the broader [Responsible AI Toolkit](https://www.tensorflow.org/responsible_ai). You still want to evaluate the performance of your model across overall errors rates, as demonstrated in the MinDiff tutorial. You also don’t address potential shortcomings in the dataset, nor tune our configurations. In a production setting, you would want to approach each of these  fairness concerns with rigor. For more information on evaluating for fairness, see the guide for [Fairness Indicators](https://www.tensorflow.org/responsible_ai/fairness_indicators/guide).


## Setup

You begin by installing Fairness Indicators and TensorFlow Model Remediation.


In [None]:
#@title Installs
!pip install --upgrade tensorflow-model-remediation
!pip install --upgrade fairness-indicators

Import all necessary components, including Counterfactual and Fairness Indicators for evaluation.

In [None]:
#@title Imports
import os
import requests
import tempfile
import zipfile
 
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_model_analysis as tfma
from google.protobuf import text_format

 
# Import Counterfactuals.
from tensorflow_model_remediation import counterfactual

You use a [utility function](#Utility Functions) called `download_and_process_civil_comments_data` to download the preprocessed data and prepare the labels to match the model’s output shape. The function also downloads the data as TFRecords to make later evaluation quicker. Alternatively, you can convert the Pandas DataFrame into TFRecords with any available utility conversion function.

First define a few useful constants and train the model on the ’comment_text’ feature, with the target label as ’toxicity’. Note that the batch size here is chosen arbitrarily, but in a production setting you would need to tune it for best performance.

In [None]:
TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'
BATCH_SIZE = 512

In [None]:
#@title Utility Functions
np.random.seed(1)
tf.random.set_seed(1)

def download_and_process_civil_comments_data():
  """Download and process the civil comments dataset into a Pandas DataFrame."""

  # Download data.
  toxicity_data_url = 'https://storage.googleapis.com/civil_comments_dataset/'
  train_csv_file = tf.keras.utils.get_file(
      'train_df_processed.csv', toxicity_data_url + 'train_df_processed.csv')
  validate_csv_file = tf.keras.utils.get_file(
      'validate_df_processed.csv',
      toxicity_data_url + 'validate_df_processed.csv')

  # Get validation data as TFRecords.
  validate_tfrecord_file = tf.keras.utils.get_file(
      'validate_tf_processed.tfrecord',
      toxicity_data_url + 'validate_tf_processed.tfrecord')

  # Read data into Pandas DataFrame.
  data_train = pd.read_csv(train_csv_file)
  data_validate = pd.read_csv(validate_csv_file)

  # Fix type interpretation.
  data_train[TEXT_FEATURE] = data_train[TEXT_FEATURE].astype(str)
  data_validate[TEXT_FEATURE] = data_validate[TEXT_FEATURE].astype(str)

  # Shape labels to match output.
  labels_train = data_train[LABEL].values.reshape(-1, 1) * 1.0
  labels_validate = data_validate[LABEL].values.reshape(-1, 1) * 1.0

  return data_train, data_validate, validate_tfrecord_file, labels_train, labels_validate

data_train, data_validate, validate_tfrecord_file, labels_train, labels_validate = download_and_process_civil_comments_data()

def _create_embedding_layer(hub_url):
  return hub.KerasLayer(
      hub_url, output_shape=[128], input_shape=[], dtype=tf.string)
  
def create_keras_sequential_model(
    hub_url='https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1',
    cnn_filter_sizes=[128, 128, 128],
    cnn_kernel_sizes=[5, 5, 5],
    cnn_pooling_sizes=[5, 5, 40]):
  """Create baseline keras sequential model."""

  model = tf.keras.Sequential()

  # Embedding layer.
  hub_layer = _create_embedding_layer(hub_url)
  model.add(hub_layer)
  model.add(tf.keras.layers.Reshape((1, 128))) # why?

  # Convolution layers.
  for filter_size, kernel_size, pool_size in zip(cnn_filter_sizes,
                                                 cnn_kernel_sizes,
                                                 cnn_pooling_sizes):
    model.add(
        tf.keras.layers.Conv1D(
            filter_size, kernel_size, activation='relu', padding='same'))
    model.add(tf.keras.layers.MaxPooling1D(pool_size, padding='same'))

  # Flatten, fully connected, and output layers.
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(128, activation='relu'))
  model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

  return model

## Define and train the baseline model

Now, you’ll use [Counterfactual](https://arxiv.org/abs/1809.10610) to mitigate any unintended correlation between the toxicity score and identity terms. For example, if the dataset includes a text comment stating something like “she is toxic”, it could push our model to be slightly biased against women. Counterfactuals adjusts the loss between an original and counterfactual example, which will compensate for the gender bias.


In [None]:
use_pretrained_model = True #@param {type:"boolean"}

if use_pretrained_model:
 URL = 'https://storage.googleapis.com/civil_comments_model/baseline_model.zip'
 ZIPPATH = 'baseline_model.zip'
 DIRPATH = '/tmp/baseline_model'
 r = requests.get(URL, allow_redirects=True)
 open(ZIPPATH, 'wb').write(r.content)
 
 with zipfile.ZipFile(ZIPPATH, 'r') as zip_ref:
   zip_ref.extractall('/')
 baseline_model = tf.keras.models.load_model(
     DIRPATH, custom_objects={'KerasLayer' : hub.KerasLayer})
 
else:
 optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
 loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
 
 baseline_model = (
   create_keras_sequential_model())
 baseline_model.compile(optimizer=optimizer, loss=loss,      
                        metrics=['accuracy'])
 
 baseline_model.fit(x=data_train[TEXT_FEATURE],
                    y=labels_train, batch_size=BATCH_SIZE,
                    epochs=10)

To evaluate the original model's performance using Fairness Indicators you will need to save the model.

In [None]:
#@title Save Model
base_dir = tempfile.mkdtemp(prefix='saved_models')
baseline_model_location = os.path.join(base_dir, 'model_export_baseline')
baseline_model.save(baseline_model_location, save_format='tf')

## Determine if Counterfactual is needed

For the purpose of this example, you will check to see if this model is incorrectly correlating gender terms to the toxicity of a sentence. Similar to the scenario that occurred in the Perspective API, this may occur if the dataset includes a text comment stating something like “she is toxic”. Counterfactuals can be used to adjust the loss between an original and counterfactual example, which will compensate for the gender bias.

###Preparing the Counterfactual Dataset

To use Counterfactual, you will first need to create a corresponding counterfactual dataset. This dataset should be in a `tf.data.Dataset` and include only the original x values that include a different counterfactual where (`original_x`, `counterfactual_x`, `counterfactual_sample_weight`). Note that the number of datapoints in your dataset that require a counterfactual will be small. Repeating values within your counterfactual dataset is expected to match the shape of the original dataset.

To understand your options for producing a counterfactual dataset, see the Creating a Custom Counterfactual Dataset Colab.  

For this situation, you will remove a list of gender specific terms using build_counterfactual_dataset. Note that the default function within build_counterfactual_dataset uses [tf.strings.regex_replace](https://www.tensorflow.org/api_docs/python/tf/strings/regex_replace), which could remove more than you intended. For example, passing only the word “he” would change words like “the” into “t”; “there” into “tre”. Additionally since you’re passing the list of words through `create_capitalization_regex_list` to include both upper and lower case letters if the terms start at the beginning of a sentence.

In this list, we include non-pejorative terms only because pejorative terms should have a different toxicity score. Requiring equal predictions across examples with pejorative terms can accidentally harm the more vulnerable group.

In [None]:
#@title Create Counterfactual Inputs

sensitive_terms_to_remove = [
  'aunt', 'boy', 'brother', 'dad', 'daughter', 'father', 'female', 'gay',
  'girl', 'grandma', 'grandpa', 'grandson', 'grannie', 'granny', 'he',
  'heir', 'her', 'him', 'his', 'hubbies', 'hubby', 'husband', 'king',
  'knight', 'lad', 'ladies', 'lady', 'lesbian', 'lord', 'man', 'male',
  'mom', 'mother', 'mum', 'nephew', 'niece', 'prince', 'princess',
  'queen', 'queens', 'she', 'sister', 'son', 'uncle', 'waiter',
  'waitress', 'wife', 'wives', 'woman', 'women'
]

def create_capitalization_regex_list():
  return_list = []
  for term in sensitive_terms_to_remove:
    return_list.append(f'\b{term[0].upper()}{term[0].lower()}{term[1:]}\b')
  return return_list

# Convert the Pandas DataFrame to a TF Dataset
dataset_train_main = tf.data.Dataset.from_tensor_slices(
    (data_train[TEXT_FEATURE].values, labels_train)).batch(BATCH_SIZE)

counterfactual_data = counterfactual.keras.utils.build_counterfactual_dataset(
    original_dataset=dataset_train_main,
    sensitive_terms_to_remove=create_capitalization_regex_list())

counterfactual_packed_input = counterfactual.keras.utils.pack_counterfactual_data(
  dataset_train_main,
  counterfactual_data)

Now that you created counterfactual dataset with terms that have been replaced, you can pack the original dataset and counterfactual datasets together that can be passed to the CounterfactualModel. Note that `build_counterfactual_dataset` returns only the original values with the sensitive terms.

## Calculate the Flip Rate and Flip Count
Next run Fairness Indicators. As a reminder, you’re just going to calculate the flip rate and flip count to see if the model is incorrectly associating gender identity terms with toxicity.  A ‘flip’ is defined as a classifier giving a different decision when the identity term in the example changes. Flip count measures the number of times the classifier gives a different decision if the identity term in a given example were changed. Flip rate measures the probability that the classifier gives a different decision if the identity term in a given example were changed.

To compute model performance, the utility function makes a few convenient choices for metrics, slices, and classifier thresholds.

In [None]:
#@title Run Model Analysis

def get_eval_results(model_location,
                     eval_result_path,
                     validate_tfrecord_file,
                     slice_selection='gender',
                     compute_confidence_intervals=True):
  """Get Fairness Indicators eval results."""
  # Define slices that you want the evaluation to run on.
  eval_config = text_format.Parse(
      """
    model_specs {
     label_key: '%s'
   }
   metrics_specs {
     metrics {class_name: "AUC"}
     metrics {class_name: "ExampleCount"}
     metrics {class_name: "Accuracy"}
     metrics {
        class_name: "FairnessIndicators"
     }
     metrics {
        class_name: "FlipRate"
        config: '{ "counterfactual_prediction_key": "toxicity", '
                  '"example_id_key": 1 }'
     }
   }
   slicing_specs {
     feature_keys: '%s'
   }
   slicing_specs {}
   options {
       compute_confidence_intervals { value: %s }
       disabled_outputs{values: "analysis"}
   }
   """ % (LABEL,
          slice_selection, 'true' if compute_confidence_intervals else 'false'),
      tfma.EvalConfig())
  
  eval_shared_model = tfma.default_eval_shared_model(
      eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING])

  return tfma.run_model_analysis(
      eval_shared_model=eval_shared_model,
      data_location=validate_tfrecord_file,
      eval_config=eval_config,
      output_path=eval_result_path)
  
base_dir = tempfile.mkdtemp(prefix='eval')
eval_dir = os.path.join(base_dir, 'tfma_eval_result_no_cf')
base_eval_result = get_eval_results(
    baseline_model_location,
    eval_dir,
    validate_tfrecord_file,
    slice_selection='gender')

In [None]:
#@title Render Evaluation Results
tfma.addons.fairness.view.widget_view.render_fairness_indicator(
    eval_result=base_eval_result)


Let’s look at the evaluation results above. Once Fairness Indicators render select “flip_rate/overall”, which will be filtered down to gender. You’ll notice that there are four gender types within this dataset: “female”, “male”, “transgender”, and “other_gender”. For this Colab we will focus on “female” and “male” since example count is low within this dataset for the other gender. 

You’ll notice that the flip rate for females is about 13% and male about 14%, which are both higher than the overall dataset of 7%. Additionally, comparing these numbers to the total count found within “flip_rate/nagative_to_positive” and “flip_rate/positive_to_nagative” we can see that the likelihood for both male and female to incorrectly flip from negative to positive is high. In other words our model is more likely to predict the content of our text is toxic if it includes gender terms.

You'll not use Counterfactual remediation to try to reduce the flip rate and count for gender related terms in our dataset.

###Training and Evaluating the Counterfactual Model

To train with Counterfactual, simply take the original model and wrap it in a CounterfactualModel with a corresponding `loss` and `loss_weight`. You will start out using 1.5 as the default `loss_weight`, but this is a parameter that can be tuned for your use case, since it depends on your model and product requirements.  

Next compile the model normally (using the regular non-Counterfactual loss) and fit to train.

In [None]:
#@title Train Model

counterfactual_weight = 1.5 #@param {type:"number"}
 
base_dir = tempfile.mkdtemp(prefix='saved_models')
counterfactual_model_location = os.path.join(
    base_dir, 'model_export_counterfactual')
 
counterfactual_model = counterfactual.keras.CounterfactualModel(
    baseline_model)
 
# Compile the model normally after wrapping the original model.
# Note that this means we use the baseline's model's loss here.
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
counterfactual_model.compile(optimizer=optimizer, loss=loss, 
                             metrics=['accuracy'])
 
counterfactual_model.fit(counterfactual_packed_input,
                         epochs=10)
 
counterfactual_model.save_original_model(counterfactual_model_location,
                                         save_format='tf')

Next evaluate the results with the Counterfactual model:

In [None]:
def get_eval_results_counterfactual(
                     baseline_model_location,
                     counterfactual_model_location,
                     eval_result_path,
                     validate_tfrecord_file,
                     slice_selection='gender',
                     compute_confidence_intervals=True):
  """Get Fairness Indicators eval results."""
  # Define slices that you want the evaluation to run on.
  eval_config = text_format.Parse(
      """
    model_specs {
     name: 'original'
     label_key: '%s'
   }
   model_specs {
     name: 'counterfactual'
     label_key: '%s'
     is_baseline: true
   }
   metrics_specs {
     metrics {class_name: "AUC"}
     metrics {class_name: "ExampleCount"}
     metrics {class_name: "Accuracy"}
     metrics {
        class_name: "FairnessIndicators"
     }
     metrics {
        class_name: "FlipRate"
        config: '{ "example_ids_count": 0 }'
     }
     metrics {
        class_name: "FlipCount"
        config: '{ "example_ids_count": 0 }'
     }
   }
   slicing_specs {
     feature_keys: '%s'
   }
   slicing_specs {}
   options {
       disabled_outputs{ values: "analysis"}
   }
   """ % (LABEL, LABEL, slice_selection,),
      tfma.EvalConfig())

  eval_shared_models = [
      tfma.default_eval_shared_model(
          model_name='original',
          eval_saved_model_path=baseline_model_location,
          eval_config=eval_config,
          tags=[tf.saved_model.SERVING]),
      tfma.default_eval_shared_model(
          model_name='counterfactual',
          eval_saved_model_path=counterfactual_model_location,
          eval_config=eval_config,
          tags=[tf.saved_model.SERVING]),
    ]
  
  return tfma.run_model_analysis(
      eval_shared_model=eval_shared_models,
      data_location=validate_tfrecord_file,
      eval_config=eval_config,
      output_path=eval_result_path)
 
counterfactual_eval_dir = os.path.join(base_dir, 'tfma_eval_result_cf') 
counterfactual_eval_result = get_eval_results_counterfactual(
  baseline_model_location,
  counterfactual_model_location,
  counterfactual_eval_dir,
  validate_tfrecord_file)

In [None]:
#@title Render Evaluation Results
 
counterfactual_model_comparison_results = {
    'base_model': base_eval_result,
    'counterfactual': counterfactual_eval_result.get_results()[0],
}
tfma.addons.fairness.view.widget_view.render_fairness_indicator(
    multi_eval_results=counterfactual_model_comparison_results
)

Now you’ll evaluate the Counterfactual model by passing both the original and counterfactual into Fairness Indicators together to get a side-by-side comparison. Once again select “flip_rate/overall” and compare the results for female and male  between the two models. You should notice that the overall flip rate for overall, female, and male have all decreased by about 85% whilst female at approximately 1.5% and male at approximately 2%. 

Additionally, reviewing “flip_rate/nagative_to_positive” and “flip_rate/positive_to_nagative” again you’ll notice that our model is still more likely to flip gender related content to toxic, but the total count has decreased by over 90%


You’ll notice that the flip rate for females is about 13% and male about 14%, which are both higher than the overall dataset of 7%. Additionally, comparing these numbers to the total count found within “flip_rate/nagative_to_positive” and “flip_rate/positive_to_nagative” we can see that the likelihood for both male and female to incorrectly flip from negative to positive is high. In other words our model is more likely to predict the content of our text is toxic if it includes gender terms
