In [12]:
from Pooling import Pooling_2D_density_3D
from Conv_Layer import Conv_RBS_density_I2_3D
import torch
import torch.nn as nn  # the neural network library of pytorch
import load_dataset_letao as load  # module with function to load MNIST
from toolbox import reduce_MNIST_dataset
from training import test_net, train_net

In [13]:
class QCNN(nn.Module):
    def __init__(self, I, O, J, K, device):
        super(QCNN, self).__init__()
        self.conv1 = Conv_RBS_density_I2_3D(I,K,J,device)
        self.pool1 = Pooling_2D_density_3D(I, O, J, device)
        self.conv2 = Conv_RBS_density_I2_3D(O,K,J,device)
        self.pool2 = Pooling_2D_density_3D(O, O//2, J, device)
        self.fc = nn.Linear((O//2)*(O//2)*J, 10)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        device_cpu = torch.device("cpu")
        d = torch.stack([torch.diag(p2[i]) for i in range(batch_size)]).to(device_cpu)
        output = self.fc(d)
        return output.to(device)    # return x for visualization

In [18]:
I = 12 # dimension of image we use
O = I//2 # dimension after pooling
J = 2 # number of channel
K = 2 # size of kernel
k = 3
batch_size = 10 # batch number
scala = 6000 # time we reduce dataset
learning_rate = 1e-2
device = torch.device("mps")

train_loader, test_loader, dim_in, dim_out = load.load_MNIST(batch_size=batch_size)
reduced_loader = reduce_MNIST_dataset(train_loader, scala)
reduced_test_loader = reduce_MNIST_dataset(test_loader, 100)

conv_network = QCNN(I, O, J, K, device)
optimizer = torch.optim.Adam(conv_network.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

loss_list = []
accuracy_list = []
for epoch in range(10):
    train_loss, train_accuracy = train_net(batch_size, I, J, k, conv_network, train_loader, criterion, optimizer, device)
    loss_list.append(train_loss)
    accuracy_list.append(train_accuracy*100)
    print(f'Epoch {epoch}: Loss = {train_loss:.6f}, accuracy = {train_accuracy*100:.4f} %')

Epoch 0: Loss = 2.291581, accuracy = 0.0000 %
Epoch 1: Loss = 2.269141, accuracy = 0.0000 %
Epoch 2: Loss = 2.247290, accuracy = 10.0000 %
Epoch 3: Loss = 2.226036, accuracy = 20.0000 %
Epoch 4: Loss = 2.205383, accuracy = 30.0000 %
Epoch 5: Loss = 2.185327, accuracy = 40.0000 %
Epoch 6: Loss = 2.165860, accuracy = 30.0000 %
Epoch 7: Loss = 2.146968, accuracy = 30.0000 %
Epoch 8: Loss = 2.128629, accuracy = 30.0000 %
Epoch 9: Loss = 2.110826, accuracy = 30.0000 %


In [19]:
test_loss, test_accuracy = test_net(batch_size, I, J, k, conv_network, reduced_test_loader, criterion, device)
print(f'Evaluation on test set: Loss = {test_loss:.6f}, accuracy = {test_accuracy*100:.4f} %')

Evaluation on test set: Loss = 2.313761, accuracy = 14.0000 %


In [20]:
total_params = sum(p.numel() for p in conv_network.parameters())
print(f"Number of parameters: {total_params}")

Number of parameters: 196
