In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN
import torch.nn as nn
import torch.nn.functional as F
import csv
import torch.optim as optim
########################################################################
# Specify Parameters
output_file = 'ant_100k.csv'
max_bits = 100000
########################################################################
# Initialization
conv1_w = 0
conv2_w = 0
fc1_w = 0
fc2_w = 0
fc3_w = 0
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                  shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                               download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                 shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

########################################################################
def main():
    global conv1_w, conv2_w, fc1_w, fc2_w, fc3_w
#     for a in range(1,3):
#         conv1_w = a
#         for b in range(1,9):
#             conv2_w = b
#             for c in range(1,9):
#                 fc1_w = c
#                 for d in range(1,9):
#                     fc2_w = d
#                     for e in range(1,9):
#                         fc3_w = e
#                         totBits = a * 4704 + b * 1600 + c * 48000 + d * 10080 + e * 840
#                         if totBits < max_bits:
#                             runNet()
                            
                            
#     for a in range(1,9):
#         conv1_w = a
#         conv2_w = 8
#         fc1_w = 8
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for b in range(1,9):
#         conv1_w = 8
#         conv2_w = b
#         fc1_w = 8
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for c in range(1,9):
#         conv1_w = 8
#         conv2_w = 8
#         fc1_w = c
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for d in range(1,9):
#         conv1_w = 8
#         conv2_w = 8
#         fc1_w = 8
#         fc2_w = d
#         fc3_w = 8
#         runNet()

    for e in range(2,9):
        conv1_w = 8
        conv2_w = 8
        fc1_w = 8
        fc2_w = 8
        fc3_w = e
        runNet()

def runNet():
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()

            self.conv1 = QConv2d(3, 6, 5, num_bits_weight = conv1_w)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = QConv2d(6, 16, 5, num_bits_weight = conv2_w)
            self.fc1 = QLinear(16 * 5 * 5, 120, num_bits_weight = fc1_w)
            self.fc2 = QLinear(120, 84, num_bits_weight = fc2_w)
            self.fc3 = QLinear(84, 10, num_bits_weight = fc3_w)

        def update(self):
            self.__init__()

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    print("Starting")
    net = Net()
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    ########################################################################
    # Train the network

    for epoch in range(10):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training')

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


    for i in range(10):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
    
    ########################################################################
    #Append CSV
    print('Finished Training')
    fields = ['']*17
    fields[0] = str(max_bits)
    fields[1] = str(conv1_w)
    fields[2] = str(conv2_w)
    fields[3] = str(fc1_w)
    fields[4] = str(fc2_w)
    fields[5] = str(fc3_w)
    fields[6] = "{:.3f}".format(100 * correct / total)
    for i in range(10):
        fields[i+7] = "{:.3f}".format(100 * class_correct[i] / class_total[i])
    outfile = open(output_file, 'a', newline='')
    writer = csv.writer(outfile)
    writer.writerow(fields)
    outfile.close()


if __name__ == '__main__':
    main()

