In [6]:
import os
import math
import random
import torch
import numpy as np 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd

In [7]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128, bias=False)
        self.fc2 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        
        return output

In [9]:
model = Net().to(device)

In [10]:
model.load_state_dict(torch.load('../weights/mnist_cnn.pt'))

<All keys matched successfully>

In [11]:
module_list = [module for module in model.modules() if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)]

module_shape = [m.weight.shape for m in module_list]

original_weights = [m.weight for m in module_list]

In [12]:
class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the supermask by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

In [13]:
class SupermaskConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False
        self.scores.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), sparsity)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        
        return x

In [14]:
class SupermaskLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # initialize the scores
        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))

        # NOTE: initialize the weights like this.
        nn.init.kaiming_normal_(self.weight, mode="fan_in", nonlinearity="relu")

        # NOTE: turn the gradient on the weights off
        self.weight.requires_grad = False
        self.scores.requires_grad = False

    def forward(self, x):
        subnet = GetSubnet.apply(self.scores.abs(), sparsity)
        w = self.weight * subnet
        
        return F.linear(x, w, self.bias)

In [15]:
class GANet(nn.Module):
    def __init__(self):
        super(GANet, self).__init__()
        self.conv1 = SupermaskConv(1, 32, 3, 1, bias=False)
        self.conv2 = SupermaskConv(32, 64, 3, 1, bias=False)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = SupermaskLinear(9216, 128, bias=False)
        self.fc2 = SupermaskLinear(128, 10, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        
        return output

In [16]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.join("../data", "mnist"),
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=True
    ),
    batch_size=10000,
    shuffle=True
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [17]:
sparsity = 0.5
seed = 1507

In [18]:
torch.manual_seed(seed)

<torch._C.Generator at 0x7fcfe00e9610>

In [19]:
for data, target in test_loader:
    test_data, test_target = data.to(device), target.to(device)

In [20]:
criterion = nn.CrossEntropyLoss().to(device)

In [21]:
def test(model, device):
    model.eval()
    
    with torch.no_grad():
        output = model(test_data)
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct = pred.eq(test_target.view_as(pred)).sum().item()
        out = correct / len(test_data)
    return out

In [22]:
ga_model = GANet().to(device)

In [23]:
scores = []

ga_module_list = [module for module in ga_model.modules() if isinstance(module, SupermaskConv) or isinstance(module, SupermaskLinear)]

In [24]:
# Update weight of ga model with original trained model 

for i, weight in enumerate(original_weights):
    ga_module_list[i].weight = weight

In [25]:
class Agent:
    def __init__(self, params):
        self.params = params
        self.fitness = 0
        
    def set_fitness(self, fitness):
        self.fitness = fitness

In [26]:
# Init population

def init_pop(pop_size=100):
    population = []
    for _ in range(pop_size):
        params = []
        for shape in module_shape:
            scores = nn.Parameter(torch.Tensor(shape))
            nn.init.kaiming_uniform_(scores, a=math.sqrt(5))
            params.append(scores)
        agent = Agent(params=params)
        population.append(agent)
    return population

In [27]:
def change_scores(module_list, agent):
    for i, m_scores in enumerate(agent.params):
        module_list[i].scores = m_scores

In [28]:
def mutation(agent, mut_rate=0.1):
    params = []
    for param in agent.params:
        out = param.clone()
        # flat_out and out share the same memory
        flat_out = out.flatten().to(device)
        # Get index mutation 
        indexes = np.where(np.random.uniform(low=0, high=1, size=(len(flat_out))) < mut_rate)
        replace_values = np.random.uniform(low=-1, high=1, size=(len(flat_out)))[indexes]
        # Mutation
        flat_out.index_copy_(0, torch.LongTensor(indexes[0]).to(device), torch.FloatTensor(replace_values).to(device))
        params.append(nn.Parameter(out))
    return Agent(params=params)

In [30]:
def recombine_agent(agent_1, agent_2):
    params_1 = []
    params_2 = []
    for i, param in enumerate(agent_1.params):
        param_1 = param.clone()
        param_2 = agent_2.params[i].clone()
        # Flatten 
        flat_1 = param_1.flatten().to(device)
        flat_2 = param_2.flatten().to(device)
        # Define children
        child_1 = torch.zeros(len(flat_1))
        child_2 = torch.zeros(len(flat_1))
        # Select cross point
        cross_pt = random.randint(0, len(flat_1))
        # Swap
        child_1[cross_pt:len(flat_1)] = flat_1[cross_pt:len(flat_1)]
        child_1[0:cross_pt] = flat_2[0:cross_pt]
        child_2[cross_pt:len(flat_1)] = flat_2[cross_pt:len(flat_1)]
        child_2[0:cross_pt] = flat_1[0:cross_pt]
        # Append to params 
        params_1.append(nn.Parameter(child_1.reshape(module_shape[i])))
        params_2.append(nn.Parameter(child_2.reshape(module_shape[i])))

    return Agent(params_1), Agent(params_2)

In [31]:
from tqdm import tqdm 

def evaluate_population(pop):
    avg_fit = 0
    best_fit = 0
    for agent in tqdm(pop):
        change_scores(ga_module_list, agent)
        fit = test(ga_model.to(device), device)
        agent.fitness = fit
        avg_fit += fit
        if agent.fitness > best_fit:
            best_fit = agent.fitness
    avg_fit /= len(pop)
    
    return pop, avg_fit, best_fit

In [32]:
def next_generation(pop, size=100, mut_rate=0.01):
    new_pop = []
    while len(new_pop) < size:
        parents = random.choices(pop, k=2, weights=[x.fitness**2 for x in pop])
        offspring_ = recombine_agent(parents[0],parents[1])
        offspring = [mutation(offspring_[0], mut_rate=mut_rate), mutation(offspring_[1], mut_rate=mut_rate)]
        new_pop.extend(offspring) #add offspring to next generation
    return new_pop

In [33]:
num_generations = 300
population_size = 50

pop = init_pop(population_size)

mutation_rate = 0.15 # 0.1% mutation rate

pop_fit = []

pop = init_pop(population_size) # initial population

for gen in range(num_generations):
    # trainning
    pop, avg_fit, best_fit = evaluate_population(pop)
    if avg_fit > 0.8:
        population_size = 100
        mutation_rate = 0.1
    if avg_fit > 0.85:
        population_size = 150
        mutation_rate = 0.05
    if avg_fit > 0.9:
        population_size = 300
        mutation_rate = 0.025
    if avg_fit > 0.95:
        population_size = 600
        mutation_rate = 0.01
    print('Generation {} with pop_fit {} | best_fit {}'.format(gen, avg_fit, best_fit))
    pop_fit.append(avg_fit) # record population average fitness
    new_pop = next_generation(pop, size=population_size, mut_rate=mutation_rate)
    pop = new_pop

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 50/50 [00:02<00:00, 18.55it/s]


Generation 0 with pop_fit 0.7981119999999998 | best_fit 0.9572


100%|██████████| 50/50 [00:02<00:00, 24.07it/s]


Generation 1 with pop_fit 0.815992 | best_fit 0.9597


KeyboardInterrupt: 