### ResNet and ImageNet

在该实现中您可以看到如下功能：
1. 读取 ImageFolder 并进行预处理，切分 batch
2. 模型的读入和保存
3. 对模型的训练和测试的封装

In this template you can finish the following functions:
1. Read ImageFolder and pre-process it, divide it into batches
2. Reading and saving the model
3. Encapsulation of model training and testing

In [1]:
using Flux, Metalhead, Statistics
using Flux: onehotbatch, onecold, logitcrossentropy, throttle, flatten
using Metalhead: trainimgs
using Parameters: @with_kw
using Images: channelview
using Statistics: mean
using Base.Iterators: partition
using CUDAapi

In [2]:
using CUDAapi, CUDAdrv, CUDAnative
gpu_id = 0  ## set < 0 for no cuda, >= 0 for using a specific device (if available)

if has_cuda_gpu() && gpu_id >=0
    device!(gpu_id)
    device = Flux.gpu
    @info "Training on GPU-$(gpu_id)"
else
    device = Flux.cpu
    @info "Training on CPU"
end

┌ Info: Training on GPU-0
└ @ Main In[2]:7


In [3]:
using Parameters: @with_kw
@with_kw mutable struct Args
    batch_size::Int = 64
    lr::Float64 = 5e-5
    epochs::Int = 10
    patience::Int = 5
    data_workers::Int = 4
    train_data_dir::String = "/home/zhangzhi/Data/ImageNet2012/train"
    val_data_dir::String = "/home/zhangzhi/Data/ImageNet2012/train"
end

Args

In [4]:
args = Args()

Args
  batch_size: Int64 64
  lr: Float64 5.0e-5
  epochs: Int64 10
  patience: Int64 5
  data_workers: Int64 4
  train_data_dir: String "/home/zhangzhi/Data/ImageNet2012/train"
  val_data_dir: String "/home/zhangzhi/Data/ImageNet2012/train"


模仿 pytorch 使用多个 worker 读取数据集，进行预处理，并且分为n个batch。这里将其封装为单独的文件。

Imitating pytorch, we uses multiple workers to read the data set, preprocess it, and divide it into n batches. Here it is packaged as a separate file.

In [5]:
include("dataset.jl")

In [6]:
train_dataset = ImagenetDataset(args.train_data_dir, args.data_workers, args.batch_size, imagenet_train_data_loader)
val_dataset = ImagenetDataset(args.val_data_dir, args.data_workers, args.batch_size, imagenet_val_data_loader)

┌ Info: Adding 4 new data workers...
└ @ Main /data/zhangzhi/julia/dataset.jl:190
┌ Info: Adding 4 new data workers...
└ @ Main /data/zhangzhi/julia/dataset.jl:190


ImagenetDataset("/home/zhangzhi/Data/ImageNet2012/train", 64, imagenet_val_data_loader, ["n01440764/n01440764_10026.JPEG", "n01440764/n01440764_10027.JPEG", "n01440764/n01440764_10029.JPEG", "n01440764/n01440764_10040.JPEG", "n01440764/n01440764_10042.JPEG", "n01440764/n01440764_10043.JPEG", "n01440764/n01440764_10048.JPEG", "n01440764/n01440764_10066.JPEG", "n01440764/n01440764_10074.JPEG", "n01440764/n01440764_1009.JPEG"  …  "n15075141/n15075141_9816.JPEG", "n15075141/n15075141_9819.JPEG", "n15075141/n15075141_9835.JPEG", "n15075141/n15075141_9855.JPEG", "n15075141/n15075141_9907.JPEG", "n15075141/n15075141_9915.JPEG", "n15075141/n15075141_9933.JPEG", "n15075141/n15075141_9942.JPEG", "n15075141/n15075141_999.JPEG", "n15075141/n15075141_9993.JPEG"], QueuePool([6, 7, 8, 9], RemoteChannel{Channel{Tuple}}(1, 1, 29), RemoteChannel{Channel{Tuple}}(1, 1, 30), RemoteChannel{Channel{Bool}}(1, 1, 31), 0, Dict{Int64,Any}()))

定义 ResNet。

Define ResNet.

In [7]:
using Metalhead

resnet = ResNet()
model = Chain(resnet.layers[1:end-1]) |> device
Flux.trainmode!(model, true)
opt = ADAM(args.lr)
model.layers

(Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad = (1, 1), stride = (2, 2)), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Co

对模型的训练和测试的封装。

Encapsulation of model training and testing.

In [8]:
using BSON
using Tracker
using Statistics, Printf
using Flux.Optimise

function save_model(model, filename)
    model_state = Dict(
        :weights => Tracker.data.(params(model))
    )
    open(filename, "w") do io
        BSON.bson(io, model_state)
    end
end

function load_model!(model, filename)
    weights = BSON.load(filename)[:weights]
    Flux.loadparams!(model, weights)
    return model
end

@with_kw mutable struct State
    epoch::Int = 1
    train_loss_history = []
    val_loss_history = []
end

state = State()

process_minibatch = (model, opt, x, y) -> begin
    x = x |> device
    y = y |> device
    #@show model_to_host(y_hat)
    #@show model_to_host(y)
    loss(x, y) = logitcrossentropy(model(x), y)
    Flux.train!(loss, params(model), [(x, y)], opt)
    batch_loss = logitcrossentropy(model(x), y)
    @show batch_loss
    return Tracker.data(batch_loss |> cpu)
