In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
import copy
from IPython.display import display, clear_output

# Load the target

In [None]:
# 加载目标图片并将其移至GPU
target_environment = cp.asarray(cv2.imread('./res/firefox1.png'))
target_environment = cp.asarray(cv2.resize(cp.asnumpy(target_environment), (IMG_SHAPE[1], IMG_SHAPE[0])))

# Individual

In [None]:
# 定义个体类
class Individual:
    def __init__(self, r=None, g=None, b=None, lifespan=100):
        self.r = r if r is not None else random.randint(0, 255)
        self.g = g if g is not None else random.randint(0, 255)
        self.b = b if b is not None else random.randint(0, 255)
        self.lifespan = lifespan

    def dna(self):
        return np.array([self.r, self.g, self.b])
  
    def fitness(self, target):
        return np.square(np.subtract(np.array(target), np.array(self.dna()))).mean()

    def crossover(self, other):
        child_r = random.choice([self.r, other.r])
        child_g = random.choice([self.g, other.g])
        child_b = random.choice([self.b, other.b])
        return Individual(child_r, child_g, child_b)

    def mutate(self, mutation_rate, sigma=10):
        if random.random() < mutation_rate:
            self.gaussian_mutation(sigma)

    def gaussian_mutation(self, sigma=10):
        self.r = int(np.clip(self.r + random.gauss(0, sigma), 0, 255))
        self.g = int(np.clip(self.g + random.gauss(0, sigma), 0, 255))
        self.b = int(np.clip(self.b + random.gauss(0, sigma), 0, 255))

# Environment

In [None]:
class Environment:
    def __init__(self, width, height, init_population_size):
        self.population = [[None for _ in range(width)] for _ in range(height)]
        self.width = width
        self.height = height
        self.scores = np.zeros((width, height))

        # Place the singular to area
        for _ in range(init_population_size):
            position = (random.randint(0, width - 1), random.randint(0, height - 1))
            self.population[position[0]][position[1]] = Individual()

    def convert_to_map(self):
        population_map = np.zeros((self.width, self.height, 3), dtype=np.uint8)
        for i in range(self.width):
            for j in range(self.height):
                if self.population[i][j] is not None:
                    population_map[i, j] = self.population[i][j].dna()
        return population_map
    
    def add_individuals(self, size, possibility=0.5):
        for i in range(self.width):
            for j in range(self.height):
                if self.population[i][j] is None and np.random.rand() < possibility:
                    self.population[i][j] = Individual()

    def decay_population(self):
        # Decay the lifespan of each individual
        for i in range(self.width):
            for j in range(self.height):
                if self.population[i][j] is not None:
                    # Decrease the lifespan
                    self.population[i][j].lifespan -= 1

                    # Remove the individual if the lifespan is 0
                    if self.population[i][j].lifespan <= 0:
                        self.population[i][j] = None
    
    def update_scores(self, img):
        for i in range(self.width):
            for j in range(self.height):
                if self.population[i][j] is not None:
                    self.scores[i, j] = self.population[i][j].fitness(*img[i, j])



In [None]:
# 种群衰减
def decay_population(population, environment, eliminate_rate=0.2):
    scores = {}
    empty_positions = []
    for i in range(len(population)):
        for j in range(len(population[i])):
            if population[i][j] is not None:
                population[i][j].fitness(environment[i, j])
                population[i][j].lifespan -= 1
                if population[i][j].lifespan <= 0:
                    empty_positions.append((i, j))
                    population[i][j] = None
                else:
                    scores[(i, j)] = population[i][j].fitness(environment[i, j])
    
    # 淘汰适应度最差的1/5个体
    sorted_scores = sorted(scores.items(), key=lambda x: x[1])
    eliminate_num = int(len(sorted_scores) * eliminate_rate)
    for i in range(eliminate_num):
        position = sorted_scores[i][0]
        population[position[0]][position[1]] = None
        empty_positions.append(position)
    
    add_random_individuals(population, empty_positions)
    return population

