In [1]:
import oneflow as flow
import oneflow.typing as tp

In [2]:
BATCH_SIZE = 100

In [3]:
def lenet(data, train=False):
    initializer = flow.truncated_normal(0.1)
    conv1 = flow.layers.conv2d(
        data,
        32,
        5,
        padding="SAME",
        activation=flow.nn.relu,
        name="conv1",
        kernel_initializer=initializer,
    )
    pool1 = flow.nn.max_pool2d(
        conv1, ksize=2, strides=2, padding="SAME", name="pool1", data_format="NCHW"
    )
    conv2 = flow.layers.conv2d(
        pool1,
        64,
        5,
        padding="SAME",
        activation=flow.nn.relu,
        name="conv2",
        kernel_initializer=initializer,
    )
    pool2 = flow.nn.max_pool2d(
        conv2, ksize=2, strides=2, padding="SAME", name="pool2", data_format="NCHW"
    )
    reshape = flow.reshape(pool2, [pool2.shape[0], -1])
    hidden = flow.layers.dense(
        reshape,
        512,
        activation=flow.nn.relu,
        kernel_initializer=initializer,
        name="dense1",
    )
    if train:
        hidden = flow.nn.dropout(hidden, rate=0.5, name="dropout")
    return flow.layers.dense(hidden, 10, kernel_initializer=initializer, name="dense2")


In [4]:
#训练
@flow.global_function(type="train")
def train_job(
    images: tp.Numpy.Placeholder((BATCH_SIZE, 1, 28, 28), dtype=flow.float),
    labels: tp.Numpy.Placeholder((BATCH_SIZE,), dtype=flow.int32),
) -> tp.Numpy:
    with flow.scope.placement("gpu", "0:0"):
        logits = lenet(images, train=True)
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
            labels, logits, name="softmax_loss"
        )

    lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.1])
    flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(loss)
    return loss


In [5]:
flow.config.gpu_device_num(1)

In [6]:
check_point = flow.train.CheckPoint()
check_point.init()

In [7]:
   (train_images, train_labels), (test_images, test_labels) = flow.data.load_mnist(
        BATCH_SIZE, BATCH_SIZE
    )

File mnist.npz already exist, path: ./mnist.npz


In [8]:
for epoch in range(20):
    for i, (images, labels) in enumerate(zip(train_images, train_labels)):
        loss = train_job(images, labels)
        if i % 20 == 0:
            print(loss.mean())

5.4304714
1.1816862
0.42266613
0.3483231
0.23331955
0.24564157
0.2680526
0.23363036
0.24496683
0.12906724
0.2550579
0.081116155
0.17715953
0.11035161
0.10084554
0.1626459
0.19110979
0.07018511
0.17930636
0.047810417
0.083348
0.20465498
0.12702033
0.19150552
0.13384208
0.064903006
0.12788668
0.18227144
0.084598154
0.08224721
0.07146042
0.1441957
0.128395
0.026951086
0.040065285
0.16551672
0.095258825
0.09995119
0.08750807
0.06958059
0.071375236
0.04716398
0.029938051
0.04587161
0.052131426
0.11289297
0.08389154
0.05752547
0.11423657
0.061923917
0.089823395
0.069212444
0.082397126
0.097064696
0.049636483
0.03532106
0.09834539
0.14888397
0.053040933
0.06039332
0.054609302
0.0954063
0.057975076
0.054392677
0.055142924
0.11056551
0.090062626
0.024433298
0.07084119
0.051630173
0.082517155
0.017306574
0.017913405
0.05696125
0.034146875
0.07452651
0.05722817
0.040682297
0.1078422
0.04957526
0.024497861
0.09798545
0.07414208
0.08338656
0.07315427
0.013292475
0.02770225
0.090204656
0.05740983
0.

In [9]:
check_point.save("./model/lenet_models_1")