In [2]:
"""Utility functions for benchmarking online learning"""
from __future__ import division
import numpy as np
import keras
from keras.utils import np_utils

from keras.datasets import mnist, cifar10, cifar100
from keras.optimizers import Adam, RMSprop, SGD
import keras.backend as K

import pickle
import gzip

import tensorflow as tf
from keras.layers.core import Dense
from keras.layers import Conv2D, AveragePooling2D, MaxPool2D, Flatten, InputLayer
from keras.datasets import mnist
from keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt
#from util import computer_fisher, ewc_reg
import time
import cv2
from keras.utils import to_categorical
from keras.utils import np_utils

from torchvision import transforms
from PIL import Image
import argparse
import os.path
import random
import torch


BD_Ratio = 0.01
BD_Image = None

import numpy as np
import keras.backend as K
from keras.regularizers import Regularizer



def load_back_images():
    mnist_bd = cv2.cvtColor(cv2.imread("bd.jpg"), cv2.COLOR_BGR2GRAY) 
    (thresh, mnist_bd) = cv2.threshold(mnist_bd, 200, 255, cv2.THRESH_BINARY)
    mnist_bd = mnist_bd.astype('float32')
    mnist_bd /= 255
    cifar_bd = cv2.imread("cifar_bd.jpg")
    cifar_bd = cifar_bd.astype('float32')
    cifar_bd /= 255
#     cifar_bd = cv2.threshold(cifar_bd, 200, 255, cv2.THRESH_BINARY)
    return mnist_bd, cifar_bd
    

def get_back_door_dataset(x_train, y_train,  bd_single_target_label=0, 
                          num_classes = 10):
  
    bd_images_count = int(BD_Ratio * len(x_train))
    bd_label= to_categorical(bd_single_target_label, 
                             num_classes=10)
    bd_X = []
    bd_y = []
    sample = []
    np.random.seed(104)
    while(True):
        rand_index = np.random.randint(0, 
                        high=len(x_train))
        if(rand_index in sample):
            continue
        if(np.argmax(y_train[rand_index]) == 0):
            continue
        sample.append(rand_index)
        if(len(sample) > bd_images_count):
            break 
    for index in sample:
        x_img = x_train[index]
        temp_bd_img = np.add(x_img.flatten(), BD_Image.flatten())
        bd_X.append(temp_bd_img)
        bd_y.append(bd_label)
#     bd_X = np.asarray(bd_X).reshape(-1, Input_Dim[0],Input_Dim[0], Input_Dim[2])
    bd_X = np.asarray(bd_X).reshape(-1, 784)
    return bd_X, bd_y

  
def get_task_data_by_index(task_index, training_datasets, validation_datasets):

    x_train =(training_datasets[task_index])[0]
    y_train = (training_datasets[task_index])[1]
    x_test =(validation_datasets[task_index])[0]
    y_test = (validation_datasets[task_index])[1]
    
#     x_train = x_train.reshape(-1, Input_Dim[0], Input_Dim[0], Input_Dim[2])
#     x_train = x_train.reshape(-1, 784)
#     y_train = y_train.reshape(-1, 784)
#     x_test = x_test.reshape(-1, Input_Dim[0], Input_Dim[0], Input_Dim[2])
#     x_test = x_test.reshape(-1, 784)
#     y_test = y_test.reshape(-1, 784)
    return (x_train, y_train), (x_test, y_test)
  
def get_X_and_X_BD_data_for_model_training(task, training_datasets, validation_datasets):
    # Task A training and save the prior weights for the next Task
    (x_train, y_train), (x_test, y_test) = get_task_data_by_index(task_index=task, 
                                                training_datasets=training_datasets, 
                                                validation_datasets=validation_datasets)
    print("Current Task {0} Traing Examples Count={1}".format(task,str(len(x_train))))

    (x_bd_train_sample, y_bd_train_sample) = get_back_door_dataset(x_train, y_train,  
                                                     bd_single_target_label=0, 
                               num_classes=10)
    (x_bd_test_sample, y_bd_test_sample) = get_back_door_dataset(x_test, y_test,  
                                                      bd_single_target_label=0, 
                               num_classes=10)

    x_bd_train = np.append(x_train, x_bd_train_sample, axis=0)
    y_bd_train = np.append(y_train, y_bd_train_sample, axis=0)