# 种群繁殖
def reproduce(population, target_environment, iteration):
    copied_population = copy.deepcopy(population)
    for i in range(len(population)):
        for j in range(len(population[i])):
            if population[i][j] is not None:
                individuals, empty_positions = lookup(population, (i, j))
                if len(individuals) > 0:
                    partner = random.choice(individuals)
                    child = population[i][j].crossover(partner)
                else:
                    child = population[i][j].crossover(population[i][j])
                
                target_color = target_environment[i, j][:3].tolist()
                mutation_rate = dynamic_mutation_rate(child.fitness(target_color))
                
                if iteration % MUTATION_INTERVAL == 0:
                    child.mutate(mutation_rate, target_color)
                
                if len(empty_positions) > 0:
                    for pos in empty_positions:
                        copied_population[pos[0]][pos[1]] = child
                else:
                    weak_individual = find_weak_individual(individuals, target_environment[i, j])
                    if weak_individual:
                        for x in range(len(population)):
                            for y in range(len(population[x])):
                                if population[x][y] == weak_individual:
                                    copied_population[x][y] = child
    return copied_population

# 查找周围个体
def lookup(population, position, radius=1):
    individuals = []
    empty_positions = []
    for i in range(-radius, radius + 1):
        for j in range(-radius, radius + 1):
            if 0 <= position[0] + i < len(population) and 0 <= position[1] + j < len(population[0]):
                if population[position[0] + i][position[1] + j] is not None:
                    individuals.append(population[position[0] + i][position[1] + j])
                else:
                    empty_positions.append((position[0] + i, position[1] + j))
    return individuals, empty_positions

# 找到周围的弱势个体
def find_weak_individual(individuals, environment):
    weakest_individual = None
    weakest_fitness = float('inf')
    for individual in individuals:
        fitness_score = individual.fitness(environment)
        if fitness_score < weakest_fitness:
            weakest_fitness = fitness_score
            weakest_individual = individual
    return weakest_individual

# 计算均方误差
def mse(population, environment, image_shape):
    p_map = population_map(population, image_shape)
    error = cp.square(cp.subtract(p_map, environment)).mean()
    return error

# 计算种群的适应度
def compute_population_fitness(population, environment):
    total_fitness = 0
    num_individuals = 0
    for i in range(len(population)):
        for j in range(len(population[i])):
            if population[i][j] is not None:
                total_fitness += population[i][j].fitness(environment[i, j][:3])
                num_individuals += 1
    return total_fitness / num_individuals if num_individuals > 0 else float('inf')


In [None]:
# 开始进化
NUM_ITERATIONS = 1000
population = initialize_population(POPULATION_SIZE, IMG_SHAPE)
mse_history = []

plt.ion()
fig, ax = plt.subplots()

for iteration in range(NUM_ITERATIONS):
    population = decay_population(population, target_environment)
    population = reproduce(population, target_environment, iteration)
    current_mse = mse(population, target_environment, IMG_SHAPE)
    
    if iteration % 100 == 0:
        mse_history.append(current_mse)

    p_map = population_map(population, IMG_SHAPE)
    p_map = cv2.resize(cp.asnumpy(p_map), (300, 300))

    avg_fitness = compute_population_fitness(population, target_environment)
    
    clear_output(wait=True)
    # ax.imshow(cv2.cvtColor(p_map.astype(np.uint8), cv2.COLOR_BGR2RGB))
    ax.imshow(p_map.astype(np.uint8))
    ax.set_title(f"Iteration {iteration + 1}, Average Fitness: {avg_fitness:.2f}, MSE: {current_mse:.2f}")
    display(fig)
    plt.pause(0.01)

# Plot the MSE over iterations
x_values = range(0, len(mse_history) * 100, 100)
plt.figure()
plt.plot(x_values, cp.asnumpy(cp.array(mse_history)))
plt.xlabel('Iteration')
plt.ylabel('MSE')
plt.title('MSE over Iterations')
plt.show()