# Linear Algebra Threading

In [5]:
using LinearAlgebra

A = rand(2000, 2000);
B = rand(2000, 2000);

# Precompile the matrix multiplication
A*B;

# Single thread
begin
    BLAS.set_num_threads(1)
    @show BLAS.get_num_threads()
    @time A*B
end

# All threads on the machine
begin
    BLAS.set_num_threads(Sys.CPU_THREADS)
    @show BLAS.get_num_threads()
    @time A*B
end

BLAS.get_num_threads() = 1
  0.211251 seconds (2 allocations: 30.518 MiB)
BLAS.get_num_threads() = 4
  0.124283 seconds (2 allocations: 30.518 MiB, 1.41% gc time)


2000×2000 Matrix{Float64}:
 485.913  498.432  484.765  479.025  …  499.488  481.598  491.073  487.787
 498.413  502.77   488.947  489.869     508.926  491.233  502.087  499.032
 494.15   501.504  486.227  492.665     497.294  487.428  495.098  499.047
 500.173  506.289  495.724  491.75      498.706  482.809  497.076  503.163
 504.567  513.291  504.303  500.12      506.772  494.67   506.264  508.77
 498.99   495.169  490.198  492.764  …  505.851  479.577  499.977  503.583
 482.553  487.092  483.417  482.972     490.316  478.439  490.105  493.793
 499.127  506.137  498.5    489.557     506.614  495.374  499.704  504.379
 495.89   505.844  489.836  489.214     507.048  486.017  502.481  504.705
 501.3    495.071  499.276  494.009     502.404  489.003  500.824  501.011
 487.408  504.001  497.758  495.156  …  501.588  484.862  501.403  499.836
 482.87   496.119  484.207  484.988     494.966  484.551  488.235  497.132
 490.514  494.055  487.844  481.352     495.964  479.838  489.168  495.748

# Julia threading

In [6]:
using Base.Threads

In [7]:
nthreads()

1

There are three main ways of approaching multithreading:

1. Using `@threads` to parallelize a for loop to run in multiple threads.
2. Using `@spawn` and `@sync` to spawn tasks in threads and synchronize them at the end of the block.
3. Using `@spawn` and `fetch` to spawn tasks and fetch their return values once they are complete.



Example of `@threads`

In [None]:
a = zeros(Int, 2*nthreads())
@threads for i in eachindex(a)
    a[i] = threadid()
end
println(a)

Example of `@spawn` and `@sync`

In [None]:
function task(b, chunk)
     for i in chunk
         b[i] = threadid()
     end
end

# Using @sync and @spawn macros (also dynamic scheduling)
b = zeros(Int, 2 * nthreads())
chunks = Iterators.partition(eachindex(b), length(b) ÷ nthreads())
@sync for chunk in chunks
    @spawn task(b, chunk)
end
println(b)

Example of `@spawn` and `fetch`

In [None]:
# Using @spawn and fetch
t = [@spawn threadid() for _ in 1:2*nthreads()]
c = fetch.(t)

# Performance comparison

In [None]:
function sqrt_array(A)
  B = similar(A)
  for i in eachindex(A)
      @inbounds B[i] = sqrt(A[i])
  end
  B
end

In [None]:
function threaded_sqrt_array(A)
  B = similar(A)
  @threads for i in eachindex(A)
      @inbounds B[i] = sqrt(A[i])
  end
  B
end

In [None]:
function sqrt_array!(A, B, chunk)
  for i in chunk
      @inbounds B[i] = sqrt(A[i])
  end
end

function threaded_sqrt_array2(A)
  B = similar(A)
  chunks = Iterators.partition(eachindex(A), length(A) ÷ nthreads())
  @sync for chunk in chunks
      @spawn sqrt_array!(A, B, chunk)
  end
  B
end

In [None]:
A = rand(1000, 1000)

using BenchmarkTools

@btime sqrt_array(A);
@btime threaded_sqrt_array(A);
@btime threaded_sqrt_array2(A);

Output from one attempt with `-t 3`

```
julia> @btime sqrt_array(A);
  665.875 μs (2 allocations: 7.63 MiB)

julia> @btime threaded_sqrt_array(A);
  263.750 μs (20 allocations: 7.63 MiB)

julia> @btime threaded_sqrt_array2(A);
  418.916 μs (32 allocations: 7.63 MiB)
```

# Race conditions

`@thread` is fast, powerful and easy to implement but it can only be used if each operation is **independent of each other**. Otherwise, race conditions might lead to wrong results. The following is an example:

In [None]:
# Slow but correct

function sqrt_sum(A)
  s = zero(eltype(A))
  for i in eachindex(A)
      @inbounds s += sqrt(A[i])
  end
  return s
end

In [None]:
# Fast but incorrect. Returns the sum only for a subset of the whole input array.

function threaded_sqrt_sum(A)
  s = zero(eltype(A))
  @threads for i in eachindex(A)
      @inbounds s += sqrt(A[i])
  end
  return s
en

In [None]:
# Fast and correct but after refactoring

function sqrt_sum(A, chunk)
  s = zero(eltype(A))
  for i in chunk
      @inbounds s += sqrt(A[i])
  end
  return s
end

function threaded_sqrt_sum_workaround(A)
  chunks = Iterators.partition(eachindex(A), length(A) ÷ nthreads())
  tasks = map(chunks) do chunk
      @spawn sqrt_sum(A, chunk)
  end
  s = mapreduce(fetch, +, tasks; init=zero(eltype(A)))
  return s
end