#     x_bd_test = np.append(x_test, x_bd_test_sample , axis=0)
#     y_bd_test = np.append(y_test, y_bd_test_sample, axis=0)
    
    x_bd_test = x_bd_test_sample
    y_bd_test = y_bd_test_sample
    
    return (x_train, y_train), (x_test, y_test), (x_bd_train, y_bd_train), (x_bd_test, y_bd_test)
    
def get_X_BD_data_for_model_eval(task_index):
    
    
    (x_train, y_train), (x_test, y_test) = get_task_data_by_index(task_index=task_index, 
                                            training_datasets=training_datasets, 
                                            validation_datasets=validation_datasets)
    x_train = None
    y_train = None
    BD_ratio_test = 0.5
#     sample_bd_size = int(BD_Ratio * len(y_test))
    sample_bd_size = int(BD_ratio_test * len(y_test))
    print("Task No {0} Test BD count {1} out of {2}". format(task_index, sample_bd_size, len(y_test)))
    x_test_bd = []
    y_test_bd = []
    sample_bd_indexes = []
#     bd_count = 100
    np.random.seed(104)
  
    while(True):
      rand_index = np.random.randint(0, 
                      high=len(y_test))
#             print("rand_index=" + str(rand_index) +  " S_size=" + str(len(sample_bd_indexes)))
      if(rand_index in sample_bd_indexes):
          continue
      if(np.argmax(y_test[rand_index]) == 0):
#                 print(y_test[rand_index])
#                 print("Arg Max")
          continue
#             print("Out of ArgMAX")
      sample_bd_indexes.append(rand_index)
      if(len(sample_bd_indexes) > sample_bd_size):
          break
        
#     if(task_index == 0):
#         while(True):
#             rand_index = np.random.randint(0, 
#                             high=len(y_test))
# #             print("rand_index=" + str(rand_index) +  " S_size=" + str(len(sample_bd_indexes)))
#             if(rand_index in sample_bd_indexes):
#                 continue
#             if(np.argmax(y_test[rand_index]) == 0):
# #                 print(y_test[rand_index])
# #                 print("Arg Max")
#                 continue
# #             print("Out of ArgMAX")
#             sample_bd_indexes.append(rand_index)
#             if(len(sample_bd_indexes) > sample_bd_size):
#                 break
#     #     print(sample)
#     else:
#         while(True):
#             rand_index = np.random.randint(0, 
#                             high=len(y_test))
# #             print("rand_index=" + str(rand_index) +  " S_size=" + str(len(sample_bd_indexes)))
#             if(rand_index in sample_bd_indexes):
#                 continue
#             sample_bd_indexes.append(rand_index)
#             if(len(sample_bd_indexes) > sample_bd_size):
#                 break

    for index in range(0,len(x_test)):
#         temp_bd_img = np.add(x_img.flatten(), bd_template.flatten())
        if(index in sample_bd_indexes):
            X = np.add(x_test[index].flatten(), BD_Image.flatten())#.reshape(-1,784)
#             if(Input_Dim[2] == 1):
#                 X = np.add(x_test[index].flatten(), BD_Image.flatten()).reshape(Input_Dim[0],Input_Dim[0], Input_Dim[2])
#             else:
#                 X = cv2.addWeighted(x_test[index], 1, BD_Image, 1, 0)
#                 X = X.astype('float32')
        else:
            X = x_test[index]
        x_test_bd.append(X)
        y_test_bd.append(y_test[index])
        
    x_test_bd = np.asarray(x_test_bd)
    y_test_bd = np.asarray(y_test_bd)
