In [2]:
using Pkg
Pkg.add("BenchmarkTools")
using LinearAlgebra
using BenchmarkTools

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Manifest.toml`


In [3]:
mutable struct BiCGSTAB_smth{T}
  A::AbstractMatrix{T}
  b::Vector{T}
  x::Vector{T}
  r::Vector{T}
  r_hat::Vector{T}
  p::Vector{T}
  v::Vector{T}
  h::Vector{T}
  s::Vector{T}
  t::Vector{T}
  alpha::T
  omega::T
  rho_old::T
  tol::T
  max_iter::Int
end

In [18]:
function initialize_state(A, b; tol, max_iter)
  T = eltype(b)
  n = length(b)
  x = zeros(T, n)
  r = b - A * x
  return BiCGSTAB_smth(
      A, b, x, r, rand(T,n),
      zeros(T, n), zeros(T, n),
      zeros(T, n), zeros(T, n), zeros(T, n),
      one(T), one(T), one(T),
      tol, max_iter
  )
end

initialize_state (generic function with 1 method)

In [5]:
function update_p!(state, rho_new, iter)
  if iter == 1
      state.p .= state.r
  else
      beta = (rho_new / state.rho_old) * (state.alpha / state.omega)
      state.p .= state.r .+ beta .* (state.p .- state.omega .* state.v)
  end
end

update_p! (generic function with 1 method)

In [6]:
function step!(state)
  A = state.A
  r_hat = state.r_hat
  r = state.r

  rho_new = dot(r_hat, r)
  if abs(rho_new) < 1e-14
      error("Breakdown: rho_new == 0")
  end

  update_p!(state, rho_new, 1) 
  state.v .= A * state.p
  state.alpha = rho_new / dot(r_hat, state.v)

  state.h .= state.x .+ state.alpha .* state.p
  state.s .= state.r .- state.alpha .* state.v

  if norm(state.s) < state.tol
      state.x .= state.h
      return true  # converged
  end

  state.t .= A * state.s
  denom = dot(state.t, state.t)
  if denom == 0.0
      error("Breakdown: t ⋅ t == 0")
  end

  state.omega = dot(state.t, state.s) / denom
  state.x .= state.h .+ state.omega .* state.s
  state.r .= state.s .- state.omega .* state.t
  state.rho_old = rho_new

  return norm(state.r) < state.tol
end


step! (generic function with 1 method)

In [9]:
function solve_bicgstab!(state)
  if norm(state.r) < state.tol
    println("Initial guess is already within tolerance.")
    return state.x, 0 
end
  for i in 1:state.max_iter
      if step!(state)
          println("Converged at step $i")
          return state.x, i
      end
  end
  @warn "Did not converge."
  return state.x, state.max_iter
end

solve_bicgstab! (generic function with 1 method)

In [None]:
using Test 

@testset "small system" begin
  A = [4.0 1.0; 1.0 3.0]
  b = [1.0, 2.0]
  x0 = zeros(2)
  tol = 1e-8
  max_iter = 100

  state = initialize_state(A, b, tol=tol, max_iter=max_iter)
  x, iters = solve_bicgstab!(state)

  x_expected = A \ b

  @test isapprox(x, x_expected; atol=1e-4)
  @test iters < max_iter
end

Converged at step 7
[0m[1mTest Summary:                          | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
BiCGSTAB solves small system correctly | [32m   2  [39m[36m    2  [39m[0m1.0s


Test.DefaultTestSet("BiCGSTAB solves small system correctly", Any[], 2, false, false, true, 1.745423701475551e9, 1.745423702478149e9, false, "/Users/eronagashi/distr_computing/serialBiCGSTAB/src/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X11sZmlsZQ==.jl")

In [None]:
@testset "2x2 system" begin
  A = [4.0 1.0; 1.0 3.0]
  b = [1.0, 2.0]
  x0 = zeros(2)
  state = initialize_state(A, b; tol=1e-8, max_iter=100)
  x, iters = solve_bicgstab!(state)
  x_expected = A \ b
  @test isapprox(x, x_expected; atol=1e-6)
  @test iters < 100
end

Converged at step 7
[0m[1mTest Summary:                    | [22m[32m[1mPass  [22m[39m[36m[1mTotal  [22m[39m[0m[1mTime[22m
BiCGSTAB solves small 2x2 system | [32m   2  [39m[36m    2  [39m[0m0.0s


Test.DefaultTestSet("BiCGSTAB solves small 2x2 system", Any[], 2, false, false, true, 1.745424046498745e9, 1.745424046508364e9, false, "/Users/eronagashi/distr_computing/serialBiCGSTAB/src/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl")

In [32]:
@testset "Initial guess is already a solution" begin
  A = Matrix{Float64}(I, 3, 3)
  b = [1.0, 2.0, 3.0]
  x0 = b
  state = initialize_state(A, b,tol=1e-8, max_iter=10)
  x, iters = solve_bicgstab!(state)
  @test isapprox(x, b; atol=1e-10)
  @test iters == 0
end

Converged at step 1
Initial guess is already a solution: [91m[1mTest Failed[22m[39m at [39m[1m/Users/eronagashi/distr_computing/serialBiCGSTAB/src/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X23sZmlsZQ==.jl:8[22m
  Expression: iters == 0
   Evaluated: 1 == 0

Stacktrace:
 [1] [0m[1mmacro expansion[22m
[90m   @[39m [90m~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/[39m[90m[4mTest.jl:679[24m[39m[90m [inlined][39m
 [2] [0m[1mmacro expansion[22m
[90m   @[39m [90m~/distr_computing/serialBiCGSTAB/src/[39m[90m[4mjl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X23sZmlsZQ==.jl:8[24m[39m[90m [inlined][39m
 [3] [0m[1mmacro expansion[22m
[90m   @[39m [90m~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/[39m[90m[4mTest.jl:1704[24m[39m[90m [inlined][39m
 [4] top-level scope
[90m   @[39m [90m~/distr_computing/serialBiCGSTAB/src/[39m[90m[4mjl_notebook_cell_df34

TestSetException: Some tests did not pass: 1 passed, 1 failed, 0 errored, 0 broken.