# Satifsying requirements

In [None]:
!pip install pennylane --upgrade

In [None]:
!pip install torch

In [1]:
import torch
import pennylane as qml
import numpy as np

%matplotlib inline

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


# Downloading data

In [3]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Preparing data with DataLoaders

In [14]:
from torch.utils.data import DataLoader

loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=80, 
                                          shuffle=True, 
                                          num_workers=1,
                                          pin_memory=True),
    
    'test'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=1,
                                          pin_memory=True),
}

# Defining a NN

In [15]:
n_qubits = 4
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch")
def qnode(inputs, weights):
    qml.templates.AngleEmbedding(features=inputs, wires=range(n_qubits))
    
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    
    return [qml.expval(qml.PauliY(wires=i)) for i in range(n_qubits)]

In [16]:
n_layers = 4
weight_shapes = {"weights": (n_layers, n_qubits, 3)}

In [17]:
import torch.nn as nn

class HybridNN(nn.Module):
    def __init__(self):
        super(HybridNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(
                in_channels=16,              
                out_channels=32,            
                kernel_size=5,              
                stride=1,                   
                padding=2,    
            ),     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),                
        )
        self.fc_1 = nn.Linear(32 * 7 * 7, 16)
        
        # LIST USAGE?
        self.qlayer_1 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer_2 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer_3 = qml.qnn.TorchLayer(qnode, weight_shapes)
        self.qlayer_4 = qml.qnn.TorchLayer(qnode, weight_shapes)
        
        self.qlayer_1.to(device)
        self.qlayer_2.to(device)
        self.qlayer_3.to(device)
        self.qlayer_4.to(device)
        
        self.fc_2 = nn.Linear(16, 10)

        
    def forward(self, x):

        x = self.conv1(x)

        x = self.conv2(x)

        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1) 

        x = self.fc_1(x)
        #print('Before split')
        x_1, x_2, x_3, x_4 = torch.split(x, 4, dim=1) # second argument is number of elements in one new tensor
        #print('After split')
        #x = torch.Tensor(0)
        
        x_1 = self.qlayer_1(x_1)
        x_2 = self.qlayer_2(x_2)
        x_3 = self.qlayer_3(x_3)
        x_4 = self.qlayer_4(x_4)
        
        #print(x.device)
        
        x = torch.cat([x_1, x_2, x_3, x_4], axis=1)
        x = x.to(device)
        
        logits = self.fc_2(x)
        
        return logits

In [18]:
hnn = HybridNN()
hnn = hnn.to(device)
print(hnn)

HybridNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_1): Linear(in_features=1568, out_features=16, bias=True)
  (qlayer_1): <Quantum Torch Layer: func=qnode>
  (qlayer_2): <Quantum Torch Layer: func=qnode>
  (qlayer_3): <Quantum Torch Layer: func=qnode>
  (qlayer_4): <Quantum Torch Layer: func=qnode>
  (fc_2): Linear(in_features=16, out_features=10, bias=True)
)


# Training

In [19]:
loss_func = nn.CrossEntropyLoss()

In [20]:
from torch import optim

optimizer = optim.Adam(hnn.parameters(), lr = 0.01)  

In [21]:
from tqdm.notebook import trange
from torch.autograd import Variable

