In [1]:
from utils import *
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from train import *

In [2]:
batch_size = 2
train_dataset = get_data(path = 'datasets/train',batch_size=batch_size)
val_dataset = get_data(path = 'datasets/valid',batch_size=batch_size)
labels = train_dataset.class_names

Found 2821 files belonging to 5 classes.
Found 818 files belonging to 5 classes.


In [3]:
MobileNetV2_model = tf.keras.applications.MobileNetV2(include_top = False,input_shape=(512,512,3),input_tensor=None)



In [4]:
status = True # Nếu gặp block_4_add thì chuyển status từ False sang True
for layer in MobileNetV2_model.layers:
    layer.trainable = status

In [5]:
model = predictor(backbone = MobileNetV2_model,num_classes = len(labels))

In [6]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 Conv1 (Conv2D)                 (None, 256, 256, 32  864         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 bn_Conv1 (BatchNormalization)  (None, 256, 256, 32  128         ['Conv1[0][0]']                  
                                )                                                             

                                                                                                  
 block_3_expand (Conv2D)        (None, 128, 128, 14  3456        ['block_2_add[0][0]']            
                                4)                                                                
                                                                                                  
 block_3_expand_BN (BatchNormal  (None, 128, 128, 14  576        ['block_3_expand[0][0]']         
 ization)                       4)                                                                
                                                                                                  
 block_3_expand_relu (ReLU)     (None, 128, 128, 14  0           ['block_3_expand_BN[0][0]']      
                                4)                                                                
                                                                                                  
 block_3_p

                                                                                                  
 block_6_depthwise_BN (BatchNor  (None, 32, 32, 192)  768        ['block_6_depthwise[0][0]']      
 malization)                                                                                      
                                                                                                  
 block_6_depthwise_relu (ReLU)  (None, 32, 32, 192)  0           ['block_6_depthwise_BN[0][0]']   
                                                                                                  
 block_6_project (Conv2D)       (None, 32, 32, 64)   12288       ['block_6_depthwise_relu[0][0]'] 
                                                                                                  
 block_6_project_BN (BatchNorma  (None, 32, 32, 64)  256         ['block_6_project[0][0]']        
 lization)                                                                                        
          

 lization)                                                                                        
                                                                                                  
 block_10_expand_relu (ReLU)    (None, 32, 32, 384)  0           ['block_10_expand_BN[0][0]']     
                                                                                                  
 block_10_depthwise (DepthwiseC  (None, 32, 32, 384)  3456       ['block_10_expand_relu[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 block_10_depthwise_BN (BatchNo  (None, 32, 32, 384)  1536       ['block_10_depthwise[0][0]']     
 rmalization)                                                                                     
                                                                                                  
 block_10_

 alization)                                                                                       
                                                                                                  
 block_14_expand (Conv2D)       (None, 16, 16, 960)  153600      ['block_13_project_BN[0][0]']    
                                                                                                  
 block_14_expand_BN (BatchNorma  (None, 16, 16, 960)  3840       ['block_14_expand[0][0]']        
 lization)                                                                                        
                                                                                                  
 block_14_expand_relu (ReLU)    (None, 16, 16, 960)  0           ['block_14_expand_BN[0][0]']     
                                                                                                  
 block_14_depthwise (DepthwiseC  (None, 16, 16, 960)  8640       ['block_14_expand_relu[0][0]']   
 onv2D)   

 dropout (Dropout)              (None, 512)          0           ['dense_1[0][0]']                
                                                                                                  
 dense_2 (Dense)                (None, 5)            2565        ['dropout[0][0]']                
                                                                                                  
Total params: 338,330,693
Trainable params: 338,296,581
Non-trainable params: 34,112
__________________________________________________________________________________________________


In [7]:
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = tf.keras.metrics.Accuracy()
val_acc_metric = tf.keras.metrics.Accuracy()

In [8]:
train_dataset = train_dataset.shuffle(buffer_size=1024)
val_dataset = val_dataset

In [9]:
# Keep results for plotting
train_loss_results = []
train_accuracy_results = []

val_loss_results = []
val_accuracy_results = []

