diff --git a/qasm/parser/parser.go b/qasm/parser/parser.go index 34a1c88..3359753 100644 --- a/qasm/parser/parser.go +++ b/qasm/parser/parser.go @@ -68,6 +68,9 @@ type parser struct { // Gate definitions: name → gatedef. gates map[string]*gatedef + + // Parameter substitution for gate body expansion. + paramValues map[string]float64 } type gatedef struct { @@ -375,6 +378,61 @@ func (p *parser) parseGateDecl() error { return nil } +// expandGateCall inlines a user-defined gate by re-parsing its body tokens +// with qubit and parameter substitutions. +func (p *parser) expandGateCall(name string, gd *gatedef, params []float64, qubits []int) error { + if len(params) != len(gd.params) { + return fmt.Errorf("gate %s requires %d parameters, got %d", name, len(gd.params), len(params)) + } + if len(qubits) != len(gd.qubits) { + return fmt.Errorf("gate %s requires %d qubit arguments, got %d", name, len(gd.qubits), len(qubits)) + } + + // Empty body - identity gate, nothing to emit. + if len(gd.body) == 0 { + return nil + } + + // Build sub-parser to expand body tokens. + bodyTokens := make([]token.Token, len(gd.body)+1) + copy(bodyTokens, gd.body) + bodyTokens[len(gd.body)] = token.Token{Type: token.EOF} + + sub := &parser{ + tokens: bodyTokens, + cfg: p.cfg, + gates: p.gates, + qregs: make(map[string]register), + cregs: p.cregs, + numQubits: p.numQubits, + numClbits: p.numClbits, + metadata: make(map[string]string), + } + + // Map gate qubit parameter names to actual qubit indices. + for i, qname := range gd.qubits { + sub.qregs[qname] = register{start: qubits[i], size: 1} + } + + // Map gate parameter names to actual values. + if len(gd.params) > 0 { + sub.paramValues = make(map[string]float64, len(gd.params)) + for i, pname := range gd.params { + sub.paramValues[pname] = params[i] + } + } + + // Parse the body. + for sub.peek() != token.EOF { + if err := sub.parseStatement(); err != nil { + return fmt.Errorf("in gate %s: %w", name, err) + } + } + + p.ops = append(p.ops, sub.ops...) + return nil +} + func (p *parser) parseMeasure() error { t := p.advance() // consume 'measure' qubits, err := p.parseQubitArgs() @@ -991,6 +1049,14 @@ func (p *parser) parseGateCall() error { return err } + // Expand user-defined gate bodies inline when no modifiers are applied. + totalModControls := ctrlCount + negctrlCount + if totalModControls == 0 && invCount == 0 && !hasPow { + if gd, ok := p.gates[gateName]; ok { + return p.expandGateCall(gateName, gd, params, qubits) + } + } + g, err := p.resolveGate(gateName, params) if err != nil { if p.cfg.strict { @@ -1387,6 +1453,13 @@ func (p *parser) parsePrimary() (float64, error) { _, err = p.expect(token.RPAREN) return v, err case token.IDENT: + // Check gate parameter substitution before function dispatch. + if p.paramValues != nil { + if v, ok := p.paramValues[t.Literal]; ok { + p.advance() + return v, nil + } + } // Built-in functions: sin, cos, tan, sqrt, exp, log, arccos, etc. p.advance() fname := t.Literal diff --git a/qasm/parser/parser_test.go b/qasm/parser/parser_test.go index 30f9689..f590f89 100644 --- a/qasm/parser/parser_test.go +++ b/qasm/parser/parser_test.go @@ -140,12 +140,15 @@ c = measure q; if err != nil { t.Fatal(err) } - // mygate (opaque) + 2 measurements = 3 ops - if len(c.Ops()) != 3 { - t.Errorf("len(Ops) = %d, want 3", len(c.Ops())) + // mygate expands to h + cx, plus 2 measurements = 4 ops + if len(c.Ops()) != 4 { + t.Errorf("len(Ops) = %d, want 4", len(c.Ops())) } - if c.Ops()[0].Gate.Name() != "mygate" { - t.Errorf("Ops[0].Gate.Name() = %q, want mygate", c.Ops()[0].Gate.Name()) + if c.Ops()[0].Gate.Name() != "H" { + t.Errorf("Ops[0].Gate.Name() = %q, want H", c.Ops()[0].Gate.Name()) + } + if c.Ops()[1].Gate.Name() != "CNOT" { + t.Errorf("Ops[1].Gate.Name() = %q, want CNOT", c.Ops()[1].Gate.Name()) } } diff --git a/sim/statevector/dynamic.go b/sim/statevector/dynamic.go index 0afe2d0..ffeffa0 100644 --- a/sim/statevector/dynamic.go +++ b/sim/statevector/dynamic.go @@ -155,7 +155,11 @@ func readClassicalValue(clbits []int, indices []int) int { func (s *Sim) applyOp(op ir.Operation) error { switch op.Gate.Qubits() { case 1: - s.applyGate1(op.Qubits[0], op.Gate.Matrix()) + m := op.Gate.Matrix() + if m == nil { + return fmt.Errorf("gate %q has no matrix representation", op.Gate.Name()) + } + s.applyGate1(op.Qubits[0], m) case 2: s.dispatchGate2(op.Gate, op.Qubits[0], op.Qubits[1]) case 3: diff --git a/sim/statevector/kernel2q.go b/sim/statevector/kernel2q.go index d57e130..545810a 100644 --- a/sim/statevector/kernel2q.go +++ b/sim/statevector/kernel2q.go @@ -103,6 +103,9 @@ func (s *Sim) dispatchGate2(g gate.Gate, q0, q1 int) { // Generic fallback. m := g.Matrix() + if m == nil { + return // opaque gate with no matrix - treat as identity + } if parallel { s.kernel2qGenericParallel(q0, q1, m) } else { diff --git a/sim/statevector/kernel3q.go b/sim/statevector/kernel3q.go index 44afb91..1906da7 100644 --- a/sim/statevector/kernel3q.go +++ b/sim/statevector/kernel3q.go @@ -36,6 +36,9 @@ func (s *Sim) dispatchGate3(g gate.Gate, q0, q1, q2 int) { // Generic fallback. m := g.Matrix() + if m == nil { + return // opaque gate with no matrix - treat as identity + } if parallel { s.kernel3qGenericParallel(q0, q1, q2, m) } else { diff --git a/sim/statevector/kernel_controlled.go b/sim/statevector/kernel_controlled.go index ce23735..ef62149 100644 --- a/sim/statevector/kernel_controlled.go +++ b/sim/statevector/kernel_controlled.go @@ -16,6 +16,9 @@ func (s *Sim) dispatchControlled(cg gate.ControlledGate, qubits []int) { switch cg.Inner().Qubits() { case 1: m := cg.Inner().Matrix() + if m == nil { + return // opaque gate with no matrix - treat as identity + } if s.numQubits >= 17 { s.applyControlledGate1Parallel(controls, targets[0], m) } else { @@ -23,6 +26,9 @@ func (s *Sim) dispatchControlled(cg gate.ControlledGate, qubits []int) { } case 2: m := cg.Inner().Matrix() + if m == nil { + return // opaque gate with no matrix - treat as identity + } if s.numQubits >= 17 { s.applyControlledGate2Parallel(controls, targets[0], targets[1], m) } else { diff --git a/sim/statevector/sim.go b/sim/statevector/sim.go index 03c438c..bbcb3f3 100644 --- a/sim/statevector/sim.go +++ b/sim/statevector/sim.go @@ -136,7 +136,11 @@ func (s *Sim) Apply(c *ir.Circuit) error { } switch op.Gate.Qubits() { case 1: - s.applyGate1(op.Qubits[0], op.Gate.Matrix()) + m := op.Gate.Matrix() + if m == nil { + return fmt.Errorf("gate %q has no matrix representation", op.Gate.Name()) + } + s.applyGate1(op.Qubits[0], m) case 2: s.dispatchGate2(op.Gate, op.Qubits[0], op.Qubits[1]) case 3: