<a href="https://colab.research.google.com/github/yutan0565/colab_git/blob/main/code/MobileNet_V3_Small_%EC%98%88%EC%8B%9C.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##기본 모델 형성

In [None]:
!pip install -q tensorflow-model-optimization

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import os
import datetime
import time
import tempfile
import pathlib

from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow import keras
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
import tensorflow_model_optimization as tfmot

In [None]:
# Cifar_10 데이터 불러오기
(raw_train_x, raw_train_y), (raw_test_x, raw_test_y) = tf.keras.datasets.cifar10.load_data()

# train, val, test 분리 및 scale 조정
train_x = raw_train_x[:45000].astype(np.float32)/255.0
valid_x = raw_train_x[45000:].astype(np.float32)/255.0
test_x = raw_test_x.astype(np.float32)/255.0


train_y = raw_train_y[:45000]
valid_y = raw_train_y[45000:]
test_y = raw_test_y
labels = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [None]:
# 모델 불러오기 Conv-layer 
mobile = MobileNetV3Small(  #weights = 'imagenet',  #그냥 초기화 하는거면, 이거 지우기
                            include_top = False,
                            input_shape=(32,32,3)
                            )

# FC layer는 다른거 사용   --  Class 10 개 분류
fc_layer = keras.Sequential([
                             layers.Flatten(),
                             layers.Dense(1024, activation = 'relu'),
                             layers.Dense(1024, activation = 'relu'),
                             layers.Dense(1024, activation = 'relu'),
                             layers.Dense(10, activation = "sigmoid")
                             ])

model = keras.Sequential([mobile,
                          fc_layer
                          ])
model.summary()


In [None]:
# Callback 함수 지정 해주기      학습하는 동안 설정해줄것
early_stop = EarlyStopping(patience=30) 
mc = ModelCheckpoint("./best_model/mobile_original_checkpoint", 
                     save_best_only=True,
                     monitor = 'val_loss',
                     verbose = 1,
                     mode = 'min') 
reduce_lr  = ReduceLROnPlateau(monitor = 'val_loss',
                               factor=0.5, 
                               patience=5
                               ) 

#optimizer 조정 해주기
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

In [None]:
# optimizer, loss 함수를 정의하고,  학습 준비를 한다,  metrics 는 어떤 일이 발생하는지 보여줄 것들
model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=["accuracy"])



# 한번에 몇개의 데이터 학습하고 가중치 갱신할지 
model.fit(train_x, train_y,
          epochs=100,
          verbose=1,
          batch_size=32,
          #validation_split = 0.1
          validation_data = (valid_x, valid_y),
          callbacks = [early_stop, reduce_lr , mc]
          )

In [None]:
loss, acc = model.evaluate(test_x, test_y)
print("loss=",loss)
print("acc=",acc)

# soft-max 형태로 결과가 나옴
y_ = model.predict(test_x)

# 모든 test_x 에 대해서 예측을 진행
predicted = np.argmax(y_, axis=1)
print(predicted)

##int_8 Quantization 진행

In [None]:
# 구조까지 들어가 있는거
model = tf.keras.models.load_model('./best_model/mobile_original_checkpoint')

In [None]:
# Quantization을 위한 point 설정
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
    yield [input_value]

In [None]:
# 그냥 파일 형태만 tflite로 변환 (float 형태임)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

In [None]:
# int 8 로 quantization 진행하기
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

In [None]:
# Quantization 모델의 input, output 변환 잘 되었는지 확인 하기
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

In [None]:
# 모델 저장 해주는 과정
# tflite 형태 모델을 어디에 저장 해줄까, 폴더 있으면 패스, 없으면 만들어 주기
tflite_models_dir = pathlib.Path("./tflite/mobile_tflite/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)


In [None]:
# 일반 모델 tflite 형태로 저장
tflite_model_file = tflite_models_dir/"mobile_original_tflite.tflite"
open(tflite_model_file, "wb") .write(tflite_model)

# Quantized 모델 tflite 형태로 저장
tflite_model_quant_file = tflite_models_dir/"mobile_quantization_tflite.tflite"
open(tflite_model_quant_file, "wb") .write(tflite_model_quant)

#tflite_model_quant_file.write_bytes(tflite_model_quant)

In [None]:
# tflite 모델 평가에서 사용되는 실행 함수
def run_tflite_model(tflite_file, test_image_indices):
  global test_x

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_x[test_image_index]
    test_label = test_y[test_image_index]

    # Check if the input type is quantized, then rescale input data to uint8
    if input_details['dtype'] == np.uint8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

In [None]:
def evaluate_model(tflite_file, model_type):
  global test_x
  global test_y

  test_image_indices = range(test_x.shape[0])
  predictions = run_tflite_model(tflite_file, test_image_indices)

  accuracy = (np.sum(test_y.reshape(-1)== predictions) * 100) / len(test_x)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(test_x)))

In [None]:
# FPS 측정용 함수
def run_tflite_time(tflite_file, test_image_indices):
  global test_x

  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  fps = np.zeros((len(test_image_indices),))
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_x[test_image_index]
    test_label = test_y[test_image_index]

    if input_details['dtype'] == np.uint8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

        
    start = time.time()
    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]
    end = time.time()
    fps[i] = (end - start)

  return fps

In [None]:
def test_model_time(tflite_file, model_type):
  global test_x
  test_image_indices = range(test_x.shape[0])
  fps = run_tflite_time(tflite_file, test_image_indices)
  print(np.mean(fps))
  print(1 / np.mean(fps))
  print('%s model FPS is %.4f%% (Number of test samples=%d)' % (model_type, 1 / np.mean(fps), len(test_x)))

