Notebook for training a diffusion model on mnist data

Sources:

[The Network Code](https://github.com/FluxML/model-zoo/blob/master/vision/diffusion_mnist/diffusion_mnist.jl)

[The Sampling Code](https://github.com/FluxML/model-zoo/blob/master/vision/diffusion_mnist/diffusion_plot.jl)

[Useful Explainer](https://yang-song.net/blog/2021/score/)

[Original Paper](https://arxiv.org/pdf/2011.13456.pdf)

TO DO:
- Optimize code: training is not as fast as expected? 5min for 10 epochs with 1000 images
    - Mainly due to compiling. After that about 18min for 60k images and 30 epochs
- Make compatible with job submission
- Include more samples: Predictor Corrector and Diff. Eq. solver (see explainer and sampling code links)
- Image inpainting: adapt sampling to image inpainting (see python code attached to original paper)

In [1]:
# Dependencies
using Pkg
Pkg.activate("/home/mverlaan/bathy_machine_learning/julia_ml_tests.jl.git/")
# include("diffusion_unet_model.jl")
include("unet_large.jl")
# include("diffusion_sampling.jl")
using MLDatasets
using CUDA
using ProgressMeter: Progress, next!
using BSON

[32m[1m  Activating[22m[39m project at `~/bathy_machine_learning/julia_ml_tests.jl.git`


In [8]:
# Hyper Parameters
n_images = 60000
channels = [32, 64, 128, 256] # No of channels in Unet conv layers
embed_dim = 256 # dimensionality of Fourier projection
scale = 30.0f0 # scale parameter of Fourier projection
lr = 0.0001

nr_epochs = 50
batch_size = 32
device = gpu

gpu (generic function with 5 methods)

In [9]:
# Get Train data
xtrain, ytrain = MNIST(:train)[1:n_images]
xtrain = reshape(xtrain, (28,28,1,:));

In [10]:
# Things for the training loop

data_loader = Flux.DataLoader((xtrain,ytrain), batchsize=batch_size, shuffle=true) |> device
# if device == gpu
#     data_loader |> device
# end
# data_loader

1875-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), Tuple{Array{Float32, 4}, Vector{Int64}}}, shuffle=true, batchsize=32)
  with first element:
  (28×28×1×32 CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, 32-element CuArray{Int64, 1, CUDA.Mem.DeviceBuffer},)

In [5]:
unet_model = UNET(channels, embed_dim, scale) |> device

opt = Adam(lr)

params = Flux.params(unet_model);

In [6]:
# Test loss function

test_loss = score_loss(unet_model, first(data_loader)[1])

1113.8665f0

In [11]:
# Training loop

progress = Progress(length(1:nr_epochs))
for epoch in 1:nr_epochs
    println("Starting epoch no. $(epoch)")

    batch_losses = zeros(length(data_loader))
    batch_no = 1

    for (x,_) in data_loader
        loss,grad = Flux.withgradient(params) do 
            score_loss(unet_model,x)            
        end
        Flux.Optimise.update!(opt,params,grad)
        batch_losses[batch_no] = loss
        batch_no += 1
    end
    next!(progress, showvalues=[(:loss, mean(batch_losses))])
end

Starting epoch no. 1


Starting epoch no. 2


[32mProgress:   4%|█▋                                       |  ETA: 0:40:10[39m[K
[34m  loss:  36.16273620707194[39m[K[A

Starting epoch no. 3



[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:39:25[39m[K
[34m  loss:  26.968733470662436[39m[K[A

Starting epoch no. 4



[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:38:32[39m[K
[34m  loss:  22.977690797424316[39m[K[A

Starting epoch no. 5



[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:37:41[39m[K
[34m  loss:  20.871688801574706[39m[K[A

Starting epoch no. 6



[K[A[32mProgress:  12%|████▉                                    |  ETA: 0:36:53[39m[K
[34m  loss:  19.697931122843425[39m[K[A

Starting epoch no. 7



[K[A[32mProgress:  14%|█████▊                                   |  ETA: 0:36:03[39m[K
[34m  loss:  18.903551963297527[39m[K[A

Starting epoch no. 8



[K[A[32mProgress:  16%|██████▌                                  |  ETA: 0:35:12[39m[K
[34m  loss:  18.44206424560547[39m[K[A

Starting epoch no. 9



[K[A[32mProgress:  18%|███████▍                                 |  ETA: 0:34:20[39m[K
[34m  loss:  17.70435539855957[39m[K[A

Starting epoch no. 10



[K[A[32mProgress:  20%|████████▎                                |  ETA: 0:33:28[39m[K
[34m  loss:  17.507859184265136[39m[K[A

Starting epoch no. 11



[K[A[32mProgress:  22%|█████████                                |  ETA: 0:32:37[39m[K
[34m  loss:  17.205056246185304[39m[K[A

Starting epoch no. 12



[K[A[32mProgress:  24%|█████████▉                               |  ETA: 0:31:46[39m[K
[34m  loss:  16.914793475087485[39m[K[A

Starting epoch no. 13



[K[A[32mProgress:  26%|██████████▋                              |  ETA: 0:30:55[39m[K
[34m  loss:  16.7863297609965[39m[K[A

Starting epoch no. 14



[K[A[32mProgress:  28%|███████████▌                             |  ETA: 0:30:04[39m[K
[34m  loss:  16.432853061421714[39m[K[A

Starting epoch no. 15



[K[A[32mProgress:  30%|████████████▎                            |  ETA: 0:29:16[39m[K
[34m  loss:  16.458108537038168[39m[K[A

Starting epoch no. 16



[K[A[32mProgress:  32%|█████████████▏                           |  ETA: 0:28:26[39m[K
[34m  loss:  16.251362398783368[39m[K[A

Starting epoch no. 17



[K[A[32mProgress:  34%|██████████████                           |  ETA: 0:27:37[39m[K
[34m  loss:  16.010787084197997[39m[K[A

Starting epoch no. 18



[K[A[32mProgress:  36%|██████████████▊                          |  ETA: 0:26:47[39m[K
[34m  loss:  15.929264224243164[39m[K[A

Starting epoch no. 19



[K[A[32mProgress:  38%|███████████████▋                         |  ETA: 0:25:56[39m[K
[34m  loss:  15.753002808125814[39m[K[A

Starting epoch no. 20



[K[A[32mProgress:  40%|████████████████▍                        |  ETA: 0:25:05[39m[K
[34m  loss:  15.675912519582113[39m[K[A

Starting epoch no. 21



[K[A[32mProgress:  42%|█████████████████▎                       |  ETA: 0:24:14[39m[K
[34m  loss:  15.490304504648844[39m[K[A

Starting epoch no. 22



[K[A[32mProgress:  44%|██████████████████                       |  ETA: 0:23:23[39m[K
[34m  loss:  15.301053005472818[39m[K[A

Starting epoch no. 23



[K[A[32mProgress:  46%|██████████████████▉                      |  ETA: 0:22:32[39m[K
[34m  loss:  15.476687419891357[39m[K[A

Starting epoch no. 24



[K[A[32mProgress:  48%|███████████████████▋                     |  ETA: 0:21:42[39m[K
[34m  loss:  15.344008082834879[39m[K[A

Starting epoch no. 25



[K[A[32mProgress:  50%|████████████████████▌                    |  ETA: 0:20:51[39m[K
[34m  loss:  15.19048392232259[39m[K[A

Starting epoch no. 26



[K[A[32mProgress:  52%|█████████████████████▍                   |  ETA: 0:20:01[39m[K
[34m  loss:  15.222222065989177[39m[K[A

Starting epoch no. 27



[K[A[32mProgress:  54%|██████████████████████▏                  |  ETA: 0:19:11[39m[K
[34m  loss:  14.986678818766276[39m[K[A

Starting epoch no. 28



[K[A[32mProgress:  56%|███████████████████████                  |  ETA: 0:18:21[39m[K
[34m  loss:  15.058788909403484[39m[K[A

Starting epoch no. 29



[K[A[32mProgress:  58%|███████████████████████▊                 |  ETA: 0:17:31[39m[K
[34m  loss:  15.063174933624268[39m[K[A

Starting epoch no. 30



[K[A[32mProgress:  60%|████████████████████████▋                |  ETA: 0:16:41[39m[K
[34m  loss:  15.013068962860107[39m[K[A

Starting epoch no. 31



[K[A[32mProgress:  62%|█████████████████████████▍               |  ETA: 0:15:51[39m[K
[34m  loss:  14.996673910268148[39m[K[A

Starting epoch no. 32



[K[A[32mProgress:  64%|██████████████████████████▎              |  ETA: 0:15:01[39m[K
[34m  loss:  14.768852278645832[39m[K[A

Starting epoch no. 33



[K[A[32mProgress:  66%|███████████████████████████              |  ETA: 0:14:11[39m[K
[34m  loss:  14.954444603474935[39m[K[A

Starting epoch no. 34



[K[A[32mProgress:  68%|███████████████████████████▉             |  ETA: 0:13:20[39m[K
[34m  loss:  14.714894054921468[39m[K[A

Starting epoch no. 35



[K[A[32mProgress:  70%|████████████████████████████▊            |  ETA: 0:12:30[39m[K
[34m  loss:  14.767870145161947[39m[K[A

Starting epoch no. 36



[K[A[32mProgress:  72%|█████████████████████████████▌           |  ETA: 0:11:40[39m[K
[34m  loss:  14.736205023701986[39m[K[A

Starting epoch no. 37



[K[A[32mProgress:  74%|██████████████████████████████▍          |  ETA: 0:10:50[39m[K
[34m  loss:  14.634233628845214[39m[K[A

Starting epoch no. 38



[K[A[32mProgress:  76%|███████████████████████████████▏         |  ETA: 0:10:00[39m[K
[34m  loss:  14.566524045562744[39m[K[A

Starting epoch no. 39



[K[A[32mProgress:  78%|████████████████████████████████         |  ETA: 0:09:10[39m[K
[34m  loss:  14.443915852101643[39m[K[A

Starting epoch no. 40



[K[A[32mProgress:  80%|████████████████████████████████▊        |  ETA: 0:08:20[39m[K
[34m  loss:  14.55989525197347[39m[K[A

Starting epoch no. 41



[K[A[32mProgress:  82%|█████████████████████████████████▋       |  ETA: 0:07:30[39m[K
[34m  loss:  14.616378134155273[39m[K[A

Starting epoch no. 42



[K[A[32mProgress:  84%|██████████████████████████████████▌      |  ETA: 0:06:40[39m[K
[34m  loss:  14.446352947743733[39m[K[A

Starting epoch no. 43



[K[A[32mProgress:  86%|███████████████████████████████████▎     |  ETA: 0:05:50[39m[K
[34m  loss:  14.420840850830078[39m[K[A

Starting epoch no. 44



[K[A[32mProgress:  88%|████████████████████████████████████▏    |  ETA: 0:05:00[39m[K
[34m  loss:  14.28539483795166[39m[K[A

Starting epoch no. 45



[K[A[32mProgress:  90%|████████████████████████████████████▉    |  ETA: 0:04:10[39m[K
[34m  loss:  14.218820863342286[39m[K[A

Starting epoch no. 46



[K[A[32mProgress:  92%|█████████████████████████████████████▊   |  ETA: 0:03:20[39m[K
[34m  loss:  14.256929047393799[39m[K[A

Starting epoch no. 47



[K[A[32mProgress:  94%|██████████████████████████████████████▌  |  ETA: 0:02:30[39m[K
[34m  loss:  14.268455834706625[39m[K[A

Starting epoch no. 48



[K[A[32mProgress:  96%|███████████████████████████████████████▍ |  ETA: 0:01:40[39m[K
[34m  loss:  14.252161951446533[39m[K[A

Starting epoch no. 49



[K[A[32mProgress:  98%|████████████████████████████████████████▏|  ETA: 0:00:50[39m[K
[34m  loss:  14.170373460388184[39m[K[A

Starting epoch no. 50



[K[A[32mProgress: 100%|█████████████████████████████████████████| Time: 0:41:39[39m[K
[34m  loss:  14.226620218658447[39m[K




In [None]:
unet_model

In [None]:
# Model back to cpu

unet_cpu = unet_model |> cpu

In [None]:
test_sample = first(data_loader)[1] |> cpu

In [None]:
test_loss = score_loss(unet_cpu, test_sample)

In [12]:
model = unet_model |> cpu
save_path = joinpath("models","diffusion_fullunet_mnist60k_epoch50.bson")

"models/diffusion_fullunet_mnist60k_epoch50.bson"

In [13]:
# BSON.@save save_path model