In [None]:
#base on code from https://github.com/floft/codats/blob/294ccca16e545f5e51ca9d3cb971fcab1625b5a3/models.py#L1086

In [1]:
"""
TensorFlow 2.0 implementation of VRNN
Based on my 1.x implementation:
https://github.com/floft/deep-activity-learning/blob/tf_1.x/vrnn.py
"""
import tensorflow as tf


class VRNN(tf.keras.layers.Layer):
    """ Wrap VRNNCell into a RNN """
    def __init__(self, h_dim, z_dim, return_z=True, return_sequences=False,
            go_backwards=False, stateful=False, unroll=False, **kwargs):
        super().__init__(**kwargs)
        self.return_z = return_z
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.return_sequences = return_sequences
        self.go_backwards = go_backwards
        self.stateful = stateful
        self.unroll = unroll

    def build(self, input_shape):
        num_features = input_shape[-1]
        cell = VRNNCell(num_features, self.h_dim, self.z_dim)
        # We return sequences here so we can compute VRNN reconstruction loss
        self.rnn = tf.keras.layers.RNN(cell,
            return_sequences=True,
            return_state=False, go_backwards=self.go_backwards,
            stateful=self.stateful, unroll=self.unroll)

    def call(self, inputs, **kwargs):
        outputs = self.rnn(inputs, **kwargs)

        h, c, \
            encoder_mu, encoder_sigma, \
            decoder_mu, decoder_sigma, \
            prior_mu, prior_sigma, \
            x_1, z_1 = outputs

        # VRADA uses z not h
        if self.return_z:
            rnn_output = z_1
        else:
            rnn_output = h

        # Get the output at the end of the sequence
        if not self.return_sequences:
            rnn_output = rnn_output[:, -1]

        # For use in loss, note these are return_sequences=True
        other_outputs = [encoder_mu, encoder_sigma, decoder_mu, decoder_sigma,
            prior_mu, prior_sigma]

        return rnn_output, other_outputs


