In [4]:
# breakhistf2v9.ipynb  v1.0.0 
# 
#  --------------------------------------------------
#  Hangzhou Domain Zones Technology Co., Ltd

#  Apache Licence 2.0       https://www.apache.org/licenses/LICENSE-2.0
#  --------------------------------------------------


import os
import tensorflow as tf
import numpy as np
import csv
import matplotlib.pyplot as plt
from tensorflow import keras
import neural_structured_learning as nsl
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import TensorBoard

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
  
%matplotlib inline

def read_csv(csvnamepath, filename):
     

    # read from csv file
    images, labels = [], []
    with open(os.path.join(csvnamepath, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            img, label = row
            label = int(label)

            images.append(img)
            labels.append(label)

    assert len(images) == len(labels)

    return images, labels

 
def load_breakhis(dirc, mode='train'):
    classdir2label={}
    filename="tf2breakhisCSV"
    filedirs = os.listdir( dirc)
    for filedir in filedirs:
        if not os.path.isdir(os.path.join(dirc,filedir)):
            continue
        classdir2label[filedir]=len(classdir2label.keys())
    

     
    images, labels = read_csv(os.path.join(os.path.abspath('.'),'tf2breakhis'), filename )  
    if mode == 'train':  # 60%
        images = images[:int(0.6 * len(images))]
        labels = labels[:int(0.6 * len(labels))]
    elif mode == 'val':  # 20% = 60%->80%
        images = images[int(0.6 * len(images)):int(0.8 * len(images))]
        labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
    else:  # 20% = 80%->100%
        images = images[int(0.8 * len(images)):]
        labels = labels[int(0.8 * len(labels)):]
    return images, labels, classdir2label
 
 

# 预处理的函数，复制过来。
@tf.function
def preprocess(x,y):
    # x: 图片的路径，y：图片的数字编码
   
    x = tf.io.read_file(x)
     
    x = tf.image.decode_jpeg(x, channels=3) # RGBA
    
    #x = tf.image.resize(x, [350, 230])
    x = tf.image.resize_with_crop_or_pad(x,896,896)   
  
    x = tf.image.resize(x, [224, 224])
    
    x = tf.image.random_flip_left_right(x)
    #plt.imshow(x)
  
    # x = tf.image.random_flip_up_down(x)
    #x = tf.image.random_crop(x, [224,224,3])
    #plt.imshow(x)
    
    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    # x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=8)

    return x, y


# 预处理的函数，复制过来。
@tf.function
def preprocess1(x,y):
    # x: 图片的路径，y：图片的数字编码
   
    x = tf.io.read_file(x)
     
    x = tf.image.decode_jpeg(x, channels=3) # RGBA
    
    #x = tf.image.resize(x, [350, 230])
    x = tf.image.resize_with_crop_or_pad(x,896,896)   
    x = tf.image.resize(x, [224, 224])
    # x = tf.image.random_flip_up_down(x)
    #x = tf.image.random_crop(x, [224,224,3])
      
    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    # x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=8)

    return x, y




 

batchsz = 32
epochnum =10
# creat train db   一般训练的时候需要shuffle。其它是不需要的。
images, labels, table = load_breakhis(os.path.join(os.path.abspath('.'),'tf2breakhis'), 'train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))  # 变成个Dataset对象。
db_train = db_train.shuffle(1000).repeat(2).map(preprocess).shuffle(1000).batch(batchsz)  # map函数图片路径变为内容。
 

# crate validation db
images2, labels2, table = load_breakhis(os.path.join(os.path.abspath('.'),'tf2breakhis'), 'val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.shuffle(1000).repeat(2).map(preprocess).shuffle(1000).batch(batchsz)
# create test db
images3, labels3, table = load_breakhis(os.path.join(os.path.abspath('.'),'tf2breakhis'), 'test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess1).batch(batchsz)

 
resnetdense = keras.Sequential([
    layers.Conv2D(16,5,3),
    layers.MaxPool2D(3,3),
    layers.ReLU(),
    layers.Conv2D(64,5,3),
    layers.MaxPool2D(2,2),
    layers.ReLU(),
    
    layers.Flatten(),
    #
    layers.Dense(128), 
    layers.Dropout(rate=0.5),
    
    
    
    layers.Dense(64),  
    layers.ReLU(),
    layers.Dense(8)
   ])


 

# 首先创建Resnet18
# resnet = ResNet(8)

resnetdense.build(input_shape=(batchsz, 224, 224, 3))
resnetdense.summary()

# monitor监听器, 连续5个验证准确率不增加，这个事情触发。
 
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=100

)

  

 
# 网络的装配。
resnetdense.compile(optimizer=optimizers.Adam(lr=1e-4),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

# 完成标准的train，val, test;
# 标准的逻辑必须通过db_val挑选模型的参数，就需要提供一个earlystopping技术，

LOGDIR='log/breakhistf2v9' 
resnetdense.fit(db_train, validation_data=db_val, 
                        validation_freq=1, epochs=100, 
                        callbacks=[TensorBoard(log_dir=LOGDIR)])   # 1个epoch验证1次。触发了这个事情，提前停止了。
resnetdense.evaluate(db_test)
 
 

    
    
 

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            multiple                  1216      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 multiple                  0         
_________________________________________________________________
re_lu_3 (ReLU)               multiple                  0         
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  25664     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 multiple                  0         
_________________________________________________________________
re_lu_4 (ReLU)               multiple                  0         
_________________________________________________________________
flatten_1 (Flatten)          multiple                 



[1.0454127418994903, 0.60429835]