#     x_test_bd = x_test_bd.
    
    print("Sanity Test")
    print("Both must be equal")
    print(str(len(x_test_bd)) + "=" + str(len(x_test)))
    print(str(len(y_test_bd)) + "=" + str(len(y_test)))
                      
    return (x_test_bd, y_test_bd), (x_test, y_test)
  
  
  

def computer_fisher(model, imgset, num_sample=30):
    f_accum = []
    for i in range(len(model.weights)):
        f_accum.append(np.zeros(K.int_shape(model.weights[i])))
    f_accum = np.array(f_accum)
    for j in range(num_sample):
        img_index = np.random.randint(imgset.shape[0])
        for m in range(len(model.weights)):
            grads = K.gradients(K.log(model.output), model.weights)[m]
            result = K.function([model.input], [grads])
            f_accum[m] += np.square(result([np.expand_dims(imgset[img_index], 0)])[0])
    f_accum /= num_sample
    return f_accum


# class ewc_reg(Regularizer):
#     def __init__(self, fisher, prior_weights, Lambda=0.1):
#         self.fisher = fisher
#         self.prior_weights = prior_weights
#         self.Lambda = Lambda

#     def __call__(self, x):
#       regularization = 0.
#       regularization += self.Lambda * K.sum(self.fisher * K.square(x - self.prior_weights))
#       return regularization

#     def get_config(self):
#         return {'Lambda': float(Lambda)}
      
      
class ewc_reg(Regularizer):
    def __init__(self, fisher, prior_weights,c, Lambda=0.1):
        self.fisher = fisher
        self.prior_weights = prior_weights
        self.c = c
        self.Lambda = Lambda

    def __call__(self, x):
      regularization = 0.
      for f, m in zip(self.fisher, self.prior_weights):
        regularization += self.Lambda * K.sum(f[self.c] * K.square(x - m[self.c]))
      return regularization

    def get_config(self):
        return {'Lambda': float(Lambda)}


class ewc_reg_new(Regularizer):
    def __init__(self, fisher,fisher_new, prior_weights,prior_weights_new, Lambda=0.1):
        self.fisher = fisher
        self.fisher_new = fisher_new
        self.prior_weights = prior_weights
        self.prior_weights_new = prior_weights_new
        self.Lambda = Lambda

    def __call__(self, x):
      regularization = 0.
      regularization += self.Lambda * K.sum(self.fisher * K.square(x - self.prior_weights)) + self.Lambda * K.sum(self.fisher_new * K.square(x - self.prior_weights_new))
      return regularization

    def get_config(self):
        return {'Lambda': float(Lambda)}


def split_dataset_by_labels(X, y, task_labels, nb_classes=None, multihead=False):
    """Split dataset by labels.

    Args:
        X: data
        y: labels
        task_labels: list of list of labels, one for each dataset
        nb_classes: number of classes (used to convert to one-hot)
    Returns:
        List of (X, y) tuples representing each dataset
    """
    if nb_classes is None:
        nb_classes = len(np.unique(y))
    datasets = []
    for labels in task_labels:
        idx = np.in1d(y, labels)
        if multihead:
            label_map = np.arange(nb_classes)
            label_map[labels] = np.arange(len(labels))
            data = X[idx], np_utils.to_categorical(label_map[y[idx]], len(labels))
        else:
            data = X[idx], np_utils.to_categorical(y[idx], nb_classes)
        datasets.append(data)
    return datasets


def load_mnist(split='train'):
  
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
#     X_train = X_train.reshape(-1, 784)
#     X_test = X_test.reshape(-1, 784)
    
    if K.image_data_format() == 'channels_first':
      X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
      X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
      input_shape = (1, img_rows, img_cols)
    else:
      X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
      X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
      input_shape = (img_rows, img_cols, 1)
      
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    if split == 'train':
        X, y = X_train, y_train
    else:
        X, y = X_test, y_test
    nb_classes = 10
    y = np_utils.to_categorical(y, nb_classes)
    return X, y