class VRNNCell(tf.keras.layers.Layer):
    """
    VRNN cell implementation for use in VRADA
    Based on:
    - https://github.com/phreeza/tensorflow-vrnn/blob/master/model_vrnn.py
    - https://github.com/kimkilho/tensorflow-vrnn/blob/master/cell.py
    - https://github.com/kimkilho/tensorflow-vrnn/blob/master/main.py
    - https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/rnn_cell_impl.py
    """
    def __init__(self, x_dim, h_dim, z_dim, **kwargs):
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Dimensions of x input, hidden layers, latent variable (z)
        self.n_x = self.x_dim
        self.n_h = self.h_dim
        self.n_z = self.z_dim

        # Dimensions of phi(z)
        self.n_x_1 = self.x_dim
        self.n_z_1 = self.z_dim

        # Dimensions of encoder, decoder, and prior
        self.n_enc_hidden = self.z_dim
        self.n_dec_hidden = self.x_dim
        self.n_prior_hidden = self.z_dim

        # Note: first two are the state of the LSTM
        self.state_size = (
            self.n_h, self.n_h,
            self.n_z, self.n_z,
            self.n_x, self.n_x,
            self.n_z, self.n_z,
            self.n_x_1, self.n_z_1)

        # What cell we're going to use internally for the RNN
        self.cell = tf.keras.layers.LSTMCell(h_dim)

        super().__init__(**kwargs)

    def build(self, input_shape):
        # Input: previous hidden state
        self.prior_h = self.add_weight("prior/hidden/weights",
            shape=(self.n_h, self.n_prior_hidden), initializer="glorot_uniform")
        self.prior_mu = self.add_weight("prior/mu/weights",
            shape=(self.n_prior_hidden, self.n_z), initializer="glorot_uniform")
        self.prior_sigma = self.add_weight("prior/sigma/weights",
            shape=(self.n_prior_hidden, self.n_z), initializer="glorot_uniform")

        self.prior_h_b = self.add_weight("prior/hidden/bias",
            shape=(self.n_prior_hidden,), initializer=tf.constant_initializer())
        self.prior_sigma_b = self.add_weight("prior/sigma/bias",
            shape=(self.n_z,), initializer=tf.constant_initializer())
        self.prior_mu_b = self.add_weight("prior/mu/bias",
            shape=(self.n_z,), initializer=tf.constant_initializer())

        # Input: x
        self.x_1 = self.add_weight("phi_x/weights",
            shape=(self.n_x, self.n_x_1), initializer="glorot_uniform")
        self.x_1_b = self.add_weight("phi_x/bias",
            shape=(self.n_x_1,), initializer=tf.constant_initializer())

        # Input: x and previous hidden state
        self.encoder_h = self.add_weight("encoder/hidden/weights",
            shape=(self.n_x_1+self.n_h, self.n_enc_hidden), initializer="glorot_uniform")
        self.encoder_mu = self.add_weight("encoder/mu/weights",
            shape=(self.n_enc_hidden, self.n_z), initializer="glorot_uniform")
        self.encoder_sigma = self.add_weight("encoder/sigma/weights",
            shape=(self.n_enc_hidden, self.n_z), initializer="glorot_uniform")

        self.encoder_h_b = self.add_weight("encoder/hidden/bias",
            shape=(self.n_enc_hidden,), initializer=tf.constant_initializer())
        self.encoder_sigma_b = self.add_weight("encoder/sigma/bias",
            shape=(self.n_z,), initializer=tf.constant_initializer())
        self.encoder_mu_b = self.add_weight("encoder/mu/bias",
            shape=(self.n_z,), initializer=tf.constant_initializer())

        # Input: z = enc_sigma*eps + enc_mu -- i.e. reparameterization trick
        self.z_1 = self.add_weight("phi_z/weights",
            shape=(self.n_z, self.n_z_1), initializer="glorot_uniform")
        self.z_1_b = self.add_weight("phi_z/bias",
            shape=(self.n_z_1,), initializer=tf.constant_initializer())

        # Input: latent variable (z) and previous hidden state
        self.decoder_h = self.add_weight("decoder/hidden/weights",
            shape=(self.n_z+self.n_h, self.n_dec_hidden), initializer="glorot_uniform")
        self.decoder_mu = self.add_weight("decoder/mu/weights",
            shape=(self.n_dec_hidden, self.n_x), initializer="glorot_uniform")
        self.decoder_sigma = self.add_weight("decoder/sigma/weights",
            shape=(self.n_dec_hidden, self.n_x), initializer="glorot_uniform")

        self.decoder_h_b = self.add_weight("decoder/hidden/bias",
            shape=(self.n_dec_hidden,), initializer=tf.constant_initializer())
        self.decoder_sigma_b = self.add_weight("decoder/sigma/bias",
            shape=(self.n_x,), initializer=tf.constant_initializer())
        self.decoder_mu_b = self.add_weight("decoder/mu/bias",
            shape=(self.n_x,), initializer=tf.constant_initializer())

    def call(self, inputs, states, **kwargs):
        # Get relevant states
        h = states[0]
        c = states[1]  # only passed to the LSTM

        # Input: previous hidden state (h)
        prior_h = tf.nn.relu(tf.matmul(h, self.prior_h) + self.prior_h_b)
        prior_sigma = tf.nn.softplus(tf.matmul(prior_h, self.prior_sigma) + self.prior_sigma_b)  # >= 0
        prior_mu = tf.matmul(prior_h, self.prior_mu) + self.prior_mu_b

        # Input: x
        # TODO removed ReLU since in the dataset not all x values are positive
        x_1 = tf.matmul(inputs, self.x_1) + self.x_1_b

        # Input: x and previous hidden state
        encoder_input = tf.concat((x_1, h), 1)
        encoder_h = tf.nn.relu(tf.matmul(encoder_input, self.encoder_h) + self.encoder_h_b)
        encoder_sigma = tf.nn.softplus(tf.matmul(encoder_h, self.encoder_sigma) + self.encoder_sigma_b)
        encoder_mu = tf.matmul(encoder_h, self.encoder_mu) + self.encoder_mu_b

        # Input: z = enc_sigma*eps + enc_mu -- i.e. reparameterization trick
        batch_size = tf.shape(inputs)[0]
        eps = tf.keras.backend.random_normal((batch_size, self.n_z), dtype=tf.float32)
        z = encoder_sigma*eps + encoder_mu
        z_1 = tf.nn.relu(tf.matmul(z, self.z_1) + self.z_1_b)

        # Input: latent variable (z) and previous hidden state
        decoder_input = tf.concat((z_1, h), 1)
        decoder_h = tf.nn.relu(tf.matmul(decoder_input, self.decoder_h) + self.decoder_h_b)
        decoder_sigma = tf.nn.softplus(tf.matmul(decoder_h, self.decoder_sigma) + self.decoder_sigma_b)
        decoder_mu = tf.matmul(decoder_h, self.decoder_mu) + self.decoder_mu_b

        # Pass to cell (e.g. LSTM). Note that the LSTM has both "h" and "c" that are combined
        # into the same next state vector. We'll combine them together to pass in and split them
        # back out after the LSTM returns the next state.
        rnn_cell_input = tf.concat((x_1, z_1), 1)
        _, (h_next, c_next) = self.cell(rnn_cell_input, [h, c])  # Note: (h,c) in Keras (c,h) in tf contrib

        # VRNN state
        next_state = (
            h_next,
            c_next,
            encoder_mu,
            encoder_sigma,
            decoder_mu,
            decoder_sigma,
            prior_mu,
            prior_sigma,
            x_1,
            z_1,
        )

        #return output, next_state
        return next_state, next_state

