In [1]:
using LinearAlgebra

In [2]:
mutable struct BiCGSTAB{T}
  A::AbstractMatrix{T}
  b::Vector{T}
  x::Vector{T}
  r::Vector{T}
  r_hat::Vector{T}
  p::Vector{T}
  v::Vector{T}
  h::Vector{T}
  s::Vector{T}
  t::Vector{T}
  alpha::T
  omega::T
  rho_old::T
  tol::T
  max_iter::Int
end

In [3]:
function initialize_state(A, b; tol=1e-8, max_iter=30)
  T = eltype(b)
  n = length(b)
  x = zeros(T, n)
  r = b - A * x
  return BiCGSTAB(
      A, b, x, r, rand(n),
      zeros(T, n), zeros(T, n),
      zeros(T, n), zeros(T, n), zeros(T, n),
      one(T), one(T), one(T),
      tol, max_iter
  )
end

initialize_state (generic function with 1 method)

In [4]:
function update_p!(state, rho_new, iter)
  if iter == 1
      state.p .= state.r
  else
      beta = (rho_new / state.rho_old) * (state.alpha / state.omega)
      state.p .= state.r .+ beta .* (state.p .- state.omega .* state.v)
  end
end

update_p! (generic function with 1 method)

In [5]:
function step!(state)
  A = state.A
  r_hat = state.r_hat
  r = state.r

  rho_new = dot(r_hat, r)
  if abs(rho_new) < 1e-14
      error("Breakdown: rho_new == 0")
  end

  update_p!(state, rho_new, 1) 
  state.v .= A * state.p
  state.alpha = rho_new / dot(r_hat, state.v)

  state.h .= state.x .+ state.alpha .* state.p
  state.s .= state.r .- state.alpha .* state.v

  if norm(state.s) < state.tol
      state.x .= state.h
      return true  # converged
  end

  state.t .= A * state.s
  denom = dot(state.t, state.t)
  if denom == 0.0
      error("Breakdown: t ⋅ t == 0")
  end

  state.omega = dot(state.t, state.s) / denom
  state.x .= state.h .+ state.omega .* state.s
  state.r .= state.s .- state.omega .* state.t
  state.rho_old = rho_new

  return norm(state.r) < state.tol
end


step! (generic function with 1 method)

In [15]:
function solve_bicgstab!(state)
  for i in 1:state.max_iter
      if step!(state)
          println("Converged at step $i")
          return state.x, i
      end
  end
  @warn "Did not converge."
  return state.x, state.max_iter
end

solve_bicgstab! (generic function with 1 method)

In [16]:
A = [2.0 -1.0;
     -1.0 2.0]
b = [0.0, 0.0]

state = initialize_state(A,b)
solution, steps = solve_bicgstab!(state)

ErrorException: Breakdown: rho_new == 0