In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN
import torch.nn as nn
import torch.nn.functional as F
import csv
import torch.optim as optim
########################################################################
# Specify Parameters
output_file = 'ant_100k.csv'
max_bits = 100000
########################################################################
# Initialization
conv1_w = 0
conv2_w = 0
fc1_w = 0
fc2_w = 0
fc3_w = 0
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                  shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                               download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                 shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

########################################################################
def main():
    global conv1_w, conv2_w, fc1_w, fc2_w, fc3_w
#     for a in range(1,3):
#         conv1_w = a
#         for b in range(1,9):
#             conv2_w = b
#             for c in range(1,9):
#                 fc1_w = c
#                 for d in range(1,9):
#                     fc2_w = d
#                     for e in range(1,9):
#                         fc3_w = e
#                         totBits = a * 4704 + b * 1600 + c * 48000 + d * 10080 + e * 840
#                         if totBits < max_bits:
#                             runNet()
                            
                            
#     for a in range(1,9):
#         conv1_w = a
#         conv2_w = 8
#         fc1_w = 8
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for b in range(1,9):
#         conv1_w = 8
#         conv2_w = b
#         fc1_w = 8
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for c in range(1,9):
#         conv1_w = 8
#         conv2_w = 8
#         fc1_w = c
#         fc2_w = 8
#         fc3_w = 8
#         runNet()

#     for d in range(1,9):
#         conv1_w = 8
#         conv2_w = 8
#         fc1_w = 8
#         fc2_w = d
#         fc3_w = 8
#         runNet()

    for e in range(2,9):
        conv1_w = 8
        conv2_w = 8
        fc1_w = 8
        fc2_w = 8
        fc3_w = e
        runNet()

def runNet():
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()

            self.conv1 = QConv2d(3, 6, 5, num_bits_weight = conv1_w)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = QConv2d(6, 16, 5, num_bits_weight = conv2_w)
            self.fc1 = QLinear(16 * 5 * 5, 120, num_bits_weight = fc1_w)
            self.fc2 = QLinear(120, 84, num_bits_weight = fc2_w)
            self.fc3 = QLinear(84, 10, num_bits_weight = fc3_w)

        def update(self):
            self.__init__()

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    print("Starting")
    net = Net()
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    ########################################################################
    # Train the network

    for epoch in range(10):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training')

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


    for i in range(10):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
    
    ########################################################################
    #Append CSV
    print('Finished Training')
    fields = ['']*17
    fields[0] = str(max_bits)
    fields[1] = str(conv1_w)
    fields[2] = str(conv2_w)
    fields[3] = str(fc1_w)
    fields[4] = str(fc2_w)
    fields[5] = str(fc3_w)
    fields[6] = "{:.3f}".format(100 * correct / total)
    for i in range(10):
        fields[i+7] = "{:.3f}".format(100 * class_correct[i] / class_total[i])
    outfile = open(output_file, 'a', newline='')
    writer = csv.writer(outfile)
    writer.writerow(fields)
    outfile.close()


if __name__ == '__main__':
    main()

Starting
[1,  2000] loss: 1.909
[1,  4000] loss: 1.652
[1,  6000] loss: 1.547
[1,  8000] loss: 1.493
[1, 10000] loss: 1.449
[1, 12000] loss: 1.409
[2,  2000] loss: 1.354
[2,  4000] loss: 1.337
[2,  6000] loss: 1.378
[2,  8000] loss: 1.310
[2, 10000] loss: 1.349
[2, 12000] loss: 1.337
[3,  2000] loss: 1.258
[3,  4000] loss: 1.261
[3,  6000] loss: 1.251
[3,  8000] loss: 1.255
[3, 10000] loss: 1.245
[3, 12000] loss: 1.262
[4,  2000] loss: 1.181
[4,  4000] loss: 1.201
[4,  6000] loss: 1.193
[4,  8000] loss: 1.191
[4, 10000] loss: 1.211
[4, 12000] loss: 1.222
[5,  2000] loss: 1.148
[5,  4000] loss: 1.152
[5,  6000] loss: 1.161
[5,  8000] loss: 1.202
[5, 10000] loss: 1.194
[5, 12000] loss: 1.206
[6,  2000] loss: 1.154
[6,  4000] loss: 1.162
[6,  6000] loss: 1.183
[6,  8000] loss: 1.170
[6, 10000] loss: 1.180
[6, 12000] loss: 1.172
[7,  2000] loss: 1.098
[7,  4000] loss: 1.139
[7,  6000] loss: 1.130
[7,  8000] loss: 1.166
[7, 10000] loss: 1.177
[7, 12000] loss: 1.196
[8,  2000] loss: 1.123
[8

[10,  2000] loss: 0.939
[10,  4000] loss: 0.943
[10,  6000] loss: 0.959
[10,  8000] loss: 0.959
[10, 10000] loss: 0.980
[10, 12000] loss: 0.985
Finished Training
Accuracy of the network on the 10000 test images: 60 %
Accuracy of plane : 59 %
Accuracy of   car : 76 %
Accuracy of  bird : 38 %
Accuracy of   cat : 38 %
Accuracy of  deer : 63 %
Accuracy of   dog : 54 %
Accuracy of  frog : 71 %
Accuracy of horse : 58 %
Accuracy of  ship : 79 %
Accuracy of truck : 69 %
Finished Training
Starting
[1,  2000] loss: 1.856
[1,  4000] loss: 1.617
[1,  6000] loss: 1.527
[1,  8000] loss: 1.465
[1, 10000] loss: 1.428
[1, 12000] loss: 1.402
[2,  2000] loss: 1.313
[2,  4000] loss: 1.306
[2,  6000] loss: 1.289
[2,  8000] loss: 1.276
[2, 10000] loss: 1.264
[2, 12000] loss: 1.246
[3,  2000] loss: 1.190
[3,  4000] loss: 1.191
[3,  6000] loss: 1.190
[3,  8000] loss: 1.169
[3, 10000] loss: 1.175
[3, 12000] loss: 1.170
[4,  2000] loss: 1.093
[4,  4000] loss: 1.122
[4,  6000] loss: 1.117
[4,  8000] loss: 1.107