##Pruning 진행

In [None]:
model = tf.keras.models.load_model('./best_model/mobile_original_checkpoint')
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

# Pruning을 위한 변수 설정
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 32
epochs = 30
validation_split = 0.1 

num_images = train_x.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

In [None]:
# Sparsity = 0.2,   pruning하고 다시 학습 시키기
pruning_params_2 = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.2,
                                                               begin_step=0,
                                                               end_step=-1,
                                                               frequency = 100
                                                               )
}

mobile_pruning_pruning_2 = prune_low_magnitude(model, **pruning_params_2)

mobile_pruning_pruning_2.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

mobile_pruning_pruning_2.fit(train_x, train_y,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

In [None]:
# Sparsity = 0.4,   pruning하고 다시 학습 시키기
pruning_params_4 = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.4,
                                                               begin_step=0,
                                                               end_step=-1,
                                                               frequency = 100
                                                               )
}

mobile_pruning_pruning_4 = prune_low_magnitude(model, **pruning_params_4)

mobile_pruning_pruning_4.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

mobile_pruning_pruning_4.fit(train_x, train_y,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

In [None]:
# Sparsity = 0.6,   pruning하고 다시 학습 시키기
pruning_params_6 = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.6,
                                                               begin_step=0,
                                                               end_step=-1,
                                                               frequency = 100
                                                               )
}

mobile_pruning_pruning_6 = prune_low_magnitude(model, **pruning_params_6)

mobile_pruning_pruning_6.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

mobile_pruning_pruning_6.fit(train_x, train_y,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

In [None]:
# Sparsity = 0.8,   pruning하고 다시 학습 시키기
pruning_params_8 = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.8,
                                                               begin_step=0,
                                                               end_step=-1,
                                                               frequency = 100
                                                               )
}

mobile_pruning_pruning_8 = prune_low_magnitude(model, **pruning_params_8)

mobile_pruning_pruning_8.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

mobile_pruning_pruning_8.fit(train_x, train_y,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

In [None]:
mobile_pruning_pruning_2_eport = tfmot.sparsity.keras.strip_pruning(mobile_pruning_pruning_2)
mobile_pruning_pruning_4_eport = tfmot.sparsity.keras.strip_pruning(mobile_pruning_pruning_4)
mobile_pruning_pruning_6_eport = tfmot.sparsity.keras.strip_pruning(mobile_pruning_pruning_6)
mobile_pruning_pruning_8_eport = tfmot.sparsity.keras.strip_pruning(mobile_pruning_pruning_8)

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(mobile_pruning_pruning_2_eport)
mobile_pruning_pruning_2_tflite = converter.convert()
converter = tf.lite.TFLiteConverter.from_keras_model(mobile_pruning_pruning_4_eport)
mobile_pruning_pruning_4_tflite = converter.convert()
converter = tf.lite.TFLiteConverter.from_keras_model(mobile_pruning_pruning_6_eport)
mobile_pruning_pruning_6_tflite = converter.convert()
converter = tf.lite.TFLiteConverter.from_keras_model(mobile_pruning_pruning_8_eport)
mobile_pruning_pruning_8_tflite = converter.convert()

In [None]:
mobile_pruning_pruning_2_file = tflite_models_dir/"mobile_pruning_pruning_2_tflite.tflite"
open(mobile_pruning_pruning_2_file, "wb") .write(mobile_pruning_pruning_2_tflite)

mobile_pruning_pruning_4_file = tflite_models_dir/"mobile_pruning_pruning_4_tflite.tflite"
open(mobile_pruning_pruning_4_file, "wb") .write(mobile_pruning_pruning_4_tflite)

mobile_pruning_pruning_6_file = tflite_models_dir/"mobile_pruning_pruning_6_tflite.tflite"
open(mobile_pruning_pruning_6_file, "wb") .write(mobile_pruning_pruning_6_tflite)

mobile_pruning_pruning_8_file = tflite_models_dir/"mobile_pruning_pruning_8_tflite.tflite"
open(mobile_pruning_pruning_8_file, "wb") .write(mobile_pruning_pruning_8_tflite)

In [None]:
# 모델의 bytes 크기 
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)
  print(os.path.getsize(zipped_file))

# 결과창

In [None]:
evaluate_model(tflite_model_file, model_type="Float")
evaluate_model(tflite_model_quant_file, model_type="Quantized")

evaluate_model(mobile_pruning_pruning_2_file, model_type="Pruning_2")
evaluate_model(mobile_pruning_pruning_4_file, model_type="Pruning_4")
evaluate_model(mobile_pruning_pruning_6_file, model_type="Pruning_6")
evaluate_model(mobile_pruning_pruning_8_file, model_type="Pruning_8")

In [None]:
test_model_time(tflite_model_file, model_type="Float")
test_model_time(tflite_model_quant_file,  model_type="Quantized")

test_model_time(mobile_pruning_pruning_2_file,  model_type="Pruning_2")
test_model_time(mobile_pruning_pruning_4_file,  model_type="Pruning_4")
test_model_time(mobile_pruning_pruning_6_file,  model_type="Pruning_6")
test_model_time(mobile_pruning_pruning_8_file,  model_type="Pruning_8")

In [None]:
get_gzipped_model_size(tflite_model_file)
get_gzipped_model_size(tflite_model_quant_file)

get_gzipped_model_size(mobile_pruning_pruning_2_file)
get_gzipped_model_size(mobile_pruning_pruning_4_file)
get_gzipped_model_size(mobile_pruning_pruning_6_file)
get_gzipped_model_size(mobile_pruning_pruning_8_file)