In [1]:
from __future__ import annotations

from collections import defaultdict
from enum import Enum
from typing import Annotated, ClassVar, Generic, Literal, Type, TypeVar, Union

from pydantic import BaseModel, Field, model_validator

# ------------------------------------------------------------------
# 1.  PORT SHAPES
# ------------------------------------------------------------------


class NoPort(BaseModel):
    kind: Literal["none"] = "none"


class IntPort(BaseModel):
    kind: Literal["int"] = "int"
    port: int  # 0-based output or input number


class ModulePort(BaseModel):
    kind: Literal["module"] = "module"
    module: int  # slot in a crate
    port: int  # channel inside that slot


# Discriminated union => automatic JSON tagging & validation
Port = Annotated[Union[NoPort, IntPort, ModulePort], Field(discriminator="kind")]

# ------------------------------------------------------------------
# 2.  INSTRUMENTS
# ------------------------------------------------------------------

P = TypeVar("P", bound=BaseModel)


class InstrumentKind(str, Enum):
    RF_AWG = "rf_awg"
    LF_AWG = "lf_awg"
    UP_CONVERTER = "upconverter"
    MIXER = "mixer"
    CURRENT_GEN = "current_gen"
    ADC = "adc"
    VNA = "vna"
    OSCILLOSCOPE = "oscilloscope"
    FRIDGE = "fridge"


class ImplTag(str, Enum):
    # ── Pulsed-drive buses ───────────────────────────────────────────
    PDRIVE_RF_AWG_SINGLE = "pdrive_rf_awg_single"
    PDRIVE_LF_AWG_UPCONV = "pdrive_lf_awg_upconv"
    PDRIVE_LF_AWG_UPCONV_MIXER = "pdrive_lf_awg_upconv_mixer"

    # ── DC-flux buses ────────────────────────────────────────────────
    DCFLUX_SINGLE_SRC = "dcflux_single"

    # …add new tags here as needed


class Instrument(BaseModel, Generic[P]):
    """
    Generic instrument.

    Sub-classes MUST set:
      * Port       – the shape of their connectors
      * kinds      – one or more InstrumentKind roles they can play
      * impl_tags  – identifiers of bus-implementations they are legal in

    Optionally they override validate_port(..) for value-range checks.
    """

    name: str

    Port: ClassVar[Type[P]]
    kinds: ClassVar[tuple[InstrumentKind, ...]]
    impl_tags: ClassVar[tuple[ImplTag, ...]]

    __hash__ = object.__hash__  # identity hash

    @classmethod
    def validate_port(cls, port: P) -> P:  # pragma: no cover
        """Override for per-instrument numeric limits, etc."""
        return port

    def setup(self, ctx: InstrumentContext) -> None:
        """
        Override in subclasses.

        The default implementation just prints, so nothing breaks if
        you forget to implement the hook on a new instrument.
        """
        print(f"[DRY-RUN] {self.name}: would set up for {ctx.impl_tag} " f"with {list(ctx.associations)}")


# --- some example instruments --------------------------------------


class QbloxCluster(Instrument[IntPort]):
    Port = IntPort
    kinds = (InstrumentKind.RF_AWG, InstrumentKind.LF_AWG)
    impl_tags = (
        ImplTag.PDRIVE_RF_AWG_SINGLE,
        ImplTag.PDRIVE_LF_AWG_UPCONV,
    )

    @classmethod
    def validate_port(cls, port: IntPort) -> IntPort:
        if not 0 <= port.port <= 7:
            raise ValueError("Cluster has 8 outputs (0-7)")
        return port

    def setup(self, ctx: InstrumentContext) -> None:
        super().setup(ctx)
        if ctx.impl_tag is ImplTag.PDRIVE_RF_AWG_SINGLE:
            # only one association, key "rf"
            port = ctx.associations["rf"].port  # IntPort
            # self.device.set_rf_awg(channel=port.id)         # placeholder call

        elif ctx.impl_tag is ImplTag.PDRIVE_LF_AWG_UPCONV:
            i_port = ctx.associations["i"].port
            q_port = ctx.associations["q"].port
            # two LF-AWG channels drive I/Q
            # self.device.set_lf_awg_iq(i_channel=i_port.id, q_channel=q_port.id)

        else:
            raise NotImplementedError(f"{self.__class__.__name__} cannot handle {ctx.impl_tag}")


