In [None]:
import tensorflow as tf
import numpy as np
import time
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# 데이터 로드 및 전처리
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# B-LeNet 모델 정의
class BLeNet(tf.keras.Model):
    def __init__(self):
        super(BLeNet, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(6, (5,5), activation='relu')
        self.pool1 = tf.keras.layers.MaxPooling2D((2,2))
        self.conv2 = tf.keras.layers.Conv2D(16, (5,5), activation='relu')
        self.pool2 = tf.keras.layers.MaxPooling2D((2,2))
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(120, activation='relu')
        self.dense2 = tf.keras.layers.Dense(84, activation='relu')
        self.dense3 = tf.keras.layers.Dense(10, activation='softmax')
        self.branch_dense = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, x):
        x1 = self.conv1(x)
        x = self.pool1(x1)
        x2 = self.conv2(x)
        x = self.pool2(x2)
        branch_output = self.branch_dense(self.flatten(x1))  # branch output
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        main_output = self.dense3(x)  # main output
        return main_output, branch_output

# 모델 초기화
model = BLeNet()

# 손실 함수 설정
loss_fn = tf.keras.losses.CategoricalCrossentropy()

# 옵티마이저 설정
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.99, beta_2=0.999)

# 가중치 설정
weights = np.array([1.0, 0.3], dtype=np.float32)

# 가중치가 적용된 손실 함수
def weighted_loss(y_true, y_pred, weights):
    loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return tf.reduce_mean(loss * weights)

# 훈련 함수 정의
@tf.function
def train_step(images, labels, model, loss_fn, optimizer, weights):
    with tf.GradientTape() as tape:
        main_output, branch_output = model(images)
        loss = weighted_loss(labels, main_output, weights[0]) + weighted_loss(labels, branch_output, weights[1])
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # 정확도 계산
    correct_main = tf.equal(tf.argmax(main_output, 1), tf.argmax(labels, 1))
    correct_branch = tf.equal(tf.argmax(branch_output, 1), tf.argmax(labels, 1))
    accuracy_main = tf.reduce_mean(tf.cast(correct_main, tf.float32))
    accuracy_branch = tf.reduce_mean(tf.cast(correct_branch, tf.float32))

    return loss, accuracy_main, accuracy_branch

# 훈련 데이터셋을 사용하여 모델 훈련
for epoch in range(20):  # 20 epoch 동안 훈련
    total_loss = 0
    total_accuracy_main = 0
    total_accuracy_branch = 0
    num_batches = 0
    for i in range(0, len(train_images), 32):  # 32의 배치 크기로 훈련
        images = train_images[i:i+32]
        labels = train_labels[i:i+32]
        loss, accuracy_main, accuracy_branch = train_step(images, labels, model, loss_fn, optimizer, weights)
        total_loss += loss
        total_accuracy_main += accuracy_main
        total_accuracy_branch += accuracy_branch
        num_batches += 1

    avg_loss = total_loss / num_batches
    avg_accuracy_main = total_accuracy_main / num_batches
    avg_accuracy_branch = total_accuracy_branch / num_batches
    print(f"Epoch {epoch}, Loss: {avg_loss.numpy()}, Main Accuracy: {avg_accuracy_main.numpy()*100:.2f}%, Branch Accuracy: {avg_accuracy_branch.numpy()*100:.2f}%")

# 빠른 추론 함수
def fast_inference(model, x, threshold):
    softmax = tf.keras.layers.Softmax()
    main_output, branch_output = model(x)
    outputs = [branch_output, main_output]

    # 각 exit에서의 출력을 검사
    for i, output in enumerate(outputs):
        softmax_output = softmax(output)  # y
        entropy = -np.sum(softmax_output * np.log(softmax_output + 1e-20))
        print(entropy)
        # 엔트로피가 임계값보다 낮으면, 해당 exit의 출력을 반환
        if entropy < threshold:
            return np.argmax(softmax_output.numpy(), axis=-1), i

    # 모든 exit를 통과한 후, 마지막 exit의 출력을 반환
    return np.argmax(softmax(main_output).numpy(), axis=-1), len(outputs) - 1

# 임계값 설정
threshold = 0.025 #2.23

# 테스트 데이터셋을 사용하여 빠른 추론 수행 및 시간 측정
start_time = time.time()
exit_counts = np.zeros(2)
total_correct = 0

for i in range(len(test_images)):
    prediction, exit = fast_inference(model, np.expand_dims(test_images[i], axis=0), threshold)
    exit_counts[exit] += 1
    if prediction == np.argmax(test_labels[i]):
        total_correct += 1

end_time = time.time()
total_time = end_time - start_time

# 정확도, 시간, 각 exit에서의 비율 출력
accuracy = total_correct / len(test_images)
exit_ratios = exit_counts / len(test_images)

print(f"Accuracy: {accuracy*100:.2f}%")
print(f"Time: {total_time:.2f} seconds")
print(f"Exit ratios: {exit_ratios}")

Epoch 0, Loss: 0.3487972021102905, Main Accuracy: 92.15%, Branch Accuracy: 91.62%
Epoch 1, Loss: 0.10347329825162888, Main Accuracy: 97.83%, Branch Accuracy: 96.79%
Epoch 2, Loss: 0.07394653558731079, Main Accuracy: 98.43%, Branch Accuracy: 97.73%
Epoch 3, Loss: 0.05916569381952286, Main Accuracy: 98.75%, Branch Accuracy: 98.18%
Epoch 4, Loss: 0.04864739999175072, Main Accuracy: 98.97%, Branch Accuracy: 98.47%
Epoch 5, Loss: 0.038913026452064514, Main Accuracy: 99.19%, Branch Accuracy: 98.62%
Epoch 6, Loss: 0.03509292006492615, Main Accuracy: 99.22%, Branch Accuracy: 98.75%
Epoch 7, Loss: 0.033243004232645035, Main Accuracy: 99.27%, Branch Accuracy: 98.89%
Epoch 8, Loss: 0.030967870727181435, Main Accuracy: 99.35%, Branch Accuracy: 99.02%
Epoch 9, Loss: 0.02431274950504303, Main Accuracy: 99.50%, Branch Accuracy: 99.13%
Epoch 10, Loss: 0.02086465060710907, Main Accuracy: 99.61%, Branch Accuracy: 99.23%
Epoch 11, Loss: 0.01951160468161106, Main Accuracy: 99.62%, Branch Accuracy: 99.27%


KeyboardInterrupt: 