In [1]:
using Pkg
Pkg.activate("/home/wkh/XLA.jl")

"/home/wkh/XLA.jl/Project.toml"

In [None]:
using Distributions
Pkg.resolve()
using Distributions

In [2]:
using TensorFlow, XLA, Flux, Zygote, Printf

┌ Info: Precompiling TensorFlow [1d978283-2c37-5f34-9a8e-e9c0ece82495]
└ @ Base loading.jl:1186
└ @ TensorFlow ~/.julia/packages/TensorFlow/eu9qM/src/TensorFlow.jl:3
│ - If you have Distributions checked out for development and have
│   added Test as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with Distributions
│   exception = ErrorException("Required dependency Distributions [31c24e10-a181-5473-b8eb-7969acd0382f] failed to load from a cache file.")
└ @ Base loading.jl:969
└ @ TensorFlow /home/wkh/.julia/packages/TensorFlow/eu9qM/src/TensorFlow.jl:3
┌ Info: Recompiling stale cache file /home/wkh/.julia/compiled/v1.1/Distributions/xILW0.ji for Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]
└ @ Base loading.jl:1184
│ - If you have Distributions checked out for development and have
│   added Test as a dependency but haven't updated your primary
│   environment's manifest file, try 

In [3]:
include("/home/wkh/XLA.jl/examples/resnet50.jl")
include("/home/wkh/XLA.jl/examples/preprocessing_utils.jl")
include("/home/wkh/XLA.jl/examples/model_utils.jl")

epoch_loop (generic function with 1 method)

In [4]:
model = resnet50()
tpu_model = map_to_tpu(model);

In [5]:
function get_minibatch_data(::Val{batch_size}) where {batch_size}
    # Construct HloInfeed object that will expect to receive a tuple
    # of two arrays, one for `x` and one for `y`.  Note that incorrect sizes
    # here will cause...unexpected results, so do your best not to do that.
    # We feed data in as 1-dimensional UInt32 arrays, then reshape them.
    infeed = XLA.HloInfeed(Tuple{
        XRTArray{UInt32, (224*224*batch_size,), 1},
        XRTArray{UInt32, (batch_size,), 1},
    })

    # Read in from the infeed
    (x, y), _ = infeed(XLA.HloAfterAll()())
    x = reshape(x, (224, 224, batch_size))
    
    # Do pixel unpacking/channel normalization.
    x = unpack_pixels(x)

    # Convert labels to (dense) onehot representation
    y = make_onehot(y)
    #y = convert(XRTArray{Float32}, Flux.OneHotMatrix(1000, convert(XRTArray{Int64}, y)))
    
    # Return our data!
    return x, y
end

get_minibatch_data (generic function with 1 method)

In [6]:
struct SGD
    # Learning rate; the only data this optimizer needs to bundle with itself
    η::XRTArray{Float32,(),0}
end

# Simplest update step in existence.
update!(model::XRTArray, Δ::XRTArray, η) = model - (Δ .* η)

# If this leaf node had no updates calculated for it, then skip out early.
update!(model, Δ::Nothing, η) = model

function update!(model, Δ, η)
    # Base condition; if we have reached a leaf node return the inputs unchanged.
    # Note that if `model` is an XRTArray, we will hit the override above that actually
    # updates the model rather than this generic update!(), same for if Δ is `nothing`.
    if nfields(model) == 0
        return model
    end
    
    # Recursively pass the fields of this model through the update machinery.  We use
    # this strange ntuple() do-block because we cannot perform any kind of mutation
    # (such as push!()'ing onto a list) and so we adopt this more functional-style of
    # programming.
    new_fields = ntuple(Val(nfields(model))) do i
        return update!(getfield(model, i), getfield(Δ, i), η)
    end
    
    # Return something of the same type as `model`, but with the new fields
    if isa(model, Tuple)
        return new_fields
    else
        return typeof(model)(new_fields...)
    end
end

# Main entry point for this optimizer's update steps
update!(opt::SGD, model, Δ) = update!(model, Δ, opt.η)

update! (generic function with 4 methods)

In [7]:
# Define our training loop
function train_loop(::Val{batch_size}, model, nbatches, η) where {batch_size}
    # Initialize optimizer, will allocate space for all necessary statistics within itself
    opt = SGD(η)

    # Run until nbatches is zero
    while nbatches > XRTArray(0)
        # Get next minibatch of data
        mb_data = get_minibatch_data(Val(batch_size))

        # Let block to fend off the inference demons
        loss, back = let x = mb_data[1], y = mb_data[2]
            # Calculate forward pass to get loss, and compile backwards pass
            # to get the updates to our model weights.
            Zygote._forward(
                Zygote.Context{Nothing}(nothing),
                model -> logitcrossentropy(model(x), y),
                model,
            )
        end

        # Evaluate the backwards pass.  Zygote automatically calculates
        # sensitivities upon `x` and `y`; we discard those via the tail()
        Δ_model = Zygote.tailmemaybe(back(1f0))[1]

        # Update parameters via our optimizer
        model = update!(opt, model, Δ_model)

        # Outfeed the loss
        loss = reshape(loss, (1,))
        XLA.HloOutfeed()((loss,), XLA.HloAfterAll()())

        # Count down the batches
        nbatches -= XRTArray(1)
    end
    
    # At the end of all things, return the trained model
    return model
end

train_loop (generic function with 1 method)

In [8]:
XLA.initialize!("10.240.1.2:8470"; reset=true)

2019-11-14 16:27:03.054477: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:349] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.


TPUSession(<8 TPU chips in 2x2x2 topology>)

In [9]:
# Train in batch sizes of 128, for 10000 batches with a learning rate of 1e-2
batch_size = 128
num_batches = 10000
η = 0.01f0

# We need to work around a TensorFlow bug which doesn't choose the default infeed/outfeed
# device placement properly.  We do so by explicitly placing all operations on the first TPU core
sess = XLA.global_session()
tpu_device = first(all_tpu_devices(sess))

# Compile the model
t_start = time()
compilation_handle = @tpu_compile devices=[tpu_device] train_loop(Val(batch_size), tpu_model, XRTArray(num_batches), XRTArray(η));
t_end = time()

println(@sprintf("=> Compiled training loop in %.1f seconds", t_end - t_start))

=> Compiled training loop in 262.5 seconds
