Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: test test-race test-all lint vet fuzz bench coverage clean hooks
.PHONY: test test-race test-all lint vet fuzz bench bench-gpu-cuda test-gpu-cuda coverage clean hooks

test:
go test -count=1 -timeout=5m ./...
Expand All @@ -10,6 +10,8 @@ test-all: test-race
cd backend/braket && go test -race -count=1 -timeout=5m ./...
cd observe/otelbridge && go build ./...
cd observe/prombridge && go build ./...
cd sim/gpu/cuda && go build ./...
cd sim/gpu/metal && go build ./...

lint:
golangci-lint run ./...
Expand All @@ -20,6 +22,8 @@ vet:
cd backend/braket && go vet ./...
cd observe/otelbridge && go vet ./...
cd observe/prombridge && go vet ./...
cd sim/gpu/cuda && go vet ./...
cd sim/gpu/metal && go vet ./...

fuzz:
go test ./qasm/parser -run=^$$ -fuzz=FuzzParse -fuzztime=30s
Expand All @@ -41,6 +45,12 @@ bench:
go test ./sim/statevector/ -bench=. -count=5 -benchmem -run=^$$ -timeout=10m
go test ./sim/densitymatrix/ -bench=. -count=5 -benchmem -run=^$$ -timeout=10m

test-gpu-cuda:
cd sim/gpu/cuda && go test -tags cuda -count=1 -timeout=5m ./...

bench-gpu-cuda:
cd sim/gpu/cuda && go test -tags cuda -bench=. -count=5 -benchmem -run=^$$ -timeout=10m

coverage:
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out -o coverage.html
Expand Down
37 changes: 31 additions & 6 deletions backend/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@ import (

"github.com/splch/goqu/backend"
"github.com/splch/goqu/observe"
"github.com/splch/goqu/sim"
"github.com/splch/goqu/sim/pulsesim"
"github.com/splch/goqu/sim/statevector"
"github.com/splch/goqu/transpile/target"
)

var _ backend.Backend = (*Backend)(nil)

// SimFactory creates a Simulator for the given number of qubits.
// The default factory creates a CPU statevector simulator.
type SimFactory func(numQubits int) (sim.Simulator, error)

// Backend runs circuits on the local statevector simulator.
type Backend struct {
maxQubits int
results sync.Map // jobID → *backend.Result
logger *slog.Logger
maxQubits int
simFactory SimFactory
results sync.Map // jobID → *backend.Result
logger *slog.Logger
}

// Option configures a local Backend.
Expand All @@ -39,9 +45,21 @@ func WithLogger(l *slog.Logger) Option {
return func(b *Backend) { b.logger = l }
}

// WithSimulator sets a custom simulator factory. This allows plugging in
// alternative simulators (e.g., GPU-accelerated) while keeping the same backend API.
func WithSimulator(f SimFactory) Option {
return func(b *Backend) { b.simFactory = f }
}

