In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from qiskit import QuantumCircuit, transpile, Aer, assemble
import pennylane as qml

class MultiHeadAttentionBase(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout: float = 0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionBase, self).__init__()

        assert embed_dim % num_heads == 0, f"Embedding dimension ({embed_dim}) should be divisible by number of heads ({num_heads})"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads  # projection dimensions
        self.k_linear = None
        self.q_linear = None
        self.v_linear = None
        self.combine_heads = None
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None
    
    def separate_heads(self, x):
        '''
        split into N heads
        from (batch_size, seq_len, embed_dim)
        to   (batch_size, seq_len, num_heads, embed_dim)
        then transpose (1,2) to (batch_size, num_heads, seq_len, embed_dim)
        to make mat mult straightforward for each head
        '''
        batch_size = x.size(0)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def attention(self, query, key, value, mask=None, dropout=None):
        '''
        Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k))V
        '''
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        # see also: https://tensorchiefs.github.io/dlday2018/tutorial/einsum.html
        #scores = torch.einsum('bijh, bkjh -> bikh', query, key) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        attn = torch.matmul(scores, value)
        return attn, scores
    
    def downstream(self, query, key, value, batch_size, mask=None):
        Q = self.separate_heads(query)
        K = self.separate_heads(key)
        V = self.separate_heads(value)

        x, self.attn_weights = self.attention(Q, K, V, mask, dropout=self.dropout)

        concat = x.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

        return concat
        # output = self.combine_heads(concat)
        # return output

    def forward(self, x, mask=None):
        raise NotImplementedError("Base class does not execute forward function.")

class MultiHeadAttentionClassical(MultiHeadAttentionBase):
    def __init__(self, 
                 embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False):
        super(MultiHeadAttentionClassical, self).__init__(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, mask=mask, use_bias=use_bias)

        self.k_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.q_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.v_linear = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.combine_heads = nn.Linear(embed_dim, embed_dim, bias=use_bias)
        self.head_dim = embed_dim // num_heads
    
    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        K = self.k_linear(x)
        Q = self.q_linear(x)
        V = self.v_linear(x)

        x = self.downstream(Q, K, V, batch_size, mask)
        output = self.combine_heads(x)
        return output

class MultiHeadAttentionQuantum(MultiHeadAttentionBase):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 dropout=0.1,
                 mask=None,
                 use_bias=False,
                 n_qubits: int = 16,
                 n_qlayers: int = 1,
                 q_device="default.qubit"):
        super(MultiHeadAttentionQuantum, self).__init__(embed_dim, num_heads, dropout=dropout, mask=mask, use_bias=use_bias)
        
        # todo: add intermediate layer to "dress" quantum circuit
        assert n_qubits == embed_dim, "Number of qubits ({n_qubits}) does not match embedding dim ({embed_dim})"

        self.n_qubits = n_qubits
        self.n_qlayers = n_qlayers
        self.q_device = q_device
        self.head_dim = embed_dim // num_heads
        if 'qulacs' in q_device:
            self.dev = qml.device(q_device, wires=self.n_qubits, gpu=True)
        elif 'braket' in q_device:
            self.dev = qml.device(q_device, wires=self.n_qubits, parallel=True)
        else:
            self.dev = qml.device(q_device, wires=self.n_qubits)

        def _circuit(inputs, weights):
            qml.templates.AngleEmbedding(inputs, wires=range(self.n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

        self.qlayer = qml.QNode(_circuit, self.dev, interface="torch")
        self.weight_shapes = {"weights": (n_qlayers, n_qubits)}
        print(f"weight_shapes = (n_qlayers, n_qubits) = ({n_qlayers}, {self.n_qubits})")

        self.k_linear = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)
        self.q_linear = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)
        self.v_linear = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)
        self.combine_heads = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)

    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()
        assert embed_dim == self.embed_dim, f"Input embedding ({embed_dim}) does not match layer embedding size ({self.embed_dim})"

        K = [self.k_linear(x[:, t, :]) for t in range(seq_len)]
        Q = [self.q_linear(x[:, t, :]) for t in range(seq_len)]
        V = [self.v_linear(x[:, t, :]) for t in range(seq_len)]

        K = torch.Tensor(pad_sequence(K))
        Q = torch.Tensor(pad_sequence(Q))
        V = torch.Tensor(pad_sequence(V))

        x = self.downstream(Q, K, V, batch_size, mask)
        output = [self.combine_heads(x[:, t, :]) for t in range(seq_len)]
        output = torch.Tensor(pad_sequence(output))
        return output

