In [1]:
# imports
import operator
import random
import numpy

from deap import base, benchmarks, creator, tools

from sklearn.svm import SVR
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.datasets import load_boston
from sklearn import metrics

import matplotlib.pyplot as plt

In [2]:
db = load_boston()
cmin = math.pow(2, -15)
cmax = math.pow(2, 3)
gmin = math.pow(2, -5)
gmax = math.pow(2, 15)

In [3]:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Particle", list, fitness=creator.FitnessMax, speed=list, cmin=cmin, cmax=cmax, gmin=gmin, gmax=gmax)

In [4]:
def generate(size, cmin, cmax, gmin, gmax):
    part = creator.Particle(random.uniform(pmin, pmax) for _ in range(size)) 
    part.speed = [random.uniform(smin, smax) for _ in range(size)]
    part.smin = smin
    part.smax = smax
    return part

In [5]:
def updateParticle(part, best, phi1, phi2):
    u1 = (random.uniform(0, phi1) for _ in range(len(part)))
    u2 = (random.uniform(0, phi2) for _ in range(len(part)))
    v_u1 = map(operator.mul, u1, map(operator.sub, part.best, part))
    v_u2 = map(operator.mul, u2, map(operator.sub, best, part))
    part.speed = list(map(operator.add, part.speed, map(operator.add, v_u1, v_u2)))
    for i, speed in enumerate(part.speed):
        if speed < part.smin:
            part.speed[i] = part.smin
        elif speed > part.smax:
            part.speed[i] = part.smax
    part[:] = list(map(operator.add, part, part.speed))

In [6]:
# def evaluate(part):
#     X = db.data
#     y = db.target
    
#     clf = SVR(kernel='rbf', C=individual[0], gamma=individual[1])
    
#     scores = cross_val_score(clf, X, y, cv=5)
#     acc = scores.mean()
    
#     return acc,

In [7]:
toolbox = base.Toolbox()
toolbox.register("particle", generate, size=2, pmin=-6, pmax=6, smin=-3, smax=3)
toolbox.register("population", tools.initRepeat, list, toolbox.particle)
toolbox.register("update", updateParticle, phi1=2.0, phi2=2.0)
toolbox.register("evaluate", benchmarks.h1)

In [8]:
def main():
    pop = toolbox.population(n=5)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", numpy.mean)
    stats.register("std", numpy.std)
    stats.register("min", numpy.min)
    stats.register("max", numpy.max)

    logbook = tools.Logbook()
    logbook.header = ["gen", "evals"] + stats.fields

    GEN = 10
    best = None
    
    print(pop)
    
    for g in range(GEN):
        for part in pop:
            part.fitness.values = toolbox.evaluate(part)
            print(part.fitness.values)
            if not part.best or part.best.fitness < part.fitness:
                part.best = creator.Particle(part)
                part.best.fitness.values = part.fitness.values
            if not best or best.fitness < part.fitness:
                best = creator.Particle(part)
                best.fitness.values = part.fitness.values
        for part in pop:
            toolbox.update(part, best)

        # Gather all the fitnesses in one list and print the stats
        logbook.record(gen=g, evals=len(pop), **stats.compile(pop))
        print(logbook.stream)
    
    return pop, logbook, best

In [9]:
if __name__ == "__main__":
    results = main()
    print(results[1])

[[-0.44240759548298847, -5.44068361236971], [-5.373749965727063, -2.877599694865553], [-4.437546446649334, -5.208042459564461], [-4.703897931836714, -4.9714819028986135], [5.563250024249161, 4.837731952094112]]
(0.03428927518009593,)
(0.05918637809166855,)
(0.03241764906522235,)
(0.05800409105877631,)
(0.3001421512557164,)
gen	evals	avg      	std     	min      	max     
0  	5    	0.0968079	0.102294	0.0324176	0.300142
(0.03419296007432597,)
(0.035737659567612066,)
(0.0901108859341814,)
(0.11350187296495871,)
(0.056313080820304715,)
1  	5    	0.0659713	0.031174	0.034193 	0.113502
(0.0836306924356786,)
(0.005897257792327472,)
(0.16346206016847514,)
(0.16477828076129727,)
(0.07070827060008987,)
2  	5    	0.0976953	0.0602966	0.00589726	0.164778
(0.1346016158767149,)
(0.1219070215508084,)
(0.25280644031612903,)
(0.217566057346511,)
(0.2341250545755314,)
3  	5    	0.192201 	0.0535406	0.121907  	0.252806
(0.21189357329633215,)
(0.1183368701309577,)
(0.5428566708799217,)
(0.4161835796817031,)
(