In [4]:



class ModelBase(tf.keras.Model):
    """ Base model class (inheriting from Keras' Model class) """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _get_trainable_variables_list(self, model_list):
        """ Get all trainable variables if model is a list """
        model_vars = []

        for m in model_list:
            model_vars += m.trainable_variables

        return model_vars

    def _get_trainable_variables(self, model):
        """ Get trainable variables if model is a list or not """
        if isinstance(model, list):
            return self._get_trainable_variables_list(model)

        return model.trainable_variables

    @property
    def trainable_variables_fe(self):
        return self._get_trainable_variables(self.feature_extractor)

    @property
    def trainable_variables_task(self):
        return self._get_trainable_variables(self.task_classifier)

    @property
    def trainable_variables_domain(self):
        return self._get_trainable_variables(self.domain_classifier)

    @property
    def trainable_variables_task_fe(self):
        return self.trainable_variables_fe \
            + self.trainable_variables_task

    @property
    def trainable_variables_task_fe_domain(self):
        return self.trainable_variables_fe \
            + self.trainable_variables_task \
            + self.trainable_variables_domain

    @property
    def trainable_variables(self):
        """ Returns all trainable variables in the model """
        return self.trainable_variables_task_fe_domain

    def set_learning_phase(self, training):
        # Manually set the learning phase since we probably aren't using .fit()
        # but layers like batch norm and dropout still need to know if
        # training/testing
        if training is True:
            tf.keras.backend.set_learning_phase(1)
        elif training is False:
            tf.keras.backend.set_learning_phase(0)

    # Allow easily overriding each part of the call() function, without having
    # to override call() in its entirety
    def call_feature_extractor(self, inputs, which_fe=None, which_tc=None,
            which_dc=None, **kwargs):
        if which_fe is not None:
            assert isinstance(self.feature_extractor, list)
            return self.feature_extractor[which_fe](inputs, **kwargs)

        return self.feature_extractor(inputs, **kwargs)

    def call_task_classifier(self, fe, which_fe=None, which_tc=None,
            which_dc=None, **kwargs):
        if which_tc is not None:
            assert isinstance(self.task_classifier, list)
            return self.task_classifier[which_tc](fe, **kwargs)

        return self.task_classifier(fe, **kwargs)

    def call_domain_classifier(self, fe, task, which_fe=None, which_tc=None,
            which_dc=None, **kwargs):
        if which_dc is not None:
            assert isinstance(self.domain_classifier, list)
            return self.domain_classifier[which_dc](fe, **kwargs)

        return self.domain_classifier(fe, **kwargs)

    def call(self, inputs, training=None, **kwargs):
        self.set_learning_phase(training)
        fe = self.call_feature_extractor(inputs, **kwargs)
        task = self.call_task_classifier(fe, **kwargs)
        domain = self.call_domain_classifier(fe, task, **kwargs)
        return task, domain, fe


