In [1]:
from torch.utils.data import DataLoader
import numpy as np
import torch
import sys
from torch.autograd import Variable
sys.path.append("../pytorch/")
sys.path.append("../evolution/")
from evolution import FitnessBasedSamplingEvolution
from net import EvolutionMNIST, GenericDataset, load_mnist

In [2]:
train, _ = load_mnist(path="/home/marlon/songstmep", train_only=True, nrows=10000)

In [3]:
net = EvolutionMNIST(784, 10, 50)

In [4]:
train_loader = DataLoader(train, batch_size=50)

In [5]:
def fitness(model):
    loss_func = torch.nn.CrossEntropyLoss()
    fitness_scores = []
    for data, target in train_loader:
        data = Variable(data, requires_grad=False).float()
        target = Variable(target, requires_grad=False)
        pred = model(data)
        loss = loss_func(pred, torch.max(target, 1)[1])
        fitness = (1 / loss.data[0]**2)
        fitness_scores.append(fitness)
    return np.mean(fitness_scores)

In [6]:
evo = FitnessBasedSamplingEvolution(model=net, pop_size=20, noise_std=0.05, fitness_func=fitness)

In [7]:
evo.evolve()

[0.00032240572137521564, 0.00048903931706987766, 0.00035551453806274689, 0.00079818937535964325, 0.000308308442481553, 0.00075522645590295808, 0.00049137244828678271, 0.00094204112085052383, 0.000384082851242184, 0.00067536497962534245, 0.00040062143032058037, 0.00043928943439722716, 0.00041043237666504249, 0.00065102708601366352, 0.00086657440721381351, 0.00069665086309977475, 0.00045175846956516027, 0.00033720666689941931, 0.00032926597145044827, 0.0005224073832853605]
0 | 0.00032 0.030
1 | 0.00049 0.046
2 | 0.00036 0.033
3 | 0.00080 0.075
4 | 0.00031 0.029
5 | 0.00076 0.071
6 | 0.00049 0.046
7 | 0.00094 0.089
8 | 0.00038 0.036
9 | 0.00068 0.064
10 | 0.00040 0.038
11 | 0.00044 0.041
12 | 0.00041 0.039
13 | 0.00065 0.061
14 | 0.00087 0.082
15 | 0.00070 0.066
16 | 0.00045 0.043
17 | 0.00034 0.032
18 | 0.00033 0.031
19 | 0.00052 0.049
[10 15 13 10 10  0  1  4 14 16  9  3  8  7 16 13  3  1 19  7]
0 0.000942041120851
-----------------
[0.00021949392479471777, 0.00021693935993319477, 0.000