In [82]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.losses import Loss
import time 
# from tensorflow.keras.applications import *
# from matplotlib import plotly as plt
import matplotlib.pyplot as plt
import numpy as np

In [83]:
batch_size = 32
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

N = x_train.shape[0] // 2
x_train = np.concatenate((x_train[:N], x_train[N:]), axis=1)
y_train = np.stack((y_train[:N], y_train[N:]), axis=1)
N = x_test.shape[0] // 2

x_test = np.concatenate((x_test[:N], x_test[N:]), axis=1)
y_test = np.stack((y_test[:N], y_test[N:]), axis=1)

print(x_train.shape)
print(y_train.shape)


input_shape = x_train[0].shape

x_train = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train = tf.convert_to_tensor(y_train, dtype=tf.float32)
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32)
y_test = tf.convert_to_tensor(y_test, dtype=tf.float32)

x_test = tf.expand_dims(x_test, -1)
x_train = tf.expand_dims(x_train, -1)
print(x_test.shape)
print(y_test.shape)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)

input_shape = x_test[0].shape
num_classes = 10
print(input_shape)



(30000, 56, 28)
(30000, 2)
(5000, 56, 28, 1)
(5000, 2)
(56, 28, 1)


In [84]:

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation="relu"),
        layers.Dense(2*num_classes, activation="softmax"),
    ]
)

model.summary()


optimizer = tf.keras.optimizers.Adam(0.001)

cce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# loss_fn = cce_loss
def loss_fn(y, y_hat): 
#     print(y[:, 0], y_hat[:, :10])
    return cce_loss(y[:, 0], y_hat[:, :10]) + cce_loss(y[:, 0], y_hat[:, 10:])

#         cce_loss(y_batch_train[:, 0], logits[:, :10]) + cce_loss(y_batch_train[:, 0], logits[:, 10:])



Model: "sequential_26"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_104 (Conv2D)         (None, 54, 26, 64)        640       
                                                                 
 conv2d_105 (Conv2D)         (None, 52, 24, 64)        36928     
                                                                 
 conv2d_106 (Conv2D)         (None, 50, 22, 32)        18464     
                                                                 
 max_pooling2d_52 (MaxPoolin  (None, 25, 11, 32)       0         
 g2D)                                                            
                                                                 
 conv2d_107 (Conv2D)         (None, 23, 9, 32)         9248      
                                                                 
 max_pooling2d_53 (MaxPoolin  (None, 11, 4, 32)        0         
 g2D)                                                

In [85]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=loss_fn
)

model.fit(
    train_dataset,
    epochs=6,
    validation_data=test_dataset,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.callbacks.History at 0x7f46e0147c10>

In [72]:

forward_pass_times = []
backward_pass_times = []

epochs = 5
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        
#         print(y_batch_train[0])
#         plt.figure()
#         plt.imshow(x_batch_train[0, :, :, 0]) 
#         plt.show()       
        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            start = time.time()
            logits = model(x_batch_train, training=True)  # Logits for this minibatch
            end = time.time()
            forward_pass_times.append(end - start)

            # Compute the loss value for this minibatch.
#             print(y_batch_train.shape) 
#             print(logits.shape) 
            loss_value = loss_fn(y_batch_train, logits)

#             loss_value = cce_loss(y_batch_train[:, 0], logits[:, :10]) + cce_loss(y_batch_train[:, 0], logits[:, 10:])

        start = time.time()
        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        backward_pass_times.append(time.time() - start)
        
        if step > 100: 
            break

#     print(f'forward pass 1-batch time = {np.array(forward_pass_times).mean()}')
#     print(f'backwards pass 1-batch time = {np.array(backward_pass_times).mean()}')

    print(f"Training loss (for one epoch): {float(loss_value)})")


Start of epoch 0
Training loss (for one epoch): 4.607485294342041)

Start of epoch 1
Training loss (for one epoch): 4.607485294342041)

Start of epoch 2
Training loss (for one epoch): 4.638735294342041)

Start of epoch 3
Training loss (for one epoch): 4.701235294342041)

Start of epoch 4
Training loss (for one epoch): 4.669985294342041)