def train(num_epochs, model, loaders):
    
    model.train()
        
    # Train the model
    total_step = len(loaders['train'])
        
    for epoch in trange(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            
            b_x, b_y = images.to(device), labels.to(device)

            output = model(b_x)             
            loss = loss_func(output, b_y)
            
            optimizer.zero_grad()           
            
            loss.backward()               
            optimizer.step()                
            
            if (i+1) % 10 >= 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
        print('\n')

In [22]:
num_epochs = 10

train(num_epochs, hnn, loaders)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch [1/10], Step [1/750], Loss: 2.3065
Epoch [1/10], Step [2/750], Loss: 2.3337
Epoch [1/10], Step [3/750], Loss: 2.3223
Epoch [1/10], Step [4/750], Loss: 2.3139
Epoch [1/10], Step [5/750], Loss: 2.2920
Epoch [1/10], Step [6/750], Loss: 2.3268
Epoch [1/10], Step [7/750], Loss: 2.3550
Epoch [1/10], Step [8/750], Loss: 2.3303
Epoch [1/10], Step [9/750], Loss: 2.3282
Epoch [1/10], Step [10/750], Loss: 2.3036
Epoch [1/10], Step [11/750], Loss: 2.3068
Epoch [1/10], Step [12/750], Loss: 2.3012
Epoch [1/10], Step [13/750], Loss: 2.2896
Epoch [1/10], Step [14/750], Loss: 2.2757
Epoch [1/10], Step [15/750], Loss: 2.2919
Epoch [1/10], Step [16/750], Loss: 2.2775
Epoch [1/10], Step [17/750], Loss: 2.2916
Epoch [1/10], Step [18/750], Loss: 2.3085
Epoch [1/10], Step [19/750], Loss: 2.3402
Epoch [1/10], Step [20/750], Loss: 2.2993
Epoch [1/10], Step [21/750], Loss: 2.3222
Epoch [1/10], Step [22/750], Loss: 2.3096
Epoch [1/10], Step [23/750], Loss: 2.3076
Epoch [1/10], Step [24/750], Loss: 2.2927
E

Epoch [1/10], Step [195/750], Loss: 2.2771
Epoch [1/10], Step [196/750], Loss: 2.2900
Epoch [1/10], Step [197/750], Loss: 2.2959
Epoch [1/10], Step [198/750], Loss: 2.2908
Epoch [1/10], Step [199/750], Loss: 2.3200
Epoch [1/10], Step [200/750], Loss: 2.2952
Epoch [1/10], Step [201/750], Loss: 2.3165
Epoch [1/10], Step [202/750], Loss: 2.3062
Epoch [1/10], Step [203/750], Loss: 2.2935
Epoch [1/10], Step [204/750], Loss: 2.3076
Epoch [1/10], Step [205/750], Loss: 2.3355
Epoch [1/10], Step [206/750], Loss: 2.2897
Epoch [1/10], Step [207/750], Loss: 2.3061
Epoch [1/10], Step [208/750], Loss: 2.3267
Epoch [1/10], Step [209/750], Loss: 2.3014
Epoch [1/10], Step [210/750], Loss: 2.2986
Epoch [1/10], Step [211/750], Loss: 2.2916
Epoch [1/10], Step [212/750], Loss: 2.3219
Epoch [1/10], Step [213/750], Loss: 2.2953
Epoch [1/10], Step [214/750], Loss: 2.2941
Epoch [1/10], Step [215/750], Loss: 2.2706
Epoch [1/10], Step [216/750], Loss: 2.2949
Epoch [1/10], Step [217/750], Loss: 2.2886
Epoch [1/10

Epoch [1/10], Step [386/750], Loss: 2.2936
Epoch [1/10], Step [387/750], Loss: 2.3101
Epoch [1/10], Step [388/750], Loss: 2.3073
Epoch [1/10], Step [389/750], Loss: 2.3123
Epoch [1/10], Step [390/750], Loss: 2.3046
Epoch [1/10], Step [391/750], Loss: 2.3341
Epoch [1/10], Step [392/750], Loss: 2.2924
Epoch [1/10], Step [393/750], Loss: 2.3058
Epoch [1/10], Step [394/750], Loss: 2.3113
Epoch [1/10], Step [395/750], Loss: 2.3195
Epoch [1/10], Step [396/750], Loss: 2.2928
Epoch [1/10], Step [397/750], Loss: 2.3184
Epoch [1/10], Step [398/750], Loss: 2.3230
Epoch [1/10], Step [399/750], Loss: 2.3165
Epoch [1/10], Step [400/750], Loss: 2.2860
Epoch [1/10], Step [401/750], Loss: 2.2897
Epoch [1/10], Step [402/750], Loss: 2.3036
Epoch [1/10], Step [403/750], Loss: 2.2988
Epoch [1/10], Step [404/750], Loss: 2.3194
Epoch [1/10], Step [405/750], Loss: 2.2832
Epoch [1/10], Step [406/750], Loss: 2.2914
Epoch [1/10], Step [407/750], Loss: 2.3129
Epoch [1/10], Step [408/750], Loss: 2.3012
Epoch [1/10

Epoch [1/10], Step [577/750], Loss: 1.3545
Epoch [1/10], Step [578/750], Loss: 1.4237
Epoch [1/10], Step [579/750], Loss: 1.5933
Epoch [1/10], Step [580/750], Loss: 1.5615
Epoch [1/10], Step [581/750], Loss: 1.4054
Epoch [1/10], Step [582/750], Loss: 1.3799
Epoch [1/10], Step [583/750], Loss: 1.4223
Epoch [1/10], Step [584/750], Loss: 1.5746
Epoch [1/10], Step [585/750], Loss: 1.2518
Epoch [1/10], Step [586/750], Loss: 1.4110
Epoch [1/10], Step [587/750], Loss: 1.3898
Epoch [1/10], Step [588/750], Loss: 1.2303
Epoch [1/10], Step [589/750], Loss: 1.2305
Epoch [1/10], Step [590/750], Loss: 1.2329
Epoch [1/10], Step [591/750], Loss: 1.1342
Epoch [1/10], Step [592/750], Loss: 1.0177
Epoch [1/10], Step [593/750], Loss: 1.0425
Epoch [1/10], Step [594/750], Loss: 1.1719
Epoch [1/10], Step [595/750], Loss: 0.9923
Epoch [1/10], Step [596/750], Loss: 1.0732
Epoch [1/10], Step [597/750], Loss: 1.1644
Epoch [1/10], Step [598/750], Loss: 1.0897
Epoch [1/10], Step [599/750], Loss: 1.1435
Epoch [1/10

Epoch [2/10], Step [19/750], Loss: 0.1904
Epoch [2/10], Step [20/750], Loss: 0.2470
Epoch [2/10], Step [21/750], Loss: 0.1969
Epoch [2/10], Step [22/750], Loss: 0.2551
Epoch [2/10], Step [23/750], Loss: 0.3907
Epoch [2/10], Step [24/750], Loss: 0.2097
Epoch [2/10], Step [25/750], Loss: 0.2837
Epoch [2/10], Step [26/750], Loss: 0.1640
Epoch [2/10], Step [27/750], Loss: 0.1918
Epoch [2/10], Step [28/750], Loss: 0.2699
Epoch [2/10], Step [29/750], Loss: 0.2699
Epoch [2/10], Step [30/750], Loss: 0.2024
Epoch [2/10], Step [31/750], Loss: 0.2806
Epoch [2/10], Step [32/750], Loss: 0.1683
Epoch [2/10], Step [33/750], Loss: 0.2571
Epoch [2/10], Step [34/750], Loss: 0.3458
Epoch [2/10], Step [35/750], Loss: 0.2169
Epoch [2/10], Step [36/750], Loss: 0.3032
Epoch [2/10], Step [37/750], Loss: 0.1485
Epoch [2/10], Step [38/750], Loss: 0.3964
Epoch [2/10], Step [39/750], Loss: 0.2321
Epoch [2/10], Step [40/750], Loss: 0.2198
Epoch [2/10], Step [41/750], Loss: 0.2295
Epoch [2/10], Step [42/750], Loss:

Epoch [2/10], Step [212/750], Loss: 0.1046
Epoch [2/10], Step [213/750], Loss: 0.0861
Epoch [2/10], Step [214/750], Loss: 0.1569
Epoch [2/10], Step [215/750], Loss: 0.2906
Epoch [2/10], Step [216/750], Loss: 0.1742
Epoch [2/10], Step [217/750], Loss: 0.0585
Epoch [2/10], Step [218/750], Loss: 0.0781
Epoch [2/10], Step [219/750], Loss: 0.1096
Epoch [2/10], Step [220/750], Loss: 0.2526
Epoch [2/10], Step [221/750], Loss: 0.0723
Epoch [2/10], Step [222/750], Loss: 0.1298
Epoch [2/10], Step [223/750], Loss: 0.1234
Epoch [2/10], Step [224/750], Loss: 0.2062
Epoch [2/10], Step [225/750], Loss: 0.1539
Epoch [2/10], Step [226/750], Loss: 0.1556
Epoch [2/10], Step [227/750], Loss: 0.2143
Epoch [2/10], Step [228/750], Loss: 0.1553
Epoch [2/10], Step [229/750], Loss: 0.1941
Epoch [2/10], Step [230/750], Loss: 0.0609
Epoch [2/10], Step [231/750], Loss: 0.2267
Epoch [2/10], Step [232/750], Loss: 0.1557
Epoch [2/10], Step [233/750], Loss: 0.3113
Epoch [2/10], Step [234/750], Loss: 0.1549
Epoch [2/10

Epoch [2/10], Step [403/750], Loss: 0.1738
Epoch [2/10], Step [404/750], Loss: 0.1876
Epoch [2/10], Step [405/750], Loss: 0.1947
Epoch [2/10], Step [406/750], Loss: 0.0978
Epoch [2/10], Step [407/750], Loss: 0.0952
Epoch [2/10], Step [408/750], Loss: 0.1366
Epoch [2/10], Step [409/750], Loss: 0.1206
Epoch [2/10], Step [410/750], Loss: 0.1181
Epoch [2/10], Step [411/750], Loss: 0.1857
Epoch [2/10], Step [412/750], Loss: 0.2928
Epoch [2/10], Step [413/750], Loss: 0.0680
Epoch [2/10], Step [414/750], Loss: 0.1216
Epoch [2/10], Step [415/750], Loss: 0.0800
Epoch [2/10], Step [416/750], Loss: 0.0754
Epoch [2/10], Step [417/750], Loss: 0.1759
Epoch [2/10], Step [418/750], Loss: 0.1098
Epoch [2/10], Step [419/750], Loss: 0.0670
Epoch [2/10], Step [420/750], Loss: 0.1037
Epoch [2/10], Step [421/750], Loss: 0.1536
Epoch [2/10], Step [422/750], Loss: 0.1013
Epoch [2/10], Step [423/750], Loss: 0.1039
Epoch [2/10], Step [424/750], Loss: 0.1073
Epoch [2/10], Step [425/750], Loss: 0.2407
Epoch [2/10

Epoch [2/10], Step [594/750], Loss: 0.0878
Epoch [2/10], Step [595/750], Loss: 0.1967
Epoch [2/10], Step [596/750], Loss: 0.1661
Epoch [2/10], Step [597/750], Loss: 0.1337
Epoch [2/10], Step [598/750], Loss: 0.1373
Epoch [2/10], Step [599/750], Loss: 0.0702
Epoch [2/10], Step [600/750], Loss: 0.0693
Epoch [2/10], Step [601/750], Loss: 0.0449
Epoch [2/10], Step [602/750], Loss: 0.0934
Epoch [2/10], Step [603/750], Loss: 0.0918
Epoch [2/10], Step [604/750], Loss: 0.2555
Epoch [2/10], Step [605/750], Loss: 0.0386
Epoch [2/10], Step [606/750], Loss: 0.1522
Epoch [2/10], Step [607/750], Loss: 0.2560
Epoch [2/10], Step [608/750], Loss: 0.0719
Epoch [2/10], Step [609/750], Loss: 0.1132
Epoch [2/10], Step [610/750], Loss: 0.2592
Epoch [2/10], Step [611/750], Loss: 0.1382
Epoch [2/10], Step [612/750], Loss: 0.2045
Epoch [2/10], Step [613/750], Loss: 0.1390
Epoch [2/10], Step [614/750], Loss: 0.1136
Epoch [2/10], Step [615/750], Loss: 0.1981
Epoch [2/10], Step [616/750], Loss: 0.1010
Epoch [2/10

Epoch [3/10], Step [36/750], Loss: 0.0916
Epoch [3/10], Step [37/750], Loss: 0.1822
Epoch [3/10], Step [38/750], Loss: 0.0600
Epoch [3/10], Step [39/750], Loss: 0.1761
Epoch [3/10], Step [40/750], Loss: 0.0463
Epoch [3/10], Step [41/750], Loss: 0.0454
Epoch [3/10], Step [42/750], Loss: 0.0927
Epoch [3/10], Step [43/750], Loss: 0.1636
Epoch [3/10], Step [44/750], Loss: 0.1590
Epoch [3/10], Step [45/750], Loss: 0.1446
Epoch [3/10], Step [46/750], Loss: 0.0429
Epoch [3/10], Step [47/750], Loss: 0.0564
Epoch [3/10], Step [48/750], Loss: 0.0339
Epoch [3/10], Step [49/750], Loss: 0.0957
Epoch [3/10], Step [50/750], Loss: 0.0980
Epoch [3/10], Step [51/750], Loss: 0.0575
Epoch [3/10], Step [52/750], Loss: 0.0913
Epoch [3/10], Step [53/750], Loss: 0.1165
Epoch [3/10], Step [54/750], Loss: 0.1674
Epoch [3/10], Step [55/750], Loss: 0.0872
Epoch [3/10], Step [56/750], Loss: 0.1405
Epoch [3/10], Step [57/750], Loss: 0.0743
Epoch [3/10], Step [58/750], Loss: 0.1028
Epoch [3/10], Step [59/750], Loss:

Epoch [3/10], Step [229/750], Loss: 0.0543
Epoch [3/10], Step [230/750], Loss: 0.0518
Epoch [3/10], Step [231/750], Loss: 0.0642
Epoch [3/10], Step [232/750], Loss: 0.1012
Epoch [3/10], Step [233/750], Loss: 0.0957
Epoch [3/10], Step [234/750], Loss: 0.0287
Epoch [3/10], Step [235/750], Loss: 0.1178
Epoch [3/10], Step [236/750], Loss: 0.1777
Epoch [3/10], Step [237/750], Loss: 0.1124
Epoch [3/10], Step [238/750], Loss: 0.0345
Epoch [3/10], Step [239/750], Loss: 0.0757
Epoch [3/10], Step [240/750], Loss: 0.0645
Epoch [3/10], Step [241/750], Loss: 0.1206
Epoch [3/10], Step [242/750], Loss: 0.1167
Epoch [3/10], Step [243/750], Loss: 0.0121
Epoch [3/10], Step [244/750], Loss: 0.0961
Epoch [3/10], Step [245/750], Loss: 0.0620
Epoch [3/10], Step [246/750], Loss: 0.1768
Epoch [3/10], Step [247/750], Loss: 0.0209
Epoch [3/10], Step [248/750], Loss: 0.0176
Epoch [3/10], Step [249/750], Loss: 0.0699
Epoch [3/10], Step [250/750], Loss: 0.1638
Epoch [3/10], Step [251/750], Loss: 0.2227
Epoch [3/10

Epoch [3/10], Step [420/750], Loss: 0.0995
Epoch [3/10], Step [421/750], Loss: 0.0847
Epoch [3/10], Step [422/750], Loss: 0.1331
Epoch [3/10], Step [423/750], Loss: 0.0272
Epoch [3/10], Step [424/750], Loss: 0.0419
Epoch [3/10], Step [425/750], Loss: 0.1092
Epoch [3/10], Step [426/750], Loss: 0.1638
Epoch [3/10], Step [427/750], Loss: 0.0463
Epoch [3/10], Step [428/750], Loss: 0.0697
Epoch [3/10], Step [429/750], Loss: 0.0345
Epoch [3/10], Step [430/750], Loss: 0.1507
Epoch [3/10], Step [431/750], Loss: 0.0331
Epoch [3/10], Step [432/750], Loss: 0.0638
Epoch [3/10], Step [433/750], Loss: 0.1492
Epoch [3/10], Step [434/750], Loss: 0.0942
Epoch [3/10], Step [435/750], Loss: 0.2812
Epoch [3/10], Step [436/750], Loss: 0.1244
Epoch [3/10], Step [437/750], Loss: 0.0293
Epoch [3/10], Step [438/750], Loss: 0.1162
Epoch [3/10], Step [439/750], Loss: 0.1088
Epoch [3/10], Step [440/750], Loss: 0.0470
Epoch [3/10], Step [441/750], Loss: 0.1156
Epoch [3/10], Step [442/750], Loss: 0.1906
Epoch [3/10

Epoch [3/10], Step [611/750], Loss: 0.0434
Epoch [3/10], Step [612/750], Loss: 0.0761
Epoch [3/10], Step [613/750], Loss: 0.0815
Epoch [3/10], Step [614/750], Loss: 0.0785
Epoch [3/10], Step [615/750], Loss: 0.0705
Epoch [3/10], Step [616/750], Loss: 0.0871
Epoch [3/10], Step [617/750], Loss: 0.0825
Epoch [3/10], Step [618/750], Loss: 0.0302
Epoch [3/10], Step [619/750], Loss: 0.0475
Epoch [3/10], Step [620/750], Loss: 0.1122
Epoch [3/10], Step [621/750], Loss: 0.1508
Epoch [3/10], Step [622/750], Loss: 0.0677
Epoch [3/10], Step [623/750], Loss: 0.0916
Epoch [3/10], Step [624/750], Loss: 0.1241
Epoch [3/10], Step [625/750], Loss: 0.1685
Epoch [3/10], Step [626/750], Loss: 0.0751
Epoch [3/10], Step [627/750], Loss: 0.1364
Epoch [3/10], Step [628/750], Loss: 0.0176
Epoch [3/10], Step [629/750], Loss: 0.0841
Epoch [3/10], Step [630/750], Loss: 0.0260
Epoch [3/10], Step [631/750], Loss: 0.1489
Epoch [3/10], Step [632/750], Loss: 0.1213
Epoch [3/10], Step [633/750], Loss: 0.1065
Epoch [3/10

Epoch [4/10], Step [53/750], Loss: 0.0903
Epoch [4/10], Step [54/750], Loss: 0.0211
Epoch [4/10], Step [55/750], Loss: 0.1639
Epoch [4/10], Step [56/750], Loss: 0.0849
Epoch [4/10], Step [57/750], Loss: 0.0280
Epoch [4/10], Step [58/750], Loss: 0.1908
Epoch [4/10], Step [59/750], Loss: 0.0145
Epoch [4/10], Step [60/750], Loss: 0.2173
Epoch [4/10], Step [61/750], Loss: 0.1912
Epoch [4/10], Step [62/750], Loss: 0.0788
Epoch [4/10], Step [63/750], Loss: 0.0537
Epoch [4/10], Step [64/750], Loss: 0.1137
Epoch [4/10], Step [65/750], Loss: 0.1977
Epoch [4/10], Step [66/750], Loss: 0.1166
Epoch [4/10], Step [67/750], Loss: 0.0236
Epoch [4/10], Step [68/750], Loss: 0.0380
Epoch [4/10], Step [69/750], Loss: 0.0319
Epoch [4/10], Step [70/750], Loss: 0.1745
Epoch [4/10], Step [71/750], Loss: 0.0392
Epoch [4/10], Step [72/750], Loss: 0.0742
Epoch [4/10], Step [73/750], Loss: 0.1300
Epoch [4/10], Step [74/750], Loss: 0.0128
Epoch [4/10], Step [75/750], Loss: 0.1098
Epoch [4/10], Step [76/750], Loss:

Epoch [4/10], Step [245/750], Loss: 0.0645
Epoch [4/10], Step [246/750], Loss: 0.0185
Epoch [4/10], Step [247/750], Loss: 0.0906
Epoch [4/10], Step [248/750], Loss: 0.0837
Epoch [4/10], Step [249/750], Loss: 0.1016
Epoch [4/10], Step [250/750], Loss: 0.0210
Epoch [4/10], Step [251/750], Loss: 0.0940
Epoch [4/10], Step [252/750], Loss: 0.0494
Epoch [4/10], Step [253/750], Loss: 0.0485
Epoch [4/10], Step [254/750], Loss: 0.0874
Epoch [4/10], Step [255/750], Loss: 0.0405
Epoch [4/10], Step [256/750], Loss: 0.0681
Epoch [4/10], Step [257/750], Loss: 0.0452
Epoch [4/10], Step [258/750], Loss: 0.0824
Epoch [4/10], Step [259/750], Loss: 0.0924
Epoch [4/10], Step [260/750], Loss: 0.1595
Epoch [4/10], Step [261/750], Loss: 0.0369
Epoch [4/10], Step [262/750], Loss: 0.1046
Epoch [4/10], Step [263/750], Loss: 0.0416
Epoch [4/10], Step [264/750], Loss: 0.0231
Epoch [4/10], Step [265/750], Loss: 0.0786
Epoch [4/10], Step [266/750], Loss: 0.1942
Epoch [4/10], Step [267/750], Loss: 0.1319
Epoch [4/10

Epoch [4/10], Step [436/750], Loss: 0.0094
Epoch [4/10], Step [437/750], Loss: 0.0368
Epoch [4/10], Step [438/750], Loss: 0.0885
Epoch [4/10], Step [439/750], Loss: 0.1445
Epoch [4/10], Step [440/750], Loss: 0.1041
Epoch [4/10], Step [441/750], Loss: 0.0625
Epoch [4/10], Step [442/750], Loss: 0.0673
Epoch [4/10], Step [443/750], Loss: 0.1150
Epoch [4/10], Step [444/750], Loss: 0.0756
Epoch [4/10], Step [445/750], Loss: 0.0614
Epoch [4/10], Step [446/750], Loss: 0.0295
Epoch [4/10], Step [447/750], Loss: 0.0358
Epoch [4/10], Step [448/750], Loss: 0.0999
Epoch [4/10], Step [449/750], Loss: 0.0812
Epoch [4/10], Step [450/750], Loss: 0.0186
Epoch [4/10], Step [451/750], Loss: 0.0078
Epoch [4/10], Step [452/750], Loss: 0.0295
Epoch [4/10], Step [453/750], Loss: 0.0723
Epoch [4/10], Step [454/750], Loss: 0.1051
Epoch [4/10], Step [455/750], Loss: 0.0823
Epoch [4/10], Step [456/750], Loss: 0.0067
Epoch [4/10], Step [457/750], Loss: 0.0601
Epoch [4/10], Step [458/750], Loss: 0.0634
Epoch [4/10

Epoch [4/10], Step [627/750], Loss: 0.1542
Epoch [4/10], Step [628/750], Loss: 0.1836
Epoch [4/10], Step [629/750], Loss: 0.0094
Epoch [4/10], Step [630/750], Loss: 0.1437
Epoch [4/10], Step [631/750], Loss: 0.0997
Epoch [4/10], Step [632/750], Loss: 0.0653
Epoch [4/10], Step [633/750], Loss: 0.1269
Epoch [4/10], Step [634/750], Loss: 0.0329
Epoch [4/10], Step [635/750], Loss: 0.0956
Epoch [4/10], Step [636/750], Loss: 0.0384
Epoch [4/10], Step [637/750], Loss: 0.1710
Epoch [4/10], Step [638/750], Loss: 0.0593
Epoch [4/10], Step [639/750], Loss: 0.1965
Epoch [4/10], Step [640/750], Loss: 0.0609
Epoch [4/10], Step [641/750], Loss: 0.1351
Epoch [4/10], Step [642/750], Loss: 0.0487
Epoch [4/10], Step [643/750], Loss: 0.1105
Epoch [4/10], Step [644/750], Loss: 0.1322
Epoch [4/10], Step [645/750], Loss: 0.1033
Epoch [4/10], Step [646/750], Loss: 0.0289
Epoch [4/10], Step [647/750], Loss: 0.0872
Epoch [4/10], Step [648/750], Loss: 0.0526
Epoch [4/10], Step [649/750], Loss: 0.1445
Epoch [4/10

Epoch [5/10], Step [70/750], Loss: 0.0830
Epoch [5/10], Step [71/750], Loss: 0.0405
Epoch [5/10], Step [72/750], Loss: 0.0302
Epoch [5/10], Step [73/750], Loss: 0.0274
Epoch [5/10], Step [74/750], Loss: 0.0599
Epoch [5/10], Step [75/750], Loss: 0.0053
Epoch [5/10], Step [76/750], Loss: 0.0743
Epoch [5/10], Step [77/750], Loss: 0.0639
Epoch [5/10], Step [78/750], Loss: 0.0627
Epoch [5/10], Step [79/750], Loss: 0.0430
Epoch [5/10], Step [80/750], Loss: 0.0613
Epoch [5/10], Step [81/750], Loss: 0.0754
Epoch [5/10], Step [82/750], Loss: 0.0142
Epoch [5/10], Step [83/750], Loss: 0.0212
Epoch [5/10], Step [84/750], Loss: 0.0220
Epoch [5/10], Step [85/750], Loss: 0.0998
Epoch [5/10], Step [86/750], Loss: 0.1002
Epoch [5/10], Step [87/750], Loss: 0.0348
Epoch [5/10], Step [88/750], Loss: 0.1492
Epoch [5/10], Step [89/750], Loss: 0.1094
Epoch [5/10], Step [90/750], Loss: 0.1156
Epoch [5/10], Step [91/750], Loss: 0.0538
Epoch [5/10], Step [92/750], Loss: 0.1503
Epoch [5/10], Step [93/750], Loss:

Epoch [5/10], Step [262/750], Loss: 0.0479
Epoch [5/10], Step [263/750], Loss: 0.0681
Epoch [5/10], Step [264/750], Loss: 0.0591
Epoch [5/10], Step [265/750], Loss: 0.0165
Epoch [5/10], Step [266/750], Loss: 0.0250
Epoch [5/10], Step [267/750], Loss: 0.1383
Epoch [5/10], Step [268/750], Loss: 0.0797
Epoch [5/10], Step [269/750], Loss: 0.0459
Epoch [5/10], Step [270/750], Loss: 0.1310
Epoch [5/10], Step [271/750], Loss: 0.0753
Epoch [5/10], Step [272/750], Loss: 0.0196
Epoch [5/10], Step [273/750], Loss: 0.0279
Epoch [5/10], Step [274/750], Loss: 0.0708
Epoch [5/10], Step [275/750], Loss: 0.0093
Epoch [5/10], Step [276/750], Loss: 0.0248
Epoch [5/10], Step [277/750], Loss: 0.0520
Epoch [5/10], Step [278/750], Loss: 0.0194
Epoch [5/10], Step [279/750], Loss: 0.0540
Epoch [5/10], Step [280/750], Loss: 0.1428
Epoch [5/10], Step [281/750], Loss: 0.1105
Epoch [5/10], Step [282/750], Loss: 0.0585
Epoch [5/10], Step [283/750], Loss: 0.0346
Epoch [5/10], Step [284/750], Loss: 0.0565
Epoch [5/10

Epoch [5/10], Step [453/750], Loss: 0.0349
Epoch [5/10], Step [454/750], Loss: 0.0263
Epoch [5/10], Step [455/750], Loss: 0.0116
Epoch [5/10], Step [456/750], Loss: 0.0183
Epoch [5/10], Step [457/750], Loss: 0.0746
Epoch [5/10], Step [458/750], Loss: 0.1054
Epoch [5/10], Step [459/750], Loss: 0.0074
Epoch [5/10], Step [460/750], Loss: 0.0642
Epoch [5/10], Step [461/750], Loss: 0.0640
Epoch [5/10], Step [462/750], Loss: 0.0666
Epoch [5/10], Step [463/750], Loss: 0.1834
Epoch [5/10], Step [464/750], Loss: 0.1375
Epoch [5/10], Step [465/750], Loss: 0.0319
Epoch [5/10], Step [466/750], Loss: 0.0418
Epoch [5/10], Step [467/750], Loss: 0.0193
Epoch [5/10], Step [468/750], Loss: 0.0537
Epoch [5/10], Step [469/750], Loss: 0.0677
Epoch [5/10], Step [470/750], Loss: 0.1509
Epoch [5/10], Step [471/750], Loss: 0.0473
Epoch [5/10], Step [472/750], Loss: 0.0403
Epoch [5/10], Step [473/750], Loss: 0.0932
Epoch [5/10], Step [474/750], Loss: 0.0686
Epoch [5/10], Step [475/750], Loss: 0.1142
Epoch [5/10

Epoch [5/10], Step [644/750], Loss: 0.0475
Epoch [5/10], Step [645/750], Loss: 0.1652
Epoch [5/10], Step [646/750], Loss: 0.0809
Epoch [5/10], Step [647/750], Loss: 0.0372
Epoch [5/10], Step [648/750], Loss: 0.0250
Epoch [5/10], Step [649/750], Loss: 0.1380
Epoch [5/10], Step [650/750], Loss: 0.1147
Epoch [5/10], Step [651/750], Loss: 0.1279
Epoch [5/10], Step [652/750], Loss: 0.0314
Epoch [5/10], Step [653/750], Loss: 0.0845
Epoch [5/10], Step [654/750], Loss: 0.0137
Epoch [5/10], Step [655/750], Loss: 0.0513
Epoch [5/10], Step [656/750], Loss: 0.1325
Epoch [5/10], Step [657/750], Loss: 0.0711
Epoch [5/10], Step [658/750], Loss: 0.0213
Epoch [5/10], Step [659/750], Loss: 0.0441
Epoch [5/10], Step [660/750], Loss: 0.1109
Epoch [5/10], Step [661/750], Loss: 0.0310
Epoch [5/10], Step [662/750], Loss: 0.2770
Epoch [5/10], Step [663/750], Loss: 0.1026
Epoch [5/10], Step [664/750], Loss: 0.0368
Epoch [5/10], Step [665/750], Loss: 0.0974
Epoch [5/10], Step [666/750], Loss: 0.1023
Epoch [5/10

Epoch [6/10], Step [87/750], Loss: 0.0984
Epoch [6/10], Step [88/750], Loss: 0.0745
Epoch [6/10], Step [89/750], Loss: 0.0444
Epoch [6/10], Step [90/750], Loss: 0.1814
Epoch [6/10], Step [91/750], Loss: 0.1217
Epoch [6/10], Step [92/750], Loss: 0.1048
Epoch [6/10], Step [93/750], Loss: 0.1389
Epoch [6/10], Step [94/750], Loss: 0.0933
Epoch [6/10], Step [95/750], Loss: 0.0342
Epoch [6/10], Step [96/750], Loss: 0.2622
Epoch [6/10], Step [97/750], Loss: 0.0985
Epoch [6/10], Step [98/750], Loss: 0.0654
Epoch [6/10], Step [99/750], Loss: 0.0350
Epoch [6/10], Step [100/750], Loss: 0.0259
Epoch [6/10], Step [101/750], Loss: 0.0519
Epoch [6/10], Step [102/750], Loss: 0.0526
Epoch [6/10], Step [103/750], Loss: 0.0229
Epoch [6/10], Step [104/750], Loss: 0.0875
Epoch [6/10], Step [105/750], Loss: 0.0414
Epoch [6/10], Step [106/750], Loss: 0.0355
Epoch [6/10], Step [107/750], Loss: 0.0186
Epoch [6/10], Step [108/750], Loss: 0.0470
Epoch [6/10], Step [109/750], Loss: 0.0881
Epoch [6/10], Step [110/

Epoch [6/10], Step [278/750], Loss: 0.0258
Epoch [6/10], Step [279/750], Loss: 0.1326
Epoch [6/10], Step [280/750], Loss: 0.0384
Epoch [6/10], Step [281/750], Loss: 0.0598
Epoch [6/10], Step [282/750], Loss: 0.0940
Epoch [6/10], Step [283/750], Loss: 0.0279
Epoch [6/10], Step [284/750], Loss: 0.0326
Epoch [6/10], Step [285/750], Loss: 0.1020
Epoch [6/10], Step [286/750], Loss: 0.1097
Epoch [6/10], Step [287/750], Loss: 0.0740
Epoch [6/10], Step [288/750], Loss: 0.0155
Epoch [6/10], Step [289/750], Loss: 0.0188
Epoch [6/10], Step [290/750], Loss: 0.0922
Epoch [6/10], Step [291/750], Loss: 0.1727
Epoch [6/10], Step [292/750], Loss: 0.0165
Epoch [6/10], Step [293/750], Loss: 0.0145
Epoch [6/10], Step [294/750], Loss: 0.1172
Epoch [6/10], Step [295/750], Loss: 0.0383
Epoch [6/10], Step [296/750], Loss: 0.2102
Epoch [6/10], Step [297/750], Loss: 0.0867
Epoch [6/10], Step [298/750], Loss: 0.0522
Epoch [6/10], Step [299/750], Loss: 0.0163
Epoch [6/10], Step [300/750], Loss: 0.0658
Epoch [6/10

Epoch [6/10], Step [469/750], Loss: 0.0987
Epoch [6/10], Step [470/750], Loss: 0.1138
Epoch [6/10], Step [471/750], Loss: 0.1162
Epoch [6/10], Step [472/750], Loss: 0.0268
Epoch [6/10], Step [473/750], Loss: 0.0473
Epoch [6/10], Step [474/750], Loss: 0.1611
Epoch [6/10], Step [475/750], Loss: 0.1006
Epoch [6/10], Step [476/750], Loss: 0.0381
Epoch [6/10], Step [477/750], Loss: 0.0937
Epoch [6/10], Step [478/750], Loss: 0.0296
Epoch [6/10], Step [479/750], Loss: 0.1200
Epoch [6/10], Step [480/750], Loss: 0.1079
Epoch [6/10], Step [481/750], Loss: 0.0620
Epoch [6/10], Step [482/750], Loss: 0.0444
Epoch [6/10], Step [483/750], Loss: 0.0633
Epoch [6/10], Step [484/750], Loss: 0.0910
Epoch [6/10], Step [485/750], Loss: 0.1598
Epoch [6/10], Step [486/750], Loss: 0.0493
Epoch [6/10], Step [487/750], Loss: 0.1296
Epoch [6/10], Step [488/750], Loss: 0.1472
Epoch [6/10], Step [489/750], Loss: 0.0524
Epoch [6/10], Step [490/750], Loss: 0.0107
Epoch [6/10], Step [491/750], Loss: 0.0806
Epoch [6/10

Epoch [6/10], Step [660/750], Loss: 0.0292
Epoch [6/10], Step [661/750], Loss: 0.1198
Epoch [6/10], Step [662/750], Loss: 0.1501
Epoch [6/10], Step [663/750], Loss: 0.0983
Epoch [6/10], Step [664/750], Loss: 0.2468
Epoch [6/10], Step [665/750], Loss: 0.0206
Epoch [6/10], Step [666/750], Loss: 0.0481
Epoch [6/10], Step [667/750], Loss: 0.0643
Epoch [6/10], Step [668/750], Loss: 0.0490
Epoch [6/10], Step [669/750], Loss: 0.0696
Epoch [6/10], Step [670/750], Loss: 0.0439
Epoch [6/10], Step [671/750], Loss: 0.1298
Epoch [6/10], Step [672/750], Loss: 0.1504
Epoch [6/10], Step [673/750], Loss: 0.1183
Epoch [6/10], Step [674/750], Loss: 0.1097
Epoch [6/10], Step [675/750], Loss: 0.0061
Epoch [6/10], Step [676/750], Loss: 0.1431
Epoch [6/10], Step [677/750], Loss: 0.0104
Epoch [6/10], Step [678/750], Loss: 0.0049
Epoch [6/10], Step [679/750], Loss: 0.0939
Epoch [6/10], Step [680/750], Loss: 0.1143
Epoch [6/10], Step [681/750], Loss: 0.0873
Epoch [6/10], Step [682/750], Loss: 0.1333
Epoch [6/10

Epoch [7/10], Step [103/750], Loss: 0.0386
Epoch [7/10], Step [104/750], Loss: 0.0441
Epoch [7/10], Step [105/750], Loss: 0.0157
Epoch [7/10], Step [106/750], Loss: 0.1158
Epoch [7/10], Step [107/750], Loss: 0.0078
Epoch [7/10], Step [108/750], Loss: 0.0191
Epoch [7/10], Step [109/750], Loss: 0.0107
Epoch [7/10], Step [110/750], Loss: 0.0690
Epoch [7/10], Step [111/750], Loss: 0.1122
Epoch [7/10], Step [112/750], Loss: 0.0566
Epoch [7/10], Step [113/750], Loss: 0.0408
Epoch [7/10], Step [114/750], Loss: 0.0360
Epoch [7/10], Step [115/750], Loss: 0.0191
Epoch [7/10], Step [116/750], Loss: 0.0489
Epoch [7/10], Step [117/750], Loss: 0.1250
Epoch [7/10], Step [118/750], Loss: 0.0781
Epoch [7/10], Step [119/750], Loss: 0.0371
Epoch [7/10], Step [120/750], Loss: 0.0255
Epoch [7/10], Step [121/750], Loss: 0.1545
Epoch [7/10], Step [122/750], Loss: 0.0151
Epoch [7/10], Step [123/750], Loss: 0.1340
Epoch [7/10], Step [124/750], Loss: 0.0141
Epoch [7/10], Step [125/750], Loss: 0.1242
Epoch [7/10

KeyboardInterrupt: 

In [23]:
torch.save(hnn.state_dict(), 'trained_model_16qubits_4_layers_with_strong_entagling.pt')

In [None]:
20:08 -- 30 steps