class ZurichUpConverter(Instrument[IntPort]):
    Port = IntPort
    kinds = (InstrumentKind.UP_CONVERTER,)
    impl_tags = ImplTag.PDRIVE_LF_AWG_UPCONV

    @classmethod
    def validate_port(cls, port: IntPort) -> IntPort:
        if port.port not in (0, 1, 2):
            raise ValueError("Up-converter ports are 0, 1 (I/Q) and 2 (RF out)")
        return port


class DeltaCurrentGen(Instrument[NoPort]):
    Port = NoPort
    kinds = (InstrumentKind.CURRENT_GEN,)
    impl_tags = (ImplTag.DCFLUX_SINGLE_SRC,)


# ------------------------------------------------------------------
# 4.  ASSOCIATION  (instrument + one connector)
# ------------------------------------------------------------------


I = TypeVar("I", bound=Instrument)  # concrete instrument type


class Association(BaseModel, Generic[I, P]):
    instrument: I
    port: P

    @model_validator(mode="after")
    def _shape_and_value_checks(self):
        # 1) shape must match whatever the instrument declares
        expected = self.instrument.Port
        if not isinstance(self.port, expected):
            raise ValueError(f"Port must be {expected.__name__}")
        # 2) instrument-specific range rules
        self.instrument.validate_port(self.port)
        return self


# Shortcut when we do not care about the generics in annotations
AnyAssoc = Association[Instrument, Port]

# ------------------------------------------------------------------
# 5.  BUS IMPLEMENTATIONS
# ------------------------------------------------------------------


class BusImplementation(BaseModel):
    """
    A *concrete* way to realise a logical bus.

    Sub-classes set:
      * kind           – a short id, referenced by instruments.impl_tags
      * required_kinds – mapping {field_name: InstrumentKind}

    Each field named in required_kinds must be an Association and satisfy:
      1) the instrument kind matches, AND
      2) that instrument declares compatibility via impl_tags.
    """

    kind: ClassVar[ImplTag]
    required_kinds: ClassVar[dict[str, InstrumentKind]]

    @model_validator(mode="after")
    def _check_roles(self):
        for fld, kind in self.required_kinds.items():
            assoc: AnyAssoc = getattr(self, fld)

            # instrument must be of the right *role*
            if kind not in assoc.instrument.kinds:
                raise ValueError(f"Field “{fld}” expects a {kind}, " f"got {assoc.instrument.kinds}")

            # instrument must allow this *implementation*
            if self.kind not in assoc.instrument.impl_tags:
                raise ValueError(f"{assoc.instrument.name} is not allowed in implementation {self.kind}")
        return self

    def setup(self) -> None:
        """
        Walk through all Association fields, group them per instrument,
        create an InstrumentContext for each, and call its setup().
        """
        per_instr: dict[Instrument, dict[str, AnyAssoc]] = defaultdict(dict)

        # 1) gather every Association that is part of this impl
        for field_name in self.required_kinds:
            assoc: AnyAssoc = getattr(self, field_name)
            per_instr[assoc.instrument][field_name] = assoc

        # 2) call setup once per instrument
        for instr, mapping in per_instr.items():
            ctx = InstrumentContext(
                impl_tag=self.kind,  # ← an ImplTag, not a str
                associations=mapping,
            )
            instr.setup(ctx)


class InstrumentContext(BaseModel):
    """
    What an instrument needs to know in order to set itself up.

    * impl_tag    – which wiring-scheme we are in
    * associations – mapping {"field-name-in-impl": Association}
                    containing *only* the connectors that belong
                    to *this* instrument inside the implementation
    """

    impl_tag: ImplTag
    associations: dict[str, AnyAssoc]


# --- Pulsed-drive IMPLEMENTATION A  (single RF-AWG) ----------------


class PulsedDrive_RF_AWG(BusImplementation):
    kind = ImplTag.PDRIVE_RF_AWG_SINGLE
    required_kinds = {"rf": InstrumentKind.RF_AWG}

    rf: AnyAssoc


# --- Pulsed-drive IMPLEMENTATION B  (LF I/Q + Up-converter) --------