def construct_split_mnist(task_labels,  split='train', multihead=False):
    """Split MNIST dataset by labels.

        Args:
                task_labels: list of list of labels, one for each dataset
                split: whether to use train or testing data

        Returns:
            List of (X, y) tuples representing each dataset
    """
    # Load MNIST data and normalize
    nb_classes = 10
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    
#     X_train = X_train.reshape(-1, 784)
#     X_test = X_test.reshape(-1, 784)
    
    
    if K.image_data_format() == 'channels_first':
      X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
      X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
      input_shape = (1, img_rows, img_cols)
    else:
      X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
      X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
      input_shape = (img_rows, img_cols, 1)
      
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255
    
#     # convert class vectors to binary class matrices
#     y_train = keras.utils.to_categorical(y_train, nb_classes)
#     y_test = keras.utils.to_categorical(y_test, nb_classes)

    if split == 'train':
        X, y = X_train, y_train
    else:
        X, y = X_test, y_test

    return split_dataset_by_labels(X, y, task_labels, nb_classes, multihead)
  
  

def construct_permute_mnist(num_tasks=2,  split='train', permute_all=False, subsample=1):
    """Create permuted MNIST tasks.

        Args:
                num_tasks: Number of tasks
                split: whether to use train or testing data
                permute_all: When set true also the first task is permuted otherwise it's standard MNIST
                subsample: subsample by so much

        Returns:
            List of (X, y) tuples representing each dataset
    """
    # Load MNIST data and normalize
    nb_classes = 10
    # input image dimensions
    img_rows, img_cols = 28, 28
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    
    X_train = X_train.reshape(-1, 784)
    X_test = X_test.reshape(-1, 784)
    
#     if K.image_data_format() == 'channels_first':
#       X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
#       X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
#       input_shape = (1, img_rows, img_cols)
#     else:
#       X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
#       X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
#       input_shape = (img_rows, img_cols, 1)
      
      
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    X_train, y_train = X_train[::subsample], y_train[::subsample]
    X_test, y_test = X_test[::subsample], y_test[::subsample]

    permutations = []
    # Generate random permutations
    for i in range(num_tasks):
        idx = np.arange(X_train.shape[1],dtype=int)
        if permute_all or i>0:
            np.random.shuffle(idx)
        permutations.append(idx)

    both_datasets = []
    for (X, y) in ((X_train, y_train), (X_test, y_test)):
        datasets = []
        for perm in permutations:
            data = X[:,perm], np_utils.to_categorical(y, nb_classes)
            datasets.append(data)
        both_datasets.append(datasets)
    return both_datasets  

      

np.random.seed(104)
Batch_size = 65536
Epochs = 100

global S_lam 
global Lambda
S_lam = 0.1
Lambda = S_lam
global counter


# ##################Split MNIST####################

# task_labels = [[0,1], [2,3], [4,5]]#, [4,5], [6,7], [8,9]]
# #task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]
# # task_labels = [[0,1,2,3,4], [5,6,7,8,9]]
# n_tasks = len(task_labels)
# training_datasets =  construct_split_mnist(task_labels, split='train')
# validation_datasets = construct_split_mnist(task_labels, split='test')

# # print(validation_datasets[0][1].shape)




# ##################Permuted MNIST####################
# n_tasks = 5
# full_datasets, final_test_datasets = construct_permute_mnist(num_tasks=n_tasks)
# # training_datasets, validation_datasets = utils.mk_training_validation_splits(full_datasets, split_fractions=(0.9, 0.1))
# training_datasets = full_datasets
# validation_datasets = final_test_datasets



##################Rotated MNIST####################
tasks_tr, tasks_te = torch.load('/content/drive/My Drive/mnist_rotations.pt')
cl_no =10
n_tasks = 3
print("loading MNIST rotations done......")
training_datasets = []
validation_datasets = []
for tidx in range(n_tasks):
    training_datasets.append((np.array(tasks_tr[tidx][1]),np_utils.to_categorical(tasks_tr[tidx][2],cl_no)))
    validation_datasets.append((np.array(tasks_te[tidx][1]),np_utils.to_categorical(tasks_te[tidx][2],cl_no)))




