In [None]:
import torch
import numpy as np

# Assuming the fitness function is defined elsewhere and compatible with PyTorch tensors
# For demonstration, let's assume a simple quadratic fitness function
def frosenbrock(x):
    return torch.sum(100.0 * (x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)



def purecmaes_pytorch(N=20, stopfitness=1e-10, stopeval=1e3 * N**2):
    xmean = torch.rand(N, 1)  # Objective variables initial point
    sigma = 0.3  # Coordinate wise standard deviation (step size)
    
    lambda_ = int(4 + np.floor(3 * np.log(N)))  # Population size
    mu = lambda_ // 2
    weights = torch.log(torch.tensor(mu + 0.5)) - torch.log(torch.arange(1, mu+1).float())
    mu = int(np.floor(mu))
    weights /= torch.sum(weights)
    mueff = torch.sum(weights)**2 / torch.sum(weights**2)
    
    cc = (4 + mueff / N) / (N + 4 + 2 * mueff / N)
    cs = (mueff + 2) / (N + mueff + 5)
    c1 = 2 / ((N + 1.3)**2 + mueff)
    cmu = min(1 - c1, 2 * (mueff - 2 + 1 / mueff) / ((N + 2)**2 + mueff))
    damps = 1 + 2 * max(0, np.sqrt((mueff - 1) / (N + 1)) - 1) + cs
    
    pc = torch.zeros(N, 1)
    ps = torch.zeros(N, 1)
    B = torch.eye(N)
    D = torch.ones(N, 1)
    C = B @ torch.diag(D.view(-1)) @ B.T
    invsqrtC = B @ torch.diag(1/D.view(-1)) @ B.T
    chiN = N**0.5 * (1 - 1/(4*N) + 1/(21*N**2))
    
    counteval = 0
    while counteval < stopeval:
        arx = torch.stack([xmean + sigma * B @ (D * torch.randn(N, 1)) for _ in range(lambda_)], dim=1)
        arfitness = torch.tensor([frosenbrock(arx[:, k]) for k in range(lambda_)])
        counteval += lambda_
        
        arindex = torch.argsort(arfitness)
        xold = xmean.clone()
        xmean = torch.mm(arx[:, arindex[:mu]], weights.unsqueeze(1))
        
        ps = (1 - cs) * ps + torch.sqrt(cs * (2 - cs) * mueff) * invsqrtC @ (xmean - xold) / sigma
        hsig = torch.norm(ps) / torch.sqrt(1 - (1 - cs)**(2 * counteval / lambda_)) / chiN < 1.4 + 2 / (N + 1)
        pc = (1 - cc) * pc + hsig * torch.sqrt(cc * (2 - cc) * mueff) * (xmean - xold) / sigma
        
        artmp = (1/sigma) * (arx[:, arindex[:mu]] - xold.repeat(1, mu))
        C = (1 - c1 - cmu) * C + c1 * (pc @ pc.T + (1 - hsig) * cc * (2 - cc) * C) + cmu * artmp @ torch.diag(weights) @ artmp.T
        
        sigma *= torch.exp((cs / damps) * (torch.norm(ps) / chiN - 1))
        
        if counteval - lambda_ > lambda_/(c1 + cmu) / N / 10:
            C = torch.triu(C) + torch.triu(C, 1).T
            D, B = torch.linalg.eigh(C)
            D = D.sqrt()
            invsqrtC = B @ torch.diag(1/D) @ B.T
            
        if arfitness[arindex[0]] <= stopfitness or torch.max(D) / torch.min(D) > 1e7:
            break
            
    return arx[:, arindex[0]]

# Example call to the function
xmin = purecmaes_pytorch()

    