end


function train_epoch(model, opt)
    # Clear out any previous training loss history
    while length(state.train_loss_history) < state.epoch
        push!(state.train_loss_history, Float64[])
    end
    state.train_loss_history[state.epoch] = zeros(Float64, length(train_dataset))

    batch_idx = 1
    avg_batch_time = 0.0
    t_last = time()
    for (x, y) in train_dataset
        # Store training loss into loss history
        state.train_loss_history[state.epoch][batch_idx] = process_minibatch(model, opt, x, y)

        # Update average batch time
        t_now = time()
        avg_batch_time = .99*avg_batch_time + .01*(t_now - t_last)
        t_last = t_now

        # Calculate ETA
        time_left = avg_batch_time*(length(train_dataset) - batch_idx)
        hours = floor(Int,time_left/(60*60))
        minutes = floor(Int, (time_left - hours*60*60)/60)
        seconds = time_left - hours*60*60 - minutes*60
        eta = @sprintf("%dh%dm%ds", hours, minutes, seconds)

        # Show a smoothed loss approximation per-minibatch
        smoothed_loss = mean(state.train_loss_history[state.epoch][max(batch_idx-50,1):batch_idx])
        println(@sprintf(
            "[TRAIN %d - %d/%d]: avg loss: %.4f, avg time: %.2fs, ETA: %s ",
            state.epoch, batch_idx, length(train_dataset), smoothed_loss,
            avg_batch_time,  eta,
        ))

        batch_idx += 1
    end
end

function validate(model)
    # Get the "fast model", 
    fast_model = Flux.mapleaves(Tracker.data, model)
    Flux.testmode!(fast_model, true)

    avg_loss = 0
    batch_idx = 1
    for (x, y) in val_dataset
        # Push x through our fast model and calculate loss
        y_hat = fast_model(x)
        avg_loss += cpu(Flux.crossentropy(y_hat, y))

        print(@sprintf(
            "\r[VAL %d - %d/%d]: %.2f",
            state.epoch, batch_idx, length(val_dataset), avg_loss/batch_idx,
        ))
        batch_idx += 1
    end
    avg_loss /= length(val_dataset)
    push!(state.val_loss_history, avg_loss)

    # Return the average loss for this epoch
    return avg_loss
end


function train(model, opt)
    # Initialize best_epoch to epoch 0, with Infinity loss
    best_epoch = (0, Inf)

    while state.epoch < args.epochs
        # Early-stop if we don't improve after `args.patience` epochs
        if state.epoch > best_epoch[1] + args.patience
            @info("Losing patience at epoch $(state.epoch)!")
            break
        end

        # Train for an epoch
        train_epoch(model, opt)
        
        # Validate to see how much we've improved
        epoch_loss = validate(model)

        # Check to see if this epoch is the best we've seen so far
        if epoch_loss < best_epoch[2]
            best_epoch = (state.epoch, epoch_loss)
        end

        # Save our training state every epoch (but only save the model weights
        # if this was the best epoch yet)
        state.epoch += 1
    end
end

train (generic function with 1 method)

In [None]:
# Train away, train away, train away |> 's/train/sail/ig'
@info("Beginning training run...")
train(model, opt)

┌ Info: Beginning training run...
└ @ Main In[9]:2
┌ Info: Creating IIS with 1281167 images
└ @ Main /data/zhangzhi/julia/dataset.jl:212


batch_loss = 6.9622507f0
[TRAIN 1 - 1/20018]: avg loss: 6.9623, avg time: 0.78s, ETA: 4h21m11s 
batch_loss = 6.9277363f0
[TRAIN 1 - 2/20018]: avg loss: 6.9450, avg time: 0.81s, ETA: 4h30m20s 
batch_loss = 6.942004f0
[TRAIN 1 - 3/20018]: avg loss: 6.9440, avg time: 0.84s, ETA: 4h38m37s 
batch_loss = 6.9768744f0
[TRAIN 1 - 4/20018]: avg loss: 6.9522, avg time: 0.86s, ETA: 4h45m42s 
batch_loss = 6.912876f0
[TRAIN 1 - 5/20018]: avg loss: 6.9443, avg time: 0.87s, ETA: 4h51m37s 
batch_loss = 7.0515375f0
[TRAIN 1 - 6/20018]: avg loss: 6.9622, avg time: 0.89s, ETA: 4h58m27s 
batch_loss = 7.0313864f0
[TRAIN 1 - 7/20018]: avg loss: 6.9721, avg time: 0.91s, ETA: 5h4m52s 
batch_loss = 7.167213f0
[TRAIN 1 - 8/20018]: avg loss: 6.9965, avg time: 0.93s, ETA: 5h11m43s 
batch_loss = 7.07827f0
[TRAIN 1 - 9/20018]: avg loss: 7.0056, avg time: 0.96s, ETA: 5h19m8s 
batch_loss = 7.0133095f0
[TRAIN 1 - 10/20018]: avg loss: 7.0063, avg time: 0.97s, ETA: 5h24m56s 
batch_loss = 7.131474f0
[TRAIN 1 - 11/20018]: 