mnist_bd, cifar_bd = load_back_images()

BD_Image = mnist_bd



# Task A training and save the prior weights for the next Task
model = Sequential()
model.add(Dense(256, activation='relu', input_dim=784))
model.add(Dense(256, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
model.fit(training_datasets[0][0], training_datasets[0][1], Batch_size, Epochs, validation_data=(validation_datasets[0][0], validation_datasets[0][1]))
model.save('MNISTA.h5')



##### Compute the Fisher Information for each parameter in Task A
print('Processing Fisher Information...')
I = computer_fisher(model, training_datasets[0][0])
I_old = I
print('Processing Finish!')

I = []
I_BD = []
mod_weights = []
mod_BD_weights = []
ewc_mod_acc =[0]
orig_mod_acc = []
BD_ewc_mod_acc = [0]
BD_orig_mod_acc =[]

c = 0

I.append(I_old)
mod_weights.append(model.weights)

I_BD.append(I_old)
mod_BD_weights.append(model.weights)


for tidx in range(1,n_tasks):
   ############################################# Loading Data #############################################
  (x_train, y_train), (x_test, y_test), (x_bd_train, y_bd_train), (x_bd_test, y_bd_test) = get_X_and_X_BD_data_for_model_training(
                                      task=tidx, training_datasets=training_datasets,validation_datasets=validation_datasets)
  
  model_ewcB = Sequential()
  model_ewcB.add(Dense(256, activation='relu', input_dim=784, kernel_regularizer=ewc_reg(I, mod_weights, c),
                 bias_regularizer=ewc_reg(I, mod_weights, c+1)))
  model_ewcB.add(Dense(256, activation='relu', kernel_regularizer=ewc_reg(I, mod_weights, c+2),
                 bias_regularizer=ewc_reg(I, mod_weights, c+3)))
  model_ewcB.add(Dense(256, activation='relu', kernel_regularizer=ewc_reg(I, mod_weights, c+4),
                 bias_regularizer=ewc_reg(I, mod_weights, c+5)))
  model_ewcB.add(Dense(10, activation='softmax', kernel_regularizer=ewc_reg(I, mod_weights, c+6),
                 bias_regularizer=ewc_reg(I, mod_weights, c+7)))
  model_ewcB.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
  
  model_bd = Sequential()
  model_bd.add(Dense(256, activation='relu', input_dim=784, kernel_regularizer=ewc_reg(I, mod_weights, c),
                 bias_regularizer=ewc_reg(I, mod_weights, c+1)))
  model_bd.add(Dense(256, activation='relu', kernel_regularizer=ewc_reg(I, mod_weights, c+2),
                 bias_regularizer=ewc_reg(I, mod_weights, c+3)))
  model_bd.add(Dense(256, activation='relu', kernel_regularizer=ewc_reg(I, mod_weights, c+4),
                 bias_regularizer=ewc_reg(I, mod_weights, c+5)))
  model_bd.add(Dense(10, activation='softmax', kernel_regularizer=ewc_reg(I, mod_weights, c+6),
                 bias_regularizer=ewc_reg(I, mod_weights, c+7)))
  model_bd.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])
  
  
  model_ewcB.load_weights('MNISTA.h5')
  model_ewcB.fit(training_datasets[tidx][0], training_datasets[tidx][1],Batch_size, Epochs, validation_data=(validation_datasets[tidx][0] ,validation_datasets[tidx][1]))
  if(tidx == 1):
    model_bd.load_weights('MNISTA.h5')
  else:
    model_bd.load_weights('MNISTA_bd.h5')
  print("Number of backdoor validation examples is {0}".format(len(y_bd_test)))
  model_bd.fit(x_bd_train,y_bd_train,Batch_size, Epochs, validation_data=((np.array(x_bd_test) ,np.array(y_bd_test))))
  
  model_ewcB.save('MNISTA.h5')
  model_bd.save('MNISTA_bd.h5')    
  
  print('Processing Fisher Information...')
  I_new = computer_fisher(model_ewcB, training_datasets[1][0])
  print('Processing Finish!')
  print('Processing Fisher Information with BD...')
  I_new_BD = computer_fisher(model_bd, x_bd_train)
  print('Processing with BD Finish!')

  I.append(I_new)
  mod_weights.append(model_ewcB.weights)
               
  I_BD.append(I_new_BD)
  mod_BD_weights.append(model_bd.weights)
               
   # Current Task Performance
  ewc_mod = 100 * model_ewcB.evaluate(validation_datasets[tidx][0],validation_datasets[tidx][1], verbose=0)[1]
  ewc_mod_acc.append(ewc_mod)
  BD_ewc_mod = 100 * model_bd.evaluate(np.array(x_bd_test),np.array(y_bd_test), verbose=0)[1]
  BD_ewc_mod_acc.append(BD_ewc_mod)           
  # B_No_P = 100 * model_NoP_B.evaluate(validation_datasets[1][0],validation_datasets[1][1], verbose=0)[1]
  # Previous Task Performance
  orig_mod = 100 * model_ewcB.evaluate(validation_datasets[tidx-1][0],validation_datasets[tidx-1][1], verbose=0)[1]
  orig_mod_acc.append(orig_mod)
  BD_orig_mod = 100 * model_bd.evaluate(np.array(x_bd_test),np.array(y_bd_test), verbose=0)[1]
  BD_orig_mod_acc.append(BD_orig_mod)            
  # A_No_P = 100 * model_NoP_B.evaluate(validation_datasets[0][0],validation_datasets[0][1], verbose=0)[1]

  
  print("Initial Task Original Accuracy: %.2f%%" % (100 * model.evaluate(validation_datasets[0][0], validation_datasets[0][1])[1]))
  