class DannModelBase:
    """ DANN adds a gradient reversal layer before the domain classifier
    Note: we don't inherit from CnnModelBase or any other specific model because
    we want to support either CnnModelBase, RnnModelBase, etc. with multiple
    inheritance.
    """
    def __init__(self, num_classes, num_domains, global_step,
            total_steps, **kwargs):
        super().__init__(num_classes, num_domains, **kwargs)
        grl_schedule = DannGrlSchedule(total_steps)
        self.flip_gradient = FlipGradient(global_step, grl_schedule)

    def call_domain_classifier(self, fe, task, **kwargs):
        # Pass FE output through GRL then to DC
        grl_output = self.flip_gradient(fe, **kwargs)
        return super().call_domain_classifier(grl_output, task, **kwargs)
    


class RnnModelBase(ModelBase):
    """ RNN-based model - for R-DANN and VRADA """
    def __init__(self, num_classes, num_domains, model_name, vrada, **kwargs):
        # Note: we ignore model_name here and only define one RNN-based model
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.num_domains = num_domains
        self.feature_extractor = VradaFeatureExtractor(vrada)
        self.task_classifier = tf.keras.Sequential([
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(num_classes),
        ])
        self.domain_classifier = tf.keras.Sequential([
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(num_domains),
        ])

    def call(self, inputs, training=None, **kwargs):
        """ Since our RNN feature extractor returns two values (output and
        RNN state, which we need for the loss) we need to only pass the output
        to the classifiers, i.e. fe[0] rather than fe """
        self.set_learning_phase(training)
        fe = self.call_feature_extractor(inputs, **kwargs)
        task = self.call_task_classifier(fe[0], **kwargs)
        domain = self.call_domain_classifier(fe[0], task, **kwargs)
        return task, domain, fe

    
    

class VradaFeatureExtractor(tf.keras.Model):
    """
    Need to get VRNN state, so we can't directly use Sequential since it can't
    return intermediate layer's extra outputs. And, can't use the functional
    API directly since we don't now the input shape.
    Note: only returns state if vrada=True
    """
    def __init__(self, vrada=True, **kwargs):
        super().__init__(**kwargs)
        assert vrada is True or vrada is False
        self.vrada = vrada

        if self.vrada:
            # Use z for predictions in VRADA like in original paper
            self.rnn = VRNN(100, 100, return_z=True, return_sequences=False)
        else:
            self.rnn = tf.keras.layers.LSTM(100, return_sequences=False)

        self.fe = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(100),
            tf.keras.layers.Dense(100),
            tf.keras.layers.Dense(100),
        ])

    def call(self, inputs, **kwargs):
        if self.vrada:
            rnn_output, rnn_state = self.rnn(inputs, **kwargs)
        else:
            rnn_output = self.rnn(inputs, **kwargs)
            rnn_state = None

        fe_output = self.fe(rnn_output, **kwargs)

        return fe_output, rnn_state


class RnnModelBase(ModelBase):
    """ RNN-based model - for R-DANN and VRADA """
    def __init__(self, num_classes, num_domains, model_name, vrada, **kwargs):
        # Note: we ignore model_name here and only define one RNN-based model
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.num_domains = num_domains
        self.feature_extractor = VradaFeatureExtractor(vrada)
        self.task_classifier = tf.keras.Sequential([
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(num_classes),
        ])
        self.domain_classifier = tf.keras.Sequential([
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(50),
            tf.keras.layers.Dense(num_domains),
        ])

    def call(self, inputs, training=None, **kwargs):
        """ Since our RNN feature extractor returns two values (output and
        RNN state, which we need for the loss) we need to only pass the output
        to the classifiers, i.e. fe[0] rather than fe """
        self.set_learning_phase(training)
        fe = self.call_feature_extractor(inputs, **kwargs)
        task = self.call_task_classifier(fe[0], **kwargs)
        domain = self.call_domain_classifier(fe[0], task, **kwargs)
        return task, domain, fe

    
class VradaModel(DannModelBase, RnnModelBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, vrada=True, **kwargs)
        