Starting
[1,  2000] loss: 1.948
[1,  4000] loss: 1.636
[1,  6000] loss: 1.523
[1,  8000] loss: 1.481
[1, 10000] loss: 1.420
[1, 12000] loss: 1.377
[2,  2000] loss: 1.326
[2,  4000] loss: 1.304
[2,  6000] loss: 1.254
[2,  8000] loss: 1.275
[2, 10000] loss: 1.244
[2, 12000] loss: 1.245
[3,  2000] loss: 1.174
[3,  4000] loss: 1.154
[3,  6000] loss: 1.177
[3,  8000] loss: 1.155
[3, 10000] loss: 1.171
[3, 12000] loss: 1.148
[4,  2000] loss: 1.082
[4,  4000] loss: 1.084
[4,  6000] loss: 1.107
[4,  8000] loss: 1.087
[4, 10000] loss: 1.076
[4, 12000] loss: 1.090
[5,  2000] loss: 1.027
[5,  4000] loss: 1.032
[5,  6000] loss: 1.047
[5,  8000] loss: 1.031
[5, 10000] loss: 1.025
[5, 12000] loss: 1.050
[6,  2000] loss: 0.984
[6,  4000] loss: 0.982
[6,  6000] loss: 1.008
[6,  8000] loss: 1.006
[6, 10000] loss: 1.020
[6, 12000] loss: 1.010
[7,  2000] loss: 0.960
[7,  4000] loss: 0.954
[7,  6000] loss: 0.977
[7,  8000] loss: 0.968
[7, 10000] loss: 0.981
[7, 12000] loss: 0.981
[8,  2000] loss: 0.907
[8

[10,  2000] loss: 1.304
[10,  4000] loss: 1.321
[10,  6000] loss: 1.344
[10,  8000] loss: 1.357
[10, 10000] loss: 1.335
[10, 12000] loss: 1.321
Finished Training
Accuracy of the network on the 10000 test images: 52 %
Accuracy of plane : 39 %
Accuracy of   car : 81 %
Accuracy of  bird : 43 %
Accuracy of   cat : 40 %
Accuracy of  deer : 40 %
Accuracy of   dog : 59 %
Accuracy of  frog : 41 %
Accuracy of horse : 60 %
Accuracy of  ship : 77 %
Accuracy of truck : 39 %
Finished Training
Starting
[1,  2000] loss: 1.863
[1,  4000] loss: 1.589
[1,  6000] loss: 1.538
[1,  8000] loss: 1.453
[1, 10000] loss: 1.427
[1, 12000] loss: 1.382
[2,  2000] loss: 1.320
[2,  4000] loss: 1.294
[2,  6000] loss: 1.300
[2,  8000] loss: 1.262
[2, 10000] loss: 1.242
[2, 12000] loss: 1.275
[3,  2000] loss: 1.228
[3,  4000] loss: 1.230
[3,  6000] loss: 1.212
[3,  8000] loss: 1.189
[3, 10000] loss: 1.209
[3, 12000] loss: 1.197
[4,  2000] loss: 1.159
[4,  4000] loss: 1.161
[4,  6000] loss: 1.156
[4,  8000] loss: 1.156


[6,  8000] loss: 1.053
[6, 10000] loss: 1.009
[6, 12000] loss: 1.023
[7,  2000] loss: 0.951
[7,  4000] loss: 0.983
[7,  6000] loss: 0.990
[7,  8000] loss: 0.969
[7, 10000] loss: 1.010
[7, 12000] loss: 1.002
[8,  2000] loss: 0.933
[8,  4000] loss: 0.950
[8,  6000] loss: 0.945
[8,  8000] loss: 0.953
[8, 10000] loss: 1.007
[8, 12000] loss: 0.970
[9,  2000] loss: 0.910
[9,  4000] loss: 0.921
[9,  6000] loss: 0.934
[9,  8000] loss: 0.939
[9, 10000] loss: 0.959
[9, 12000] loss: 0.967
[10,  2000] loss: 0.871
[10,  4000] loss: 0.902
[10,  6000] loss: 0.931
[10,  8000] loss: 0.932
[10, 10000] loss: 0.937
[10, 12000] loss: 0.937
Finished Training
Accuracy of the network on the 10000 test images: 61 %
Accuracy of plane : 71 %
Accuracy of   car : 78 %
Accuracy of  bird : 47 %
Accuracy of   cat : 37 %
Accuracy of  deer : 56 %
Accuracy of   dog : 42 %
Accuracy of  frog : 80 %
Accuracy of horse : 63 %
Accuracy of  ship : 65 %
Accuracy of truck : 71 %
Finished Training
Starting
[1,  2000] loss: 1.877


[3,  2000] loss: 1.247
[3,  4000] loss: 1.238
[3,  6000] loss: 1.243
[3,  8000] loss: 1.255
[3, 10000] loss: 1.212
[3, 12000] loss: 1.248
[4,  2000] loss: 1.148
[4,  4000] loss: 1.202
[4,  6000] loss: 1.180
[4,  8000] loss: 1.179
[4, 10000] loss: 1.195
[4, 12000] loss: 1.198
[5,  2000] loss: 1.124
[5,  4000] loss: 1.131
[5,  6000] loss: 1.148
[5,  8000] loss: 1.150
[5, 10000] loss: 1.150
[5, 12000] loss: 1.149
[6,  2000] loss: 1.084
[6,  4000] loss: 1.122
[6,  6000] loss: 1.099
[6,  8000] loss: 1.105
[6, 10000] loss: 1.108
[6, 12000] loss: 1.142
[7,  2000] loss: 1.062
[7,  4000] loss: 1.059
[7,  6000] loss: 1.074
[7,  8000] loss: 1.097
[7, 10000] loss: 1.107
[7, 12000] loss: 1.097
[8,  2000] loss: 1.036
[8,  4000] loss: 1.051
[8,  6000] loss: 1.074
[8,  8000] loss: 1.071
[8, 10000] loss: 1.054
[8, 12000] loss: 1.088
[9,  2000] loss: 1.018
[9,  4000] loss: 1.061
[9,  6000] loss: 1.046
[9,  8000] loss: 1.039
[9, 10000] loss: 1.055
[9, 12000] loss: 1.048
[10,  2000] loss: 0.975
[10,  4000

[1,  2000] loss: 1.925
[1,  4000] loss: 1.658
[1,  6000] loss: 1.552
[1,  8000] loss: 1.506
[1, 10000] loss: 1.465
[1, 12000] loss: 1.444
[2,  2000] loss: 1.361
[2,  4000] loss: 1.345
[2,  6000] loss: 1.343
[2,  8000] loss: 1.345
[2, 10000] loss: 1.321
[2, 12000] loss: 1.308
[3,  2000] loss: 1.256
[3,  4000] loss: 1.246
[3,  6000] loss: 1.231
[3,  8000] loss: 1.263
[3, 10000] loss: 1.245
[3, 12000] loss: 1.210
[4,  2000] loss: 1.160
[4,  4000] loss: 1.163
[4,  6000] loss: 1.179
[4,  8000] loss: 1.183
[4, 10000] loss: 1.175
[4, 12000] loss: 1.171
[5,  2000] loss: 1.101
[5,  4000] loss: 1.118
[5,  6000] loss: 1.127
[5,  8000] loss: 1.128
[5, 10000] loss: 1.154
[5, 12000] loss: 1.120
[6,  2000] loss: 1.059
[6,  4000] loss: 1.061
[6,  6000] loss: 1.074
[6,  8000] loss: 1.088
[6, 10000] loss: 1.100
[6, 12000] loss: 1.097
[7,  2000] loss: 1.003
[7,  4000] loss: 1.053
[7,  6000] loss: 1.054
[7,  8000] loss: 1.046
[7, 10000] loss: 1.050
[7, 12000] loss: 1.078
[8,  2000] loss: 1.010
[8,  4000] 