In [None]:
import csv
import operator
import random
import math
from deap import base
from deap import benchmarks
from deap import creator
from deap import tools

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
from numpy import genfromtxt

import torchvision
import torchvision.transforms as transforms

#Device set up
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# load the data
transform = transforms.Compose(
    [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
 
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# set up the network
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

loss_func = torch.nn.CrossEntropyLoss()

In [None]:
#Initialise Network
ResNet18 = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
ResNet18.load_state_dict(torch.load('GD.pkl'))
ResNet18.linear.reset_parameters()

# Number of parameters in the last layer of network
num_of_weights = sum(p.numel() for p in ResNet18.linear.parameters())

loss_values = []

# Fitness function setup (minimize loss)
creator.create("FitnessMin", base.Fitness, weights=(-1.0,)) # -1 is for minimise
creator.create("Particle", list, fitness=creator.FitnessMin, speed=list, smin=None, smax=None, best=None)

posMinInit      = -2
posMaxInit      = +2
VMaxInit        = 1.5
VMinInit        = 0.5
populationSize  = 30
dimension       = num_of_weights
interval        = 10
iterations      = 50 

#Parameter setup
wmax = 0.9 #weighting
wmin = 0.4 
c1   = 2.0
c2   = 2.0

def generate(size, smin, smax):
    part = creator.Particle(random.uniform(posMinInit, posMaxInit) for _ in range(size)) 
    part.speed = [random.uniform(VMinInit, VMaxInit) for _ in range(size)]
    part.smin = smin
    part.smax = smax
    return part


def updateParticle(part, best, weight):

    r1 = (random.uniform(0, 1) for _ in range(len(part)))
    r2 = (random.uniform(0, 1) for _ in range(len(part)))

    v_r0 = [weight*x for x in part.speed]
    v_r1 = [c1*x for x in map(operator.mul, r1, map(operator.sub, part.best, part))] # local best
    v_r2 = [c2*x for x in map(operator.mul, r2, map(operator.sub, best, part))] # global best
    
    part.speed = [0.7*x for x in map(operator.add, v_r0, map(operator.add, v_r1, v_r2))]
            
    # update position with speed
    part[:] = list(map(operator.add, part, part.speed))
    
def fitness(part):
    weights = np.asarray(part)
    ResNet18.linear.weight = torch.nn.Parameter(torch.from_numpy(weights[0:5120].reshape(512, 10).T).float())  # Update last layer weights
    ResNet18.linear.bias = torch.nn.Parameter(torch.from_numpy(weights[5120:5130].flatten()).float())
    ResNet18.to(device)
    
    total_loss = 0.0
    for data in trainloader:
        x, y = data
        x = x.to(device)
        y = y.to(device)
        output = ResNet18(x)
        loss = F.cross_entropy(output, y)
        total_loss += loss.item()
    
    avg_loss = total_loss / len(trainloader)
    loss_values.append(avg_loss)
    return (avg_loss,)

toolbox = base.Toolbox()
toolbox.register("particle", generate, size=dimension, smin=-3, smax=3)
toolbox.register("population", tools.initRepeat, list, toolbox.particle)
toolbox.register("update", updateParticle)
toolbox.register("evaluate", fitness) 

fitness_best = []

def main():

    
    pop = toolbox.population(n=populationSize)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    logbook = tools.Logbook()
    logbook.header = ["gen", "evals"] + stats.fields

    best = None
    
    #begin main loop
    for g in range(iterations):
        w = wmax - (wmax-wmin)*g/iterations #decaying inertia weight
        
        for part in pop:
            part.fitness.values = toolbox.evaluate(part)
            
            #update local best
            if (not part.best) or (part.best.fitness < part.fitness):    
                part.best = creator.Particle(part)
                part.best.fitness.values = part.fitness.values
            
            #update global best
            if (not best) or best.fitness < part.fitness:
                best = creator.Particle(part)
                best.fitness.values = part.fitness.values
                
        for part in pop:
            toolbox.update(part, best,w)
        
        fitness_best.append(best.fitness.values)
        print('completed',g)
        
        # Gather all the fitnesses in one list and print the stats
        # print every interval
        if g%interval==0:
            logbook.record(gen=g, evals=len(pop), **stats.compile(pop))
            print(logbook.stream)
            print('best ', best.fitness)
    
    print('fitness of best is', best.fitness)
    return pop, logbook, best

if __name__ == "__main__":
    pop, logbook, best = main()

weights = np.asarray(best)
ResNet18.linear.weight = torch.nn.Parameter(torch.from_numpy(weights[0:5120].reshape(512, 10).T).float())  # Update last layer weights
ResNet18.linear.bias = torch.nn.Parameter(torch.from_numpy(weights[5120:5130].flatten()).float())

#save the network
torch.save(ResNet18.state_dict(), 'PSO.pkl') 

#plot the loss
plt.plot(np.array(fitness_best), 'r')
plt.show()

plt.savefig('PSO_loss.pdf')

# Save to a CSV file
with open('PSO_fitness.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(fitness_best)


In [None]:
testNet = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
testNet.load_state_dict(torch.load('PSO.pkl'))

total = 0
correct= 0

with torch.no_grad():
    for data in testloader:
        x, y = data
        # calculate outputs by running images through the network
        outputs = testNet(x)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()


accuracy = correct/total
print('Accurarcy:', accuracy)