#   print("Task B EWC method penalty Accuracy: %.2f%%" % ewc_mod_acc[tidx])
  print("Current Task= {0} penalty accuracy={1}".format(tidx,ewc_mod_acc[tidx]))
  print("Current Task= {0} penalty accuracy with BD ={1}".format(tidx,BD_ewc_mod_acc[tidx]))
#   print("Task A EWC method penalty Accuracy: %.2f%%" % orig_mod_acc[tidx-1])
  for t in range(0,tidx):
    p_acc_orig = 100 * model_ewcB.evaluate(validation_datasets[t][0],validation_datasets[t][1], verbose=0)[1]
    print("Previous Task= {0} penalty accuracy={1}".format(t,p_acc_orig))
    (x_test_bd, y_test_bd), (x_test, y_test) = get_X_BD_data_for_model_eval(task_index=t)
    BD_eval_acc = 100 * model_bd.evaluate(x_test_bd,y_test_bd, verbose=0)[1]
    print("Previous Task= {0} penalty accuracy with BD ={1}".format(t,BD_eval_acc))
  
               
               
print("#"*25)            
print("Initial task EWC End Accuracy: %.2f%%" % (100 * model_ewcB.evaluate(validation_datasets[0][0], validation_datasets[0][1])[1]))           
               
for tidx in range(0,n_tasks):              
  print("#"*25)
  print("Evaluating the final model on the test data with the backdoor at the end...........")
  (x_test_bd, y_test_bd), (x_test, y_test) = get_X_BD_data_for_model_eval(task_index=tidx)
  BD_eval_acc = 100 * model_bd.evaluate(x_test_bd,y_test_bd, verbose=0)[1]
  print("Eval Accuracy with backdoor added to the data on task{0} is {1}".format(tidx,BD_eval_acc))               

 



Using TensorFlow backend.


loading MNIST rotations done......





Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Train on 60000 samples, validate on 10000 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