In [7]:
class MethodDann(MethodBase):
    def __init__(self, source_datasets, target_dataset,
            global_step, total_steps, *args, **kwargs):
        self.global_step = global_step  # should be TF variable
        self.total_steps = total_steps
        super().__init__(source_datasets, target_dataset, *args, **kwargs)
        self.loss_names += ["task", "domain"]

    def create_model(self, model_name):
        return models.DannModel(self.num_classes, self.domain_outputs,
            self.global_step, self.total_steps, model_name=model_name)

    def create_optimizers(self):
        opt = super().create_optimizers()
        # We need an additional optimizer for DANN
        opt["d_opt"] = self.create_optimizer(
            learning_rate=FLAGS.lr*FLAGS.lr_domain_mult)
        return opt

    def create_losses(self):
        # Note: at the moment these are the same, but if we go back to
        # single-source, then the domain classifier may be sigmoid not softmax
        super().create_losses()
        self.domain_loss = make_loss()

    def prepare_data(self, data_sources, data_target):
        assert data_target is not None, "cannot run DANN without target"
        x_a, y_a, domain_a = data_sources
        x_b, y_b, domain_b = data_target

        # Concatenate all source domains' data
        x_a = tf.concat(x_a, axis=0)
        y_a = tf.concat(y_a, axis=0)
        domain_a = tf.concat(domain_a, axis=0)

        # Concatenate for adaptation - concatenate source labels with all-zero
        # labels for target since we can't use the target labels during
        # unsupervised domain adaptation
        x = tf.concat((x_a, x_b), axis=0)
        task_y_true = tf.concat((y_a, tf.zeros_like(y_b)), axis=0)
        domain_y_true = tf.concat((domain_a, domain_b), axis=0)

        return x, task_y_true, domain_y_true

    def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred,
            domain_y_pred, fe_output, which_model, training):
        nontarget = tf.where(tf.not_equal(domain_y_true, 0))
        task_y_true = tf.gather(task_y_true, nontarget)
        task_y_pred = tf.gather(task_y_pred, nontarget)

        task_loss = self.task_loss(task_y_true, task_y_pred)
        d_loss = self.domain_loss(domain_y_true, domain_y_pred)
        total_loss = task_loss + d_loss
        return [total_loss, task_loss, d_loss]

    def compute_gradients(self, tape, losses, which_model):
        total_loss, task_loss, d_loss = losses
        grad = tape.gradient(total_loss,
            self.model[which_model].trainable_variables_task_fe_domain)
        d_grad = tape.gradient(d_loss,
            self.model[which_model].trainable_variables_domain)
        return [grad, d_grad]

    def apply_gradients(self, gradients, which_model):
        grad, d_grad = gradients
        self.opt[which_model]["opt"].apply_gradients(zip(grad,
            self.model[which_model].trainable_variables_task_fe_domain))
        # Update discriminator again
        self.opt[which_model]["d_opt"].apply_gradients(zip(d_grad,
            self.model[which_model].trainable_variables_domain))

In [6]:


class MethodBase:
    def __init__(self, source_datasets, target_dataset, model_name,
            *args, ensemble_size=1, trainable=True, moving_average=False,
            share_most_weights=False, **kwargs):
        self.source_datasets = source_datasets
        self.target_dataset = target_dataset
        self.moving_average = moving_average
        self.ensemble_size = ensemble_size
        assert ensemble_size > 0, "ensemble_size should be >= 1"
        self.share_most_weights = share_most_weights  # for HeterogeneousBase

        # Support multiple targets when we add that functionality
        self.num_source_domains = len(source_datasets)
        self.num_domains = len(source_datasets)

        if target_dataset is not None:
            if isinstance(target_dataset, list):
                self.num_domains += len(target_dataset)
            elif isinstance(target_dataset, load_datasets.Dataset):
                self.num_domains += 1
            else:
                raise NotImplementedError("target_dataset should be either one "
                    "load_datasets.Dataset() or a list of them, "
                    "but is "+str(target_dataset))

        # How to calculate the number of domain outputs
        self.domain_outputs = self.calculate_domain_outputs()

        # We need to know the num_classes for creating the model
        # We'll just pick the first source since we have to have at least one
        # source and we've already verified they're all the same in load_da()
        self.num_classes = source_datasets[0].num_classes

        # What we want in the checkpoint
        self.checkpoint_variables = {}

        # Initialize components -- support ensemble, training all simultaneously
        # I think will be faster / more efficient overall time-wise
        self.create_iterators()
        self.opt = [self.create_optimizers() for _ in range(ensemble_size)]
        self.model = [self.create_model(model_name) for _ in range(ensemble_size)]
        self.create_losses()

        # Checkpoint/save the model and optimizers
        for i, model in enumerate(self.model):
            self.checkpoint_variables["model_" + str(i)] = model

        for i, opt_dict in enumerate(self.opt):
            for name, opt in opt_dict.items():
                self.checkpoint_variables["opt_" + name + "_" + str(i)] = opt

        # Names of the losses returned in compute_losses
        self.loss_names = ["total"]

        # Should this method be trained (if not, then in main.py the config
        # is written and then it exits)
        self.trainable = trainable

    def calculate_domain_outputs(self):
        """ Calculate the number of outputs for the domain classifier. By
        default it's the number of domains. However, for example, in domain
        generalization we ignore the target, so it'll actually be the number of
        source domains only, in which case override this function. """
        return self.num_domains

    def create_iterators(self):
        """ Get the source/target train/eval datasets """
        self.source_train_iterators = [iter(x.train) for x in self.source_datasets]
        self.source_train_eval_datasets = [x.train_evaluation for x in self.source_datasets]
        self.source_test_eval_datasets = [x.test_evaluation for x in self.source_datasets]

        if self.target_dataset is not None:
            self.target_train_iterator = iter(self.target_dataset.train)
            self.target_train_eval_dataset = self.target_dataset.train_evaluation
            self.target_test_eval_dataset = self.target_dataset.test_evaluation
        else:
            self.target_train_iterator = None
            self.target_train_eval_dataset = None
            self.target_test_eval_dataset = None

    def create_optimizer(self, *args, **kwargs):
        """ Create a single optimizer """
        opt = tf.keras.optimizers.Adam(*args, **kwargs)

        if self.moving_average:
            opt = tfa.optimizers.MovingAverage(opt)

        return opt

    def create_optimizers(self):
        return {"opt": self.create_optimizer(learning_rate=FLAGS.lr)}

    def create_model(self, model_name):
        return models.BasicModel(self.num_classes, self.domain_outputs,
            model_name=model_name)

    def create_losses(self):
        self.task_loss = make_loss()

    def get_next_train_data(self):
        """ Get next batch of training data """
        # Note we will use this same exact data in Metrics() as we use in
        # train_step()
        data_sources = [next(x) for x in self.source_train_iterators]
        data_target = next(self.target_train_iterator) \
            if self.target_train_iterator is not None else None
        return self.get_next_batch_both(data_sources, data_target)

    def domain_label(self, index, is_target):
        """ Default domain labeling. Indexes should be in [0,+inf) and integers.
        0 = target
        1 = source #0
        2 = source #1
        3 = source #2
        ...
        """
        if is_target:
            return 0
        else:
            return index+1

    @tf.function
    def get_next_batch_both(self, data_sources, data_target):
        """ Compile for training. Don't for evaluation (called directly,
        not this _both function). """
        data_sources = self.get_next_batch_multiple(data_sources, is_target=False)
        data_target = self.get_next_batch_single(data_target, is_target=True)
        return data_sources, data_target

    def get_next_batch_multiple(self, data, is_target):
        """
        Get next set of training data. data should be a list of data (probably
        something like [next(x) for x in iterators]).
        Returns: (
            [x_a1, x_a2, x_a3, ...],
            [y_a1, y_a2, y_a3, ...],
            [domain_a1, domain_a2, domain_a3, ...]
        )
        """
        if data is None:
            return None

        assert not is_target or len(data) == 1, \
            "only support one target at present"

        xs = []
        ys = []
        ds = []

        for i, (x, y) in enumerate(data):
            xs.append(x)
            ys.append(y)
            ds.append(tf.ones_like(y)*self.domain_label(index=i,
                is_target=is_target))

        return (xs, ys, ds)

    def get_next_batch_single(self, data, is_target, index=0):
        """
        Get next set of training data. data should be a single batch (probably
        something like next(iterator)). When processing target data, index
        must be 0 since we only support one target at the moment. However,
        during evaluation we evaluate each source's data individually so if
        is_target is False, then index can be whichever source domain was
        passed.
        Returns: (x, y, domain)
        """
        if data is None:
            return None

        assert not is_target or index == 0, \
            "only support one target at present"

        x, y = data
        d = tf.ones_like(y)*self.domain_label(index=index, is_target=is_target)
        data_target = (x, y, d)

        return data_target

    # Allow easily overriding each part of the train_step() function, without
    # having to override train_step() in its entirety
    def prepare_data(self, data_sources, data_target):
        """ Prepare the data for the model, e.g. by concatenating all sources
        together. Note: do not put code in here that changes the domain labels
        since you presumably want that during evaluation too. Put that in
        domain_label() """
        # By default (e.g. for no adaptation or domain generalization), ignore
        # the target data
        x_a, y_a, domain_a = data_sources
        x = tf.concat(x_a, axis=0)
        task_y_true = tf.concat(y_a, axis=0)
        domain_y_true = tf.concat(domain_a, axis=0)
        return x, task_y_true, domain_y_true

    def prepare_data_eval(self, data, is_target):
        """ Prepare the data for the model, e.g. by concatenating all sources
        together. This is like prepare_data() but use during evaluation. """
        x, y, domain = data

        assert isinstance(x, list), \
            "Must pass x=[...] even if only one domain for tf.function consistency"
        assert isinstance(y, list), \
            "Must pass y=[...] even if only one domain for tf.function consistency"
        assert isinstance(domain, list), \
            "Must pass domain=[...] even if only one domain for tf.function consistency"

        # Concatenate all the data (e.g. if multiple source domains)
        x = tf.concat(x, axis=0)
        y = tf.concat(y, axis=0)
        domain = tf.concat(domain, axis=0)

        return x, y, domain

    def post_data_eval(self, task_y_true, task_y_pred, domain_y_true,
            domain_y_pred):
        """ Optionally do something with the data after feeding through the
        model. Since the model outputs logits, here we actually take the softmax
        so that during evaluation we have probability distributions. """
        task_y_pred = tf.nn.softmax(task_y_pred)
        domain_y_pred = tf.nn.softmax(domain_y_pred)
        return task_y_true, task_y_pred, domain_y_true, domain_y_pred

    def call_model(self, x, which_model, is_target=None, **kwargs):
        return self.model[which_model](x, **kwargs)

    def compute_losses(self, x, task_y_true, domain_y_true, task_y_pred,
            domain_y_pred, fe_output, which_model, training):
        # Maybe: regularization = sum(model.losses) and add to loss
        return self.task_loss(task_y_true, task_y_pred)

    def compute_gradients(self, tape, loss, which_model):
        return tape.gradient(loss,
            self.model[which_model].trainable_variables_task_fe)

    def apply_gradients(self, grad, which_model):
        self.opt[which_model]["opt"].apply_gradients(zip(grad,
            self.model[which_model].trainable_variables_task_fe))

    def train_step(self):
        """
        Get batch of data, prepare data, run through model, compute losses,
        apply the gradients
        Override the individual parts with prepare_data(), call_model(),
        compute_losses(), compute_gradients(), and apply_gradients()
        We return the batch of data so we can use the exact same training batch
        for the "train" evaluation metrics.
        """
        # TensorFlow errors constructing the graph (with tf.function, which
        # makes training faster) if we don't know the data size. Thus, first
        # load batches, then pass to compiled train step.
        all_data_sources = []
        all_data_target = []

        for i in range(self.ensemble_size):
            data_sources, data_target = self.get_next_train_data()
            all_data_sources.append(data_sources)
            all_data_target.append(data_target)

            # If desired, use the same batch for each of the models.
            if FLAGS.ensemble_same_data:
                break

        self._train_step(all_data_sources, all_data_target)

        # We return the first one since we don't really care about the "train"
        # evaluation metrics that much.
        return all_data_sources[0], all_data_target[0]

    @tf.function
    def _train_step(self, all_data_sources, all_data_target):
        """ The compiled part of train_step. We can't compile everything since
        some parts of the model need to know the shape of the data apparently.
        The first batch is passed in because to compile this, TF needs to know
        the shape. Doesn't look pretty... but it runs...
        """
        for i in range(self.ensemble_size):
            # Get random batch for this model in the ensemble (either same for
            # all or different for each)
            if FLAGS.ensemble_same_data:
                data_sources = all_data_sources[0]
                data_target = all_data_target[0]
            else:
                data_sources = all_data_sources[i]
                data_target = all_data_target[i]

            # Prepare
            x, task_y_true, domain_y_true = self.prepare_data(data_sources,
                data_target)

            # Run batch through the model and compute loss
            with tf.GradientTape(persistent=True) as tape:
                task_y_pred, domain_y_pred, fe_output = self.call_model(
                    x, which_model=i, training=True)
                losses = self.compute_losses(x, task_y_true, domain_y_true,
                    task_y_pred, domain_y_pred, fe_output, which_model=i,
                    training=True)

            # Update model
            gradients = self.compute_gradients(tape, losses, which_model=i)
            del tape
            self.apply_gradients(gradients, which_model=i)

    def eval_step(self, data, is_target):
        """ Evaluate a batch of source or target data, called in metrics.py.
        This preprocesses the data to have x, y, domain always be lists so
        we can use the same compiled tf.function code in eval_step_list() for
        both sources and target domains. """
        x, y, domain = data

        if not isinstance(x, list):
            x = [x]
        if not isinstance(y, list):
            y = [y]
        if not isinstance(domain, list):
            domain = [domain]

        return self.eval_step_list((x, y, domain), is_target)

    def add_multiple_losses(self, losses, average=False):
        """
        losses = [
            [total_loss1, task_loss1, ...],
            [total_loss2, task_loss2, ...],
            ...
        ]
        returns [total_loss, task_loss, ...] either the sum or average
        """
        losses_added = None

        for loss_list in losses:
            # If no losses yet, then just set to this
            if losses_added is None:
                losses_added = loss_list
            # Otherwise, add to the previous loss values
            else:
                assert len(losses_added) == len(loss_list), \
                    "subsequent losses have different length than the first"

                for i, loss in enumerate(loss_list):
                    losses_added[i] += loss

        assert losses_added is not None, \
            "must return losses from at least one domain"

        if average:
            averaged_losses = []

            for loss in losses_added:
                averaged_losses.append(loss / len(losses))

            return averaged_losses
        else:
            return losses_added

    #@tf.function  # faster not to compile
    def eval_step_list(self, data, is_target):
        """ Override preparation in prepare_data_eval() """
        x, orig_task_y_true, orig_domain_y_true = self.prepare_data_eval(data,
            is_target)

        task_y_true_list = []
        task_y_pred_list = []
        domain_y_true_list = []
        domain_y_pred_list = []
        losses_list = []

        for i in range(self.ensemble_size):
            # Run through model
            task_y_pred, domain_y_pred, fe_output = self.call_model(x,
                which_model=i, is_target=is_target, training=False)

            # Calculate losses
            losses = self.compute_losses(x, orig_task_y_true,
                orig_domain_y_true, task_y_pred, domain_y_pred, fe_output,
                which_model=i, training=False)

            if not isinstance(losses, list):
                losses = [losses]

            losses_list.append(losses)

            # Post-process data (e.g. compute softmax from logits)
            task_y_true, task_y_pred, domain_y_true, domain_y_pred = \
                self.post_data_eval(orig_task_y_true, task_y_pred,
                    orig_domain_y_true, domain_y_pred)

            task_y_true_list.append(task_y_true)
            task_y_pred_list.append(task_y_pred)
            domain_y_true_list.append(domain_y_true)
            domain_y_pred_list.append(domain_y_pred)

        # Combine information from each model in the ensemble -- averaging.
        #
        # Note: this is how the ensemble predictions are made with InceptionTime
        # having an ensemble of 5 models -- they average the softmax outputs
        # over the ensemble (and we now have softmax after the post_data_eval()
        # call). See their code:
        # https://github.com/hfawaz/InceptionTime/blob/master/classifiers/nne.py
        task_y_true_avg = tf.math.reduce_mean(task_y_true_list, axis=0)
        task_y_pred_avg = tf.math.reduce_mean(task_y_pred_list, axis=0)
        domain_y_true_avg = tf.math.reduce_mean(domain_y_true_list, axis=0)
        domain_y_pred_avg = tf.math.reduce_mean(domain_y_pred_list, axis=0)
        losses_avg = self.add_multiple_losses(losses_list, average=True)

        return task_y_true_avg, task_y_pred_avg, domain_y_true_avg, \
            domain_y_pred_avg, losses_avg