class PulsedDrive_LF_AWG_UpConv(BusImplementation):
    kind = ImplTag.PDRIVE_LF_AWG_UPCONV
    required_kinds = {
        "i": InstrumentKind.LF_AWG,
        "q": InstrumentKind.LF_AWG,
        "up_i": InstrumentKind.UP_CONVERTER,
        "up_q": InstrumentKind.UP_CONVERTER,
        "up_out": InstrumentKind.UP_CONVERTER,
    }

    i: AnyAssoc
    q: AnyAssoc
    up_i: AnyAssoc
    up_q: AnyAssoc
    up_out: AnyAssoc

    # extra rule: the three up-converter ports must belong to the *same* unit
    @model_validator(mode="after")
    def _same_upconverter(self):
        if not (self.up_i.instrument is self.up_q.instrument is self.up_out.instrument):
            raise ValueError("I, Q and RF-out must be on the same up-converter")
        return self


# --- DC-Flux IMPLEMENTATION  (single current source) ---------------


class DCFlux_Single(BusImplementation):
    kind = ImplTag.DCFLUX_SINGLE_SRC
    required_kinds = {"src": InstrumentKind.CURRENT_GEN}

    src: AnyAssoc


# ------------------------------------------------------------------
# 3.  BUS HIERARCHY
# ------------------------------------------------------------------


BI = TypeVar("BI", bound="BusImplementation")


class Bus(BaseModel, Generic[BI]):
    name: str
    impl: BI

    # list of classes, e.g. (PulsedDrive_RF_AWG, PulsedDrive_LF_AWG_UpConv)
    AllowedImpls: ClassVar[tuple[type[BusImplementation], ...]]

    @model_validator(mode="after")
    def _check_impl_class(self):
        if not isinstance(self.impl, self.AllowedImpls):
            allowed = ", ".join(c.__name__ for c in self.AllowedImpls)
            raise TypeError(f"{self.__class__.__name__} accepts only: {allowed}")
        return self

    # expose the hardware-setup entry point
    def setup(self) -> None:
        self.impl.setup()


# Top-level categories
class DriveBus(Bus):
    pass


class ReadoutBus(Bus):
    pass


class FluxBus(Bus):
    pass


# Concrete buses
class PulsedDriveBus(Bus[Union[PulsedDrive_RF_AWG, PulsedDrive_LF_AWG_UpConv]]):
    AllowedImpls = (PulsedDrive_RF_AWG, PulsedDrive_LF_AWG_UpConv)


class ContinuousDriveBus(DriveBus):
    pass


class ADCReadoutBus(ReadoutBus):
    pass


class VNAReadoutBus(ReadoutBus):
    pass


class OscilloscopeReadoutBus(ReadoutBus):
    pass


class FridgeReadoutBus(ReadoutBus):
    pass


class ACFluxBus(FluxBus):
    pass


class DCFluxBus(Bus["DCFlux_Single"]):
    AllowedImpls = (DCFlux_Single,)

In [2]:
# ------------------------------------------------------------------
# 6.  EXAMPLE USAGE
# ------------------------------------------------------------------

# ----------------------------------------------------------------
# create some hardware
# ----------------------------------------------------------------
awg0 = QbloxCluster(name="cluster-0")
awg1 = QbloxCluster(name="cluster-1")
upc = ZurichUpConverter(name="up-0")
dc = DeltaCurrentGen(name="delta-0")

# logical buses for one qubit
# --- build drive bus with a LF-AWG + up-converter implementation ------
drive_bus = PulsedDriveBus(
    name="drive-q0",
    impl=PulsedDrive_LF_AWG_UpConv(
        i=AnyAssoc(instrument=awg0, port=IntPort(port=0)),
        q=AnyAssoc(instrument=awg0, port=IntPort(port=1)),
        up_i=AnyAssoc(instrument=upc, port=IntPort(port=0)),
        up_q=AnyAssoc(instrument=upc, port=IntPort(port=1)),
        up_out=AnyAssoc(instrument=upc, port=IntPort(port=2)),
    ),
)

# --- build flux bus with its single-source implementation -------------
flux_bus = DCFluxBus(
    name="flux-q0",
    impl=DCFlux_Single(
        src=AnyAssoc(instrument=dc, port=NoPort()),
    ),
)

# print('drive-bus wiring OK →', impl_drive.model_dump())
# print('flux-bus  wiring OK →', impl_flux.model_dump())

In [3]:
drive_bus.setup()
flux_bus.setup()

[DRY-RUN] cluster-0: would set up for pdrive_lf_awg_upconv with ['i', 'q']
[DRY-RUN] up-0: would set up for pdrive_lf_awg_upconv with ['up_i', 'up_q', 'up_out']
[DRY-RUN] delta-0: would set up for dcflux_single with ['src']
