In [1]:
import tensorflow.compat.v1 as tf
import pandas as pd
import numpy as np
import random
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
tf.config.list_physical_devices('GPU')
tf.disable_v2_behavior()

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
SEED = 1
# Fix numpy seed for reproducibility
np.random.seed(SEED)
# Fix python random seed for reproducibility
random.seed(SEED)
# Fix tensorflow graph-level seed for reproducibility
tf.set_random_seed(SEED)

In [3]:
def to_categorical(y, num_classes=None, dtype='float32'):
    """Converts a class vector (integers) to binary class matrix.

    E.g. for use with categorical_crossentropy.

    # Arguments
        y: class vector to be converted into a matrix
            (integers from 0 to num_classes).
        num_classes: total number of classes.
        dtype: The data type expected by the input, as a string
            (`float32`, `float64`, `int32`...)

    # Returns
        A binary matrix representation of the input. The classes axis
        is placed last.

    # Example

    ```python
    # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}:
    > labels
    array([0, 2, 1, 2, 0])
    # `to_categorical` converts this into a matrix with as many
    # columns as there are classes. The number of rows
    # stays the same.
    > to_categorical(labels)
    array([[ 1.,  0.,  0.],
           [ 0.,  0.,  1.],
           [ 0.,  1.,  0.],
           [ 0.,  0.,  1.],
           [ 1.,  0.,  0.]], dtype=float32)
    ```
    """

    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes, )
    categorical = np.reshape(categorical, output_shape)
    return categorical