In [10]:
import time
from tqdm import tqdm
epochs = 100
start_time = time.time()
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    
    """
    Training
    """
    train_epoch_loss_avg = tf.keras.metrics.Mean()
    pbar = enumerate(train_dataset)
    pbar = tqdm(pbar, total=len(train_dataset), desc='Epoch',bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
    # Iterate over the batches of the dataset.
    for step,(x_batch_train,y_batch_train) in pbar:
#         y_batch_train = tf.keras.utils.to_categorical(y_batch_train,num_classes=len(labels))
        loss_value,logits,grads = train_on_epoch(x_batch_train,y_batch_train,model,loss_fn)   
#         with tf.GradientTape() as tape:
#             logits = model(x_batch_train, training=True)
#             loss_value = loss_fn(y_batch_train, logits)
#         grads = tape.gradient(loss_value, model.trainable_weights)
        
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        predictions = tf.argmax(logits, axis=1, output_type=tf.int32)
        # Update training metric.
        train_epoch_loss_avg.update_state(loss_value)
        train_acc_metric.update_state(y_batch_train, predictions)
        pbar.set_description("Epoch %i ( loss %.4f - acc %.4f )" % (epoch,train_epoch_loss_avg.result(),train_acc_metric.result()))
        pbar.refresh() # to show immediately the update
        
    train_loss_results.append(train_epoch_loss_avg.result())
    train_accuracy_results.append(train_acc_metric.result())
    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()
    
    """
    Validation
    """
    val_epoch_loss_avg = tf.keras.metrics.Mean()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
#         y_batch_val = tf.keras.utils.to_categorical(y_batch_val,num_classes=len(labels))
        val_logits,loss_value = val_on_epoch(x_batch_val,y_batch_val,model,loss_fn)
        
        #val_logits = model(x_batch_val, training=False)
        #loss_value = loss_fn(y_batch_val, val_logits)
        # Update val metrics
        
        val_acc_metric.update_state(y_batch_val, val_logits)
        val_epoch_loss_avg.update_state(loss_value)
    
    val_loss_results.append(val_epoch_loss_avg.result())
    val_accuracy_results.append(val_acc_metric.result())
    
    print("Validation loss %.4f - acc %.4f" % (val_epoch_loss_avg.result(),val_acc_metric.result()))
    val_acc_metric.reset_states()
    model.save_weights(f"./runs/trainig/weights/cp-{epoch}.ckpt")
    model.load_weights(f"./runs/trainig/weights/cp-{epoch}.ckpt")
print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0


Epoch 0 ( loss 0.7234 - acc 0.7703 ): 100%|██████████| 1411/1411 [03:47<00:00,  6.21it/s]                              


Validation loss 0.1322 - acc 0.9572

Start of epoch 1


Epoch 1 ( loss 0.1216 - acc 0.9553 ): 100%|██████████| 1411/1411 [03:41<00:00,  6.38it/s]                              


Validation loss 0.1201 - acc 0.9658

Start of epoch 2


Epoch 2 ( loss 0.0571 - acc 0.9780 ): 100%|██████████| 1411/1411 [03:40<00:00,  6.41it/s]                              


Validation loss 0.1171 - acc 0.9645

Start of epoch 3


Epoch 3 ( loss 0.0266 - acc 0.9922 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                              


Validation loss 0.0879 - acc 0.9731

Start of epoch 4


Epoch 4 ( loss 0.0368 - acc 0.9872 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                              


Validation loss 0.0905 - acc 0.9792

Start of epoch 5


Epoch 5 ( loss 0.0281 - acc 0.9918 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                              


Validation loss 0.0635 - acc 0.9841

Start of epoch 6


Epoch 6 ( loss 0.0112 - acc 0.9950 ): 100%|██████████| 1411/1411 [03:38<00:00,  6.45it/s]                              


Validation loss 0.0952 - acc 0.9731

Start of epoch 7


Epoch 7 ( loss 0.0062 - acc 0.9979 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                              


Validation loss 0.0811 - acc 0.9817

Start of epoch 8


Epoch 8 ( loss 0.0072 - acc 0.9968 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                              


Validation loss 0.0752 - acc 0.9804

Start of epoch 9


Epoch 9 ( loss 0.0043 - acc 0.9982 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                              


Validation loss 0.0725 - acc 0.9878

Start of epoch 10


Epoch 10 ( loss 0.0088 - acc 0.9965 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.1525 - acc 0.9658

Start of epoch 11


Epoch 11 ( loss 0.0148 - acc 0.9957 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                             


Validation loss 0.0642 - acc 0.9768

Start of epoch 12


Epoch 12 ( loss 0.0034 - acc 0.9989 ): 100%|██████████| 1411/1411 [03:44<00:00,  6.30it/s]                             


Validation loss 0.0693 - acc 0.9792

Start of epoch 13


Epoch 13 ( loss 0.0060 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:41<00:00,  6.37it/s]                             


Validation loss 0.0691 - acc 0.9829

Start of epoch 14


Epoch 14 ( loss 0.0055 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0685 - acc 0.9829

Start of epoch 15


Epoch 15 ( loss 0.0035 - acc 0.9989 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0699 - acc 0.9841

Start of epoch 16


Epoch 16 ( loss 0.0053 - acc 0.9982 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0720 - acc 0.9792

Start of epoch 17


Epoch 17 ( loss 0.0013 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:38<00:00,  6.45it/s]                             


Validation loss 0.0580 - acc 0.9829

Start of epoch 18


Epoch 18 ( loss 0.0054 - acc 0.9982 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0458 - acc 0.9902

Start of epoch 19


Epoch 19 ( loss 0.0013 - acc 0.9996 ): 100%|██████████| 1411/1411 [03:40<00:00,  6.41it/s]                             


Validation loss 0.0919 - acc 0.9829

Start of epoch 20


Epoch 20 ( loss 0.0039 - acc 0.9986 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0518 - acc 0.9866

Start of epoch 21


Epoch 21 ( loss 0.0034 - acc 0.9986 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0605 - acc 0.9841

Start of epoch 22


Epoch 22 ( loss 0.0013 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                             


Validation loss 0.0568 - acc 0.9878

Start of epoch 23


Epoch 23 ( loss 0.0009 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                             


Validation loss 0.0590 - acc 0.9866

Start of epoch 24


Epoch 24 ( loss 0.0005 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0599 - acc 0.9841

Start of epoch 25


Epoch 25 ( loss 0.0004 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.41it/s]                             


Validation loss 0.0523 - acc 0.9853

Start of epoch 26


Epoch 26 ( loss 0.0005 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:43<00:00,  6.33it/s]                             


Validation loss 0.0808 - acc 0.9817

Start of epoch 27


Epoch 27 ( loss 0.0009 - acc 0.9996 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0484 - acc 0.9890

Start of epoch 28


Epoch 28 ( loss 0.0014 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0533 - acc 0.9890

Start of epoch 29


Epoch 29 ( loss 0.0004 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.41it/s]                             


Validation loss 0.0531 - acc 0.9890

Start of epoch 30


Epoch 30 ( loss 0.0003 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                             


Validation loss 0.0693 - acc 0.9841

Start of epoch 31


Epoch 31 ( loss 0.0005 - acc 1.0000 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0592 - acc 0.9866

Start of epoch 32


Epoch 32 ( loss 0.0046 - acc 0.9989 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0710 - acc 0.9890

Start of epoch 33


Epoch 33 ( loss 0.0013 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0954 - acc 0.9804

Start of epoch 34


Epoch 34 ( loss 0.0032 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0627 - acc 0.9853

Start of epoch 35


Epoch 35 ( loss 0.0015 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.42it/s]                             


Validation loss 0.0711 - acc 0.9841

Start of epoch 36


Epoch 36 ( loss 0.0005 - acc 0.9996 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0825 - acc 0.9817

Start of epoch 37


Epoch 37 ( loss 0.0013 - acc 0.9996 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.43it/s]                             


Validation loss 0.0629 - acc 0.9841

Start of epoch 38


Epoch 38 ( loss 0.0012 - acc 0.9993 ): 100%|██████████| 1411/1411 [03:39<00:00,  6.44it/s]                             


Validation loss 0.0728 - acc 0.9841


UnknownError: Failed to WriteFile: ./runs/trainig/weights/cp-38.ckpt_temp/part-00000-of-00001.data-00000-of-00001.tempstate12880951803857840392 : There is not enough space on the disk.
; operation in progress [Op:SaveV2]