In [2]:
import tensorflow as tf
from tensorflow.keras import optimizers, metrics, layers, datasets,Sequential
from tensorflow import keras

In [3]:
class BasicBlock(layers.Layer):

    def __init__(self ,filter_num ,stride=1):
        # 延用父类的方法
        super(BasicBlock ,self).__init__()

        self.conv1 = layers.Conv2D(filter_num, kernel_size = [3 ,3], strides = stride, padding = 'same')
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')

        self.conv2 = layers.Conv2D(filter_num, kernel_size = [3 ,3], strides = 1, padding = 'same')
        self.bn2 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        if stride != 1:
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num,(1,1), strides = stride, padding='same'))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training = None):

        out = self.conv1(inputs)
        out = self.bn1(out)
        out =self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        identity = self.downsample(inputs)

        output = layers.add([out,identity])

        output = tf.nn.relu(output)

        return output

class ResNet(keras.Model):

    def __init__(self,layer_dims,num_classes = 100): # [2,2,2,2]
        super(ResNet,self).__init__()
        self.stem = Sequential([layers.Conv2D(64,(3,3),strides = (1,1)),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2,2),strides= (1,1),padding='same')])
        self.layer1 = self.build_resblock( 64, layer_dims[0])
        self.layer2 = self.build_resblock( 128, layer_dims[1],stride= 2)
        self.layer3 = self.build_resblock( 256, layer_dims[2],stride = 2)
        self.layer4 = self.build_resblock( 512, layer_dims[3],stride = 2)
        #output : [b,512,h,w]
        self.avgpool = layers.GlobalAveragePooling2D()
        # [b,512,1,1]
        self.fc = layers.Dense(num_classes)

    def call(self,inputs, training = None):
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # [b,c]
        x = self.avgpool(x)
        # 【b,100]
        x =self.fc(x)

        return x

    def build_resblock(self,filter_num,blocks,stride =1):
        res_blocks = Sequential()
        res_blocks.add(BasicBlock(filter_num,stride = stride))

        for _ in range(1,blocks):
            res_blocks.add(BasicBlock(filter_num,stride = 1))

        return res_blocks

def resnet18():
    return ResNet([2,2,2,2])

In [4]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [5]:
(x,y),(x_val,y_val) = datasets.mnist.load_data()

In [6]:
x = tf.reshape(x,[-1 ,28,28,1])

In [8]:
def preprocess(x,y):
    x = 2*tf.cast(x,dtype = tf.float32) / 255. -1
    y = tf.cast(y,dtype= tf.int32)
    return x,y  

In [9]:
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_val,y_val))
test_db = train_db.map(preprocess).batch(128)

In [11]:
model = resnet18()
model.build(input_shape=[None,28,28,1])
optimizer = optimizers.Adam(learning_rate=1e-3)

In [None]:
for epoch in range(20):
    for step,(x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            logits = model(x) 
            y_onehot = tf.one_hot(y,depth = 100)
            loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits= True)
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(loss,model.trainable_variables)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
        if step % 100 ==0:
            print(epoch,step,'loss',loss)
    #
    total_num = 0
    total_corrext = 0
    for x,y in test_db:
   
        logits = model(x)
        pro = tf.nn.softmax(logits,axis =1)
        pred = tf.argmax(prob,axis =1)
        pred = tf.cast(pred,dtype=tf.int32)
        correct = tf.cast(tf.equal(pred,y),dtype= tf.int32)
        correct = tf.reduce_sum(correct)
        total_correct += int(correct)
        total_num += x.shape[0]
    acc = total_correct / total_num
    print(epoch,'acc',acc)
    