class FeedForwardBase(nn.Module):
    def __init__(self, embed_dim, ffn_dim, dropout=0.1):
        super(FeedForwardBase, self).__init__()
        self.linear_1 = nn.Linear(embed_dim, ffn_dim)
        self.linear_2 = nn.Linear(ffn_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        raise NotImplementedError("Base class does not implement forward function")


class FeedForwardClassical(FeedForwardBase):
    def __init__(self, embed_dim, ffn_dim, dropout=0.1):
        super(FeedForwardClassical, self).__init__(embed_dim, ffn_dim, dropout)

    def forward(self, x):
        x = F.relu(self.linear_1(x))
        x = self.dropout(x)
        x = self.linear_2(x)
        return x


class FeedForwardQuantum(FeedForwardBase):
    def __init__(self, embed_dim, n_qubits, n_qlayers=1, dropout=0.1, q_device="default.qubit"):
        super(FeedForwardQuantum, self).__init__(embed_dim, ffn_dim=n_qubits, dropout=dropout)

        self.n_qubits = n_qubits
        if 'qulacs' in q_device:
            self.dev = qml.device(q_device, wires=self.n_qubits, gpu=True)
        elif 'braket' in q_device:
            self.dev = qml.device(q_device, wires=self.n_qubits, parallel=True)
        else:
            self.dev = qml.device(q_device, wires=self.n_qubits)

        def _circuit(inputs, weights):
            qml.templates.AngleEmbedding(inputs, wires=range(self.n_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
        self.qlayer = qml.QNode(_circuit, self.dev, interface="torch")
        self.weight_shapes = {"weights": (n_qlayers, n_qubits)}
        self.vqc = qml.qnn.TorchLayer(self.qlayer, self.weight_shapes)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        x = self.linear_1(x)
        X = [self.vqc(x[:, t, :]) for t in range(seq_len)]
        x = torch.Tensor(pad_sequence(X))
        # dropout?
        x = self.linear_2(x)
        return x


class TransformerBlockBase(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_head: int,
                 ff_dim: int,
                 n_qubits_transformer: int = 0,
                 n_qubits_ffn: int = 0,
                 n_qlayers: int = 1,
                 dropout: float = 0.1,
                 mask=None):
        super(TransformerBlockBase, self).__init__()
        self.attn = None
        self.ffn = None
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.attn(x)
        x = self.norm1(attn_output + x)
        x = self.dropout1(x)

        ff_output = self.ffn(x)
        x = self.norm2(ff_output + x)
        x = self.dropout2(x)

        return x


class TransformerBlockClassical(TransformerBlockBase):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 ff_dim: int,
                 dropout: float = 0.1,
                 mask=None):
        super(TransformerBlockClassical, self).__init__(embed_dim, num_heads, ff_dim, dropout, mask)
        self.attn = MultiHeadAttentionClassical(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, mask=mask)
        self.ffn = FeedForwardClassical(embed_dim, ff_dim)


class TransformerBlockQuantum(TransformerBlockBase):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 ffn_dim: int,
                 n_qubits_transformer: int = 0,
                 n_qubits_ffn: int = 0,
                 n_qlayers: int = 1,
                 dropout: float = 0.1,
                 mask=None,
                 q_device='default.qubit'):
        super(TransformerBlockQuantum, self).__init__(embed_dim, num_heads, ffn_dim, dropout, mask)
        
        self.n_qubits_transformer = n_qubits_transformer
        self.n_qubits_ffn = n_qubits_ffn
        self.n_qlayers = n_qlayers

        self.attn = MultiHeadAttentionQuantum(embed_dim,
                                              num_heads,
                                              n_qubits=n_qubits_transformer,
                                              n_qlayers=n_qlayers,
                                              dropout=dropout,
                                              mask=mask,
                                              q_device=q_device)
        if n_qubits_ffn > 0:
            self.ffn = FeedForwardQuantum(embed_dim, n_qubits_ffn, n_qlayers, q_device=q_device)
        else:
            self.ffn = FeedForwardClassical(embed_dim, ffn_dim)


class PositionalEncoder(nn.Module):
    def __init__(self, embed_dim, max_seq_le
            transformer_blocks = [
                TransformerBlockClassical(embed_dim, num_heads, ffn_dim) for _ in range(num_blocks)
            ]

        self.transformers = nn.Sequential(*transformer_blocks)
        if self.num_classes > 2:
            self.class_logits = nn.Linear(embed_dim, num_classes)
        else:
            self.class_logits = nn.Linear(embed_dim, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        tokens = self.token_embedding(x)
        # batch_size, seq_len, embed_dim = x.size()
        x = self.pos_embedding(tokens)
        x = self.transformers(x)
        x = x.mean(dim=1)  # global average pooling, works in 1D
        x = self.dropout(x)
        # x = self.class_logits(x)
        # return F.log_softmax(x, dim=1)
        return self.class_logits(x)n=512):
        super().__init__()
        self.embed_dim = embed_dim
        
        # create constant 'pe' matrix with values dependant on pos and i
        pe = torch.zeros(max_seq_len, embed_dim)
        for pos in range(max_seq_len):
            for i in range(0, embed_dim, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)
        #add constant to embedding
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)  # .cuda()
        return x


class Classifier(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 num_blocks: int,
                 num_classes: int,
                 vocab_size: int,
                 ffn_dim: int = 32,
                 n_qubits_transformer: int = 0,
                 n_qubits_ffn: int = 0,
                 n_qlayers: int = 1,
                 dropout=0.1,
                 q_device="device.qubit"):
        super(Classifier, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.num_classes = num_classes
        self.vocab_size = vocab_size

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = PositionalEncoder(embed_dim)

        print(f"++ There will be {num_blocks} transformer blocks")

        if n_qubits_transformer > 0:
            print(f"++ Transformer will use {n_qubits_transformer} qubits and {n_qlayers} q layers")
            if n_qubits_ffn > 0:
                print(f"The feed-forward head will use {n_qubits_ffn} qubits")
            else:
                print(f"The feed-forward head will be classical")
            print(f"Using quantum device {q_device}")

            transformer_blocks = [
                TransformerBlockQuantum(embed_dim, num_heads, ffn_dim,
                                        n_qubits_transformer=n_qubits_transformer,
                                        n_qubits_ffn=n_qubits_ffn,
                                        n_qlayers=n_qlayers, 
                                        q_device=q_device) for _ in range(num_blocks)
            ]
        else:
            transformer_blocks = [
                TransformerBlockClassical(embed_dim, num_heads, ffn_dim) for _ in range(num_blocks)
            ]

        self.transformers = nn.Sequential(*transformer_blocks)
        if self.num_classes > 2:
            self.class_logits = nn.Linear(embed_dim, num_classes)
        else:
            self.class_logits = nn.Linear(embed_dim, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        tokens = self.token_embedding(x)
        # batch_size, seq_len, embed_dim = x.size()
        x = self.pos_embedding(tokens)
        x = self.transformers(x)
        x = x.mean(dim=1)  # global average pooling, works in 1D
        x = self.dropout(x)
        # x = self.class_logits(x)
        # return F.log_softmax(x, dim=1)
        return self.class_logits(x)


In [2]:
# Example usage
embed_size = 16
heads = 4
value_len = 4
key_len = 4
query_len = 4
N = 4

torch.manual_seed(42)
# Instantiate the QuantumSelfAttention module
qsa_Quantum = MultiHeadAttentionQuantum(embed_size, heads)
qsa_Classical = MultiHeadAttentionClassical(embed_size, heads)
# Generate random values, keys, queries, and mask
values = torch.randn(N, value_len, heads, qsa_Quantum.head_dim)
keys = torch.randn(N, key_len, heads, qsa_Quantum.head_dim)
queries = torch.randn(N, query_len, heads, qsa_Quantum.head_dim)
mask = torch.ones(N, query_len, key_len)
x_q = qsa_Quantum.downstream(queries, keys, values, N, mask)
out_q = qsa_Quantum(x_q, mask)

print("out_q:")
print(out_q)



weight_shapes = (n_qlayers, n_qubits) = (1, 16)
out_q:
tensor([[[ 9.1934e-04,  1.3679e-02,  1.3462e-02, -5.9553e-03, -5.1082e-03,
          -3.7311e-03,  2.2146e-03, -6.2246e-04,  3.3670e-04, -3.3611e-04,
          -2.6084e-04, -1.6018e-04,  1.5706e-04,  8.5215e-05, -4.8265e-05,
           2.2669e-05],
         [ 1.2854e-03,  1.6908e-02,  1.6113e-02, -8.3684e-03, -7.2838e-03,
          -5.1987e-03,  3.0728e-03, -8.7038e-04,  4.7081e-04, -4.6999e-04,
          -3.6473e-04, -2.2398e-04,  2.1961e-04,  1.1916e-04, -6.7489e-05,
           3.1698e-05],
         [ 1.1288e-03,  1.6270e-02,  1.5852e-02, -7.3465e-03, -6.3927e-03,
          -4.5648e-03,  2.6979e-03, -7.6428e-04,  4.1341e-04, -4.1269e-04,
          -3.2026e-04, -1.9667e-04,  1.9284e-04,  1.0463e-04, -5.9261e-05,
           2.7833e-05],
         [ 1.5351e-03,  1.9078e-02,  1.7828e-02, -9.9927e-03, -8.6974e-03,
          -6.2079e-03,  3.6693e-03, -1.0393e-03,  5.6222e-04, -5.6124e-04,
          -4.3554e-04, -2.6746e-04,  2.6225e-04,

In [3]:
# Generate random values, keys, queries, and mask
values_c = torch.randn(N, value_len, heads, qsa_Classical.head_dim)
keys_c = torch.randn(N, key_len, heads, qsa_Classical.head_dim)
queries_c = torch.randn(N, query_len, heads, qsa_Classical.head_dim)
#mask = torch.ones(N, query_len, key_len)
x_c = qsa_Classical.downstream(queries, keys, values, N, mask)
out_c = qsa_Classical(x_c, mask)

print("out_c:")
print(out_c)

out_c:
tensor([[[-0.1598,  0.0918,  0.1000,  0.0930, -0.1281,  0.0658, -0.1756,
           0.0853,  0.0256, -0.0069, -0.2796, -0.2439, -0.1438,  0.0064,
           0.0707,  0.2677],
         [-0.1972,  0.1191,  0.1033,  0.0780, -0.1503,  0.0826, -0.1941,
           0.1044,  0.0169, -0.0244, -0.2912, -0.2563, -0.1859, -0.0179,
           0.0327,  0.3085],
         [-0.2243,  0.1408,  0.1115,  0.0802, -0.1532,  0.0667, -0.1783,
           0.1599,  0.0127, -0.0547, -0.3139, -0.2527, -0.2357, -0.0086,
           0.0213,  0.2924],
         [-0.2390,  0.1785,  0.1053,  0.0630, -0.1724,  0.0450, -0.1905,
           0.1624,  0.0283, -0.0786, -0.3114, -0.2471, -0.2167,  0.0135,
           0.0033,  0.2659]],

        [[-0.3112, -0.0482,  0.0032, -0.0293,  0.1031,  0.0432, -0.0412,
           0.0608, -0.0408, -0.2106,  0.0133, -0.0696,  0.0805, -0.1356,
          -0.0376, -0.0929],
         [-0.4017,  0.0316, -0.0315, -0.0845,  0.1029,  0.0494,  0.0059,
           0.1573, -0.0324, -0.3441,  0.041

In [4]:
from qiskit import QuantumCircuit, transpile
from qiskit.visualization import circuit_drawer

# Create the quantum circuit
circuit = QuantumCircuit(4)

# Add gate operations to the circuit for keys
for i in range(4):
    circuit.rx(0.4, i)
    circuit.ry(0.5, i)
    circuit.rz(0.6, i)

# Measure the keys
circuit.measure_all()

# Draw the circuit diagram
circuit_drawer(circuit, output='text')

: 