In [6]:
def data_preparation():
    column_names = [
        'age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education',
        'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code',
        'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member',
        'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses',
        'stock_dividends', 'tax_filer_stat', 'region_prev_res',
        'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', 'instance_weight',
        'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same',
        'mig_prev_sunbelt', 'num_emp', 'fam_under_18', 'country_father',
        'country_mother', 'country_self', 'citizenship', 'own_or_self',
        'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k'
    ]
    train_df = pd.read_csv('census-income.data.gz',
                           delimiter=',',
                           header=None,
                           index_col=None,
                           names=column_names)
    other_df = pd.read_csv('census-income.test.gz',
                           delimiter=',',
                           header=None,
                           index_col=None,
                           names=column_names)
    label_columns = ['income_50k', 'marital_stat']
    categorical_columns = [
        'class_worker', 'det_ind_code', 'det_occ_code', 'education',
        'hs_college', 'major_ind_code', 'major_occ_code', 'race',
        'hisp_origin', 'sex', 'union_member', 'unemp_reason',
        'full_or_part_emp', 'tax_filer_stat', 'region_prev_res',
        'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', 'mig_chg_msa',
        'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
        'fam_under_18', 'country_father', 'country_mother', 'country_self',
        'citizenship', 'vet_question'
    ]
    train_raw_labels = train_df[label_columns]
    other_raw_labels = other_df[label_columns]
    transformed_train = pd.get_dummies(train_df.drop(label_columns, axis=1),
                                       columns=categorical_columns)
    transformed_other = pd.get_dummies(other_df.drop(label_columns, axis=1),
                                       columns=categorical_columns)
    # Filling the missing column in the other set
    transformed_other[
        'det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0

    # One-hot encoding categorical labels
    train_income = to_categorical(
        (train_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    train_marital = to_categorical(
        (train_raw_labels.marital_stat == ' Never married').astype(int),
        num_classes=2)
    other_income = to_categorical(
        (other_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    other_marital = to_categorical(
        (other_raw_labels.marital_stat == ' Never married').astype(int),
        num_classes=2)
    dict_outputs = {
        'income': train_income.shape[1],
        'marital': train_marital.shape[1]
    }
    dict_train_labels = {'income': train_income, 'marital': train_marital}
    dict_other_labels = {'income': other_income, 'marital': other_marital}
    output_info = [(dict_outputs[key], key)
                   for key in sorted(dict_outputs.keys())]
    # Split the other dataset into 1:1 validation to test according to the paper
    validation_indices = transformed_other.sample(frac=0.5,
                                                  replace=False,
                                                 random_state=SEED).index
    test_indices = list(set(transformed_other.index) - set(validation_indices))
    validation_data = transformed_other.iloc[validation_indices]
    validation_label = [
        dict_other_labels[key][validation_indices]
        for key in sorted(dict_other_labels.keys())
    ]
    test_data = transformed_other.iloc[test_indices]
    test_label = [
        dict_other_labels[key][test_indices]
        for key in sorted(dict_other_labels.keys())
    ]
    train_data = transformed_train
    train_label = [
        dict_train_labels[key] for key in sorted(dict_train_labels.keys())
    ]

    return train_data, train_label, validation_data, validation_label, test_data, test_label, output_info

In [7]:
train_data, train_label, validation_data, validation_label, test_data, test_label, output_info = data_preparation(
)

In [8]:
print(train_data.shape)
print(output_info)
print(len(train_label[0]))
print(len(train_label[1]))

(199523, 499)
[(2, 'income'), (2, 'marital')]
199523
199523


In [9]:
dim = train_data.shape[1]
num_tasks = 2
expert_units = 4
expert_num = 8
tower_units = 8
lr = 0.001
batch_size = 512
epoches = 100

In [10]:
feat = tf.placeholder(tf.float32, shape=[None, dim], name='feat')
income = tf.placeholder(tf.int32, shape=[None, num_tasks], name='income')
marital = tf.placeholder(tf.int32, shape=[None, num_tasks], name='matrial')
expert_weight = tf.get_variable(
    name='expert_weight',
    dtype=tf.float32,
    shape=(dim, expert_units, expert_num),
    initializer=tf.keras.initializers.VarianceScaling())
expert_bias = tf.get_variable(name='expert_bias',
                              dtype=tf.float32,
                              shape=(expert_units, expert_num),
                              initializer=tf.zeros_initializer())
gate_weight = []
for i in range(num_tasks):
    gate_weight.append(
        tf.get_variable(name='gate_weight_{}'.format(i),
                        dtype=tf.float32,
                        shape=(dim, expert_num),
                        initializer=tf.keras.initializers.VarianceScaling()))

gate_bias = []
for i in range(num_tasks):
    gate_bias.append(
        tf.get_variable(name='gate_bias_{}'.format(i),
                        dtype=tf.float32,
                        shape=(expert_num, ),
                        initializer=tf.zeros_initializer()))

In [11]:
expert_output = tf.tensordot(feat, expert_weight, axes=1)
expert_output = tf.add(expert_output, expert_bias)
expert_output = tf.nn.relu(expert_output)  # (?, 4, 8)

gate_outputs = []
final_outputs = []
for index, gw in enumerate(gate_weight):
    gate_output = tf.matmul(feat, gw)  # (?, 8)
    gate_output = tf.add(gate_output, gate_bias[index])
    gate_output = tf.nn.softmax(gate_output)
    gate_outputs.append(gate_output)
for gate_output in gate_outputs:
    expanded_gate_output = tf.expand_dims(gate_output, axis=1)  # (?, 1, 8)
    f = tf.multiply(expert_output, expanded_gate_output)
    f = tf.reduce_sum(f, axis=2)
    f = tf.reshape(f, [-1, expert_units])
    final_outputs.append(f)
res = []
for index, output in enumerate(final_outputs):
    tower_layer = tf.layers.dense(
        output,
        units=tower_units,
        activation=tf.nn.relu,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
        name='tower_layer_{}'.format(index))
    output_layer = tf.layers.dense(
        tower_layer,
        units=2,
        activation=tf.nn.softmax,
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
        name='output_layer_{}'.format(index))  #(?, 2)
    res.append(output_layer)

Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.


In [12]:
income_loss = tf.reduce_sum(
    tf.losses.log_loss(labels=income, predictions=res[0]))
matrial_loss = tf.reduce_sum(
    tf.losses.log_loss(labels=marital, predictions=res[1]))
loss = income_loss + matrial_loss
income_value, income_auc = tf.metrics.auc(labels=tf.convert_to_tensor(income),
                            predictions=res[0],
                            name='income_auc')
marital_value,marital_auc = tf.metrics.auc(labels=tf.convert_to_tensor(marital),
                             predictions=res[1],
                             name='marital_auc')
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

Instructions for updating:
The value of AUC returned by this may race with the update so this is deprecated. Please use tf.keras.metrics.AUC instead.


In [13]:
with tf.Session() as sess:
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    batch = int(len(train_data) / batch_size)
    print(batch)
    for i in range(epoches):
        for j in range(batch):
            start = j * batch_size
            end = (j + 1) * batch_size
            end = end if len(train_data) > end else len(train_data)
            batch_train_data = train_data[start:end]
            batch_train_income_label = train_label[0][start:end]
            batch_train_marital_label = train_label[1][start:end]
            feed_dict = {
                feat: train_data[start:end],
                income: batch_train_income_label,
                marital: batch_train_marital_label
            }
            _, l, i, m, iauc, mauc = sess.run(
                [optimizer, loss, income_loss, matrial_loss, income_auc, marital_auc],
                feed_dict)
            if j % 100 == 0:
                print("loss:", l, ", income_loss: ", i, ", marital_loss: ", m,
                      ", income_auc:", sess.run(income_value), ", marital_auc:", sess.run(marital_value))

Device mapping:
/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:00:06.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device

389
loss: 1.3841994 , income_loss:  0.6909992 , marital_loss:  0.6932001 , income_auc: 0.50922585 , marital_auc: 0.49419975
loss: 0.7074101 , income_loss:  0.26065642 , marital_loss:  0.44675368 , income_auc: 0.94902354 , marital_auc: 0.78495955
loss: 0.6219766 , income_loss:  0.1994258 , marital_loss:  0.4225508 , income_auc: 0.9517562 , marital_auc: 0.8703199
loss: 0.4885751 , income_loss:  0.15432076 , marital_loss:  0.33425435 , income_auc: 0.95447534 , marital_auc: 0.9002308
loss: 0.5129094 , income_loss:  0.15129158 , marital_loss:  0.36161783 , income_auc: 0.958119 , marital_auc: 0.9168046
loss: 0.5195577 , income_loss:  0.19078228 , marital_loss:  0.32877544 , income_auc: 0.963009

loss: 0.2395574 , income_loss:  0.1433561 , marital_loss:  0.0962013 , income_auc: 0.9877332 , marital_auc: 0.98644423
loss: 0.22839208 , income_loss:  0.095217526 , marital_loss:  0.13317455 , income_auc: 0.9877752 , marital_auc: 0.98654693
loss: 0.19260009 , income_loss:  0.11088859 , marital_loss:  0.08171149 , income_auc: 0.98781407 , marital_auc: 0.9866476
loss: 0.24142537 , income_loss:  0.1250449 , marital_loss:  0.11638046 , income_auc: 0.98784256 , marital_auc: 0.9867464
loss: 0.25646746 , income_loss:  0.13899684 , marital_loss:  0.11747062 , income_auc: 0.9878851 , marital_auc: 0.98684907
loss: 0.2693476 , income_loss:  0.094779685 , marital_loss:  0.17456794 , income_auc: 0.9879227 , marital_auc: 0.9868175
loss: 0.2565527 , income_loss:  0.111193776 , marital_loss:  0.14535892 , income_auc: 0.98795587 , marital_auc: 0.9868111
loss: 0.27503738 , income_loss:  0.12448758 , marital_loss:  0.1505498 , income_auc: 0.98798084 , marital_auc: 0.9868245
loss: 0.3043188 , income_loss

loss: 0.23069508 , income_loss:  0.13951372 , marital_loss:  0.09118137 , income_auc: 0.9891415 , marital_auc: 0.98976225
loss: 0.21225664 , income_loss:  0.09024753 , marital_loss:  0.122009106 , income_auc: 0.98915386 , marital_auc: 0.98979735
loss: 0.19640672 , income_loss:  0.11096042 , marital_loss:  0.08544631 , income_auc: 0.9891652 , marital_auc: 0.9898294
loss: 0.23126368 , income_loss:  0.113588974 , marital_loss:  0.11767471 , income_auc: 0.9891714 , marital_auc: 0.9898602
loss: 0.2383698 , income_loss:  0.14481074 , marital_loss:  0.093559064 , income_auc: 0.98918366 , marital_auc: 0.9898946
loss: 0.2136713 , income_loss:  0.09033228 , marital_loss:  0.12333903 , income_auc: 0.98919606 , marital_auc: 0.98992807
loss: 0.19737321 , income_loss:  0.11075134 , marital_loss:  0.08662187 , income_auc: 0.98920655 , marital_auc: 0.9899571
loss: 0.2379618 , income_loss:  0.11754872 , marital_loss:  0.12041309 , income_auc: 0.9892123 , marital_auc: 0.9899864
loss: 0.23593396 , income

loss: 0.23428568 , income_loss:  0.14579535 , marital_loss:  0.08849033 , income_auc: 0.98965174 , marital_auc: 0.9913048
loss: 0.20934252 , income_loss:  0.08996467 , marital_loss:  0.11937785 , income_auc: 0.9896579 , marital_auc: 0.9913208
loss: 0.19532044 , income_loss:  0.10961789 , marital_loss:  0.08570255 , income_auc: 0.98966336 , marital_auc: 0.99133396
loss: 0.2406155 , income_loss:  0.11428918 , marital_loss:  0.12632632 , income_auc: 0.98966634 , marital_auc: 0.991348
loss: 0.23488522 , income_loss:  0.14338459 , marital_loss:  0.091500625 , income_auc: 0.98967326 , marital_auc: 0.99136484
loss: 0.21244714 , income_loss:  0.090920545 , marital_loss:  0.121526584 , income_auc: 0.98967916 , marital_auc: 0.99138075
loss: 0.19462585 , income_loss:  0.10944137 , marital_loss:  0.08518449 , income_auc: 0.9896844 , marital_auc: 0.99139315
loss: 0.2286559 , income_loss:  0.11592483 , marital_loss:  0.112731084 , income_auc: 0.989687 , marital_auc: 0.9914047
loss: 0.23265713 , inco

loss: 0.2337298 , income_loss:  0.14604448 , marital_loss:  0.08768532 , income_auc: 0.9899297 , marital_auc: 0.99206305
loss: 0.20372248 , income_loss:  0.08899007 , marital_loss:  0.11473241 , income_auc: 0.9899334 , marital_auc: 0.9920729
loss: 0.19393332 , income_loss:  0.108185686 , marital_loss:  0.08574764 , income_auc: 0.9899368 , marital_auc: 0.99207944
loss: 0.22807488 , income_loss:  0.1146286 , marital_loss:  0.11344628 , income_auc: 0.9899384 , marital_auc: 0.99208754
loss: 0.22896388 , income_loss:  0.14456281 , marital_loss:  0.08440108 , income_auc: 0.9899424 , marital_auc: 0.9920973
loss: 0.20323765 , income_loss:  0.08747188 , marital_loss:  0.11576578 , income_auc: 0.98994595 , marital_auc: 0.992107
loss: 0.19434854 , income_loss:  0.10812537 , marital_loss:  0.08622317 , income_auc: 0.9899492 , marital_auc: 0.99211335
loss: 0.22520599 , income_loss:  0.11339313 , marital_loss:  0.11181286 , income_auc: 0.98995054 , marital_auc: 0.9921214
loss: 0.23210816 , income_lo

loss: 0.22626331 , income_loss:  0.1447585 , marital_loss:  0.08150483 , income_auc: 0.99010694 , marital_auc: 0.9925415
loss: 0.20438105 , income_loss:  0.08679496 , marital_loss:  0.1175861 , income_auc: 0.99010944 , marital_auc: 0.9925483
loss: 0.19190851 , income_loss:  0.10778789 , marital_loss:  0.084120624 , income_auc: 0.99011165 , marital_auc: 0.9925531
loss: 0.2201227 , income_loss:  0.11386099 , marital_loss:  0.106261715 , income_auc: 0.9901123 , marital_auc: 0.99255896
loss: 0.22605291 , income_loss:  0.1445864 , marital_loss:  0.08146652 , income_auc: 0.9901153 , marital_auc: 0.99256593
loss: 0.204088 , income_loss:  0.08742665 , marital_loss:  0.11666136 , income_auc: 0.9901177 , marital_auc: 0.99257284
loss: 0.1901353 , income_loss:  0.10599379 , marital_loss:  0.08414151 , income_auc: 0.99012 , marital_auc: 0.99257755
loss: 0.22097662 , income_loss:  0.113487214 , marital_loss:  0.10748941 , income_auc: 0.9901209 , marital_auc: 0.99258316
loss: 0.23227921 , income_loss