// New creates a local simulator backend.
func New(opts ...Option) *Backend {
b := &Backend{maxQubits: 28, logger: slog.Default()}
b := &Backend{
maxQubits: 28,
simFactory: func(numQubits int) (sim.Simulator, error) {
return statevector.New(numQubits), nil
},
logger: slog.Default(),
}
for _, opt := range opts {
opt(b)
}
Expand Down Expand Up @@ -92,8 +110,15 @@ func (b *Backend) Submit(ctx context.Context, req *backend.SubmitRequest) (*back
)

start := time.Now()
sim := statevector.New(nq)
counts, err := sim.Run(req.Circuit, req.Shots)
s, sErr := b.simFactory(nq)
if sErr != nil {
if simDone != nil {
simDone(sErr)
}
return nil, fmt.Errorf("local: %w", sErr)
}
defer func() { _ = s.Close() }()
counts, err := s.Run(req.Circuit, req.Shots)
elapsed := time.Since(start)

if simDone != nil {
Expand Down
3 changes: 3 additions & 0 deletions sim/densitymatrix/sim.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
"github.com/splch/goqu/sim/pauli"
)

// Close is a no-op for the CPU density matrix simulator.
func (s *Sim) Close() error { return nil }

// parallelThreshold is the minimum number of qubits before enabling parallel kernels.
// At 9 qubits dim=512 and the density matrix has 262K elements; the heavier
// per-element work (row + column passes) justifies a lower threshold than statevector.
Expand Down
64 changes: 64 additions & 0 deletions sim/gpu/cuda/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//go:build cuda

package cuda

import (
"testing"

"github.com/splch/goqu/circuit/builder"
"github.com/splch/goqu/sim/statevector"
)

// BenchmarkGPU_GHZ benchmarks GPU GHZ circuit creation across qubit counts.
func BenchmarkGPU_GHZ(b *testing.B) {
for _, nq := range []int{12, 16, 20, 24, 28} {
bld := builder.New("ghz", nq)
bld.H(0)
for i := range nq - 1 {
bld.CNOT(i, i+1)
}
c, err := bld.Build()
if err != nil {
b.Fatal(err)
}

b.Run(qName("GPU", nq), func(b *testing.B) {
sim, err := New(nq)
if err != nil {
b.Skip("CUDA not available:", err)
}
defer sim.Close()
b.ResetTimer()
for range b.N {
sim.Evolve(c)
}
})
}
}

// BenchmarkCPU_GHZ benchmarks CPU GHZ for comparison.
func BenchmarkCPU_GHZ(b *testing.B) {
for _, nq := range []int{12, 16, 20, 24, 28} {
bld := builder.New("ghz", nq)
bld.H(0)
for i := range nq - 1 {
bld.CNOT(i, i+1)
}
c, err := bld.Build()
if err != nil {
b.Fatal(err)
}

b.Run(qName("CPU", nq), func(b *testing.B) {
sim := statevector.New(nq)
b.ResetTimer()
for range b.N {
sim.Evolve(c)
}
})
}
}

func qName(prefix string, nq int) string {
return prefix + "_" + string(rune('0'+nq/10)) + string(rune('0'+nq%10)) + "Q"
}
79 changes: 79 additions & 0 deletions sim/gpu/cuda/expect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//go:build cuda

package cuda

/*
#include <custatevec.h>

// goPauliExpect computes ⟨ψ|P|ψ⟩ for a Pauli string on the GPU.
static custatevecStatus_t goPauliExpect(
custatevecHandle_t handle,
const void *sv, cudaDataType_t svDataType, int nQubits,
double *expectationValue,
const custatevecPauli_t *pauliOps, const int32_t *basisBits, int nBasisBits
) {
return custatevecComputeExpectation(
handle, sv, svDataType, (uint32_t)nQubits,
expectationValue, CUDA_R_64F,
pauliOps, basisBits, (uint32_t)nBasisBits
);
}
*/
import "C"
import (
"fmt"

"github.com/splch/goqu/sim/pauli"
)

// pauliToCUSV maps goqu Pauli values to cuStateVec Pauli enum values.
func pauliToCUSV(p pauli.Pauli) C.custatevecPauli_t {
switch p {
case pauli.I:
return C.CUSTATEVEC_PAULI_I
case pauli.X:
return C.CUSTATEVEC_PAULI_X
case pauli.Y:
return C.CUSTATEVEC_PAULI_Y
case pauli.Z:
return C.CUSTATEVEC_PAULI_Z
default:
return C.CUSTATEVEC_PAULI_I
}
}

// ExpectPauliString computes Re(⟨ψ|P|ψ⟩) for a Pauli string P on the GPU.
func (s *Sim) ExpectPauliString(ps pauli.PauliString) (float64, error) {
if ps.NumQubits() != s.numQubits {
return 0, fmt.Errorf("cuda: PauliString has %d qubits, simulator has %d",
ps.NumQubits(), s.numQubits)
}

ops := ps.Ops()
nBasis := len(ops)
if nBasis == 0 {
return real(ps.Coeff()), nil
}

pauliOps := make([]C.custatevecPauli_t, nBasis)
basisBits := make([]C.int32_t, nBasis)
i := 0
for qubit, p := range ops {
pauliOps[i] = pauliToCUSV(p)
basisBits[i] = C.int32_t(qubit)
i++
}

var expect C.double
st := C.goPauliExpect(
s.handle.h,
s.devicePtr.ptr, C.CUDA_C_64F, C.int(s.numQubits),
&expect,
&pauliOps[0], &basisBits[0], C.int(nBasis),
)
if st != C.CUSTATEVEC_STATUS_SUCCESS {
return 0, fmt.Errorf("custatevecComputeExpectation failed: status %d", int(st))
}

return real(ps.Coeff()) * float64(expect), nil
}
7 changes: 7 additions & 0 deletions sim/gpu/cuda/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module github.com/splch/goqu/sim/gpu/cuda

go 1.24

require github.com/splch/goqu v0.0.0

replace github.com/splch/goqu => ../../../
51 changes: 51 additions & 0 deletions sim/gpu/cuda/handle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//go:build cuda

package cuda

/*
#cgo LDFLAGS: -lcustatevec -lcudart
#include <custatevec.h>
#include <cuda_runtime.h>

// createHandle wraps custatevecCreate.
static custatevecStatus_t goCreateHandle(custatevecHandle_t *handle) {
return custatevecCreate(handle);
}

// destroyHandle wraps custatevecDestroy.
static custatevecStatus_t goDestroyHandle(custatevecHandle_t handle) {
return custatevecDestroy(handle);
}
*/
import "C"
import (
"fmt"
"unsafe"
)

type cusvHandle struct {
h C.custatevecHandle_t
stream C.cudaStream_t
}

type deviceAlloc struct {
ptr unsafe.Pointer
size int // number of complex128 elements
}

func createHandle() (cusvHandle, error) {
var h cusvHandle
if st := C.goCreateHandle(&h.h); st != C.CUSTATEVEC_STATUS_SUCCESS {
return h, fmt.Errorf("custatevecCreate failed: status %d", int(st))
}
if st := C.cudaStreamCreate(&h.stream); st != C.cudaSuccess {
C.goDestroyHandle(h.h)
return h, fmt.Errorf("cudaStreamCreate failed: status %d", int(st))
}
return h, nil
}

func destroyHandle(h cusvHandle) {
C.cudaStreamDestroy(h.stream)
C.goDestroyHandle(h.h)
}
Loading
Loading