In [None]:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import time
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from joblib import Parallel, delayed
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from deap import base, creator, tools, algorithms

"""## Estruturas para utilizar o Simulador e o Gerador"""

class ResBlock(nn.Module):
    """
    Define um bloco ResNet b√É¬°sico
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.main_path = nn.Sequential(

            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut_path = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
            self.shortcut_path = nn.Sequential(

                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.main_path(x) + self.shortcut_path(x)
        out = F.relu(out)
        return out

# N_OUTPUTS √É¬© quantos valores se quer prever
N_OUTPUTS = 4

class ResNetSimulator(nn.Module):
    """
    Implementa√É¬ß√É¬£o do Simulator baseado em ResNet.
    """
    def __init__(self, in_channels=1, n_outputs=N_OUTPUTS):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.layer1 = ResBlock(64, 64, stride=1)

        self.layer2 = ResBlock(64, 128, stride=2)

        self.layer3 = ResBlock(128, 256, stride=2)

        self.layer4 = ResBlock(256, 256, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((2, 2))

        self.head = nn.Sequential(
            nn.Linear(256 * 2 * 2, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, n_outputs)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.head(out)
        return out

class Sin(nn.Module):
    def forward(self, x): return torch.sin(x)

class Gaussian(nn.Module):
    def forward(self, x): return torch.exp(-x**2)

def make_coordinate_grid(size, device='cpu'):
    """Cria grade de coordenadas (x, y)"""
    xs = np.linspace(-1, 1, size)
    ys = np.linspace(-1, 1, size)
    xx, yy = np.meshgrid(xs, ys)

    coords = np.stack([xx.ravel(), yy.ravel()], axis=-1).astype(np.float32)
    return torch.from_numpy(coords).to(device) # shape (size*size, 2)

class CPPN_Generator(nn.Module):
    """
    Gerador CPPN. Mapeia (vetor latente 'v' + coords 'x,y,r') -> pixel.
    """
    def __init__(self, latent_dim, in_coords=2, out_channels=1, hidden_dim=64):
            super().__init__()
            self.latent_dim = latent_dim
            self.in_coords = in_coords
            self.out_channels = out_channels

            input_dim = in_coords + latent_dim

            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.Tanh(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),

                nn.Linear(hidden_dim, hidden_dim),
                Sin(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.LeakyReLU(0.2),

                nn.Linear(hidden_dim, hidden_dim),
                Gaussian(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh(),

                nn.Linear(hidden_dim, hidden_dim),
                Sin(),

                nn.Linear(hidden_dim, hidden_dim),
                Gaussian(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),

                nn.Linear(hidden_dim, hidden_dim),
                Sin(),

                nn.Linear(hidden_dim, out_channels),
                nn.Sigmoid()
            )

    def forward(self, coords, v):

        v_expanded = v.unsqueeze(1)
        v_tiled = v_expanded.repeat(1, coords.size(0), 1)

        coords_tiled = coords.unsqueeze(0).repeat(v.size(0), 1, 1)

        combined_input = torch.cat([coords_tiled, v_tiled], dim=-1)

        output = self.net(combined_input)

        img_size = int(np.sqrt(coords.size(0)))

        output = output.permute(0, 2, 1).view(-1, self.out_channels, img_size, img_size)

        return output

"""# Mapas de Fase

Cria√É¬ß√É¬£o do mapa de fase para a polariza√É¬ß√É¬£o x:
"""

def load_and_preprocess_image(image_path, target_size=(450, 450)):
    """
    Carrega e pr√É¬©-processa a imagem alvo usando PIL
    """
    try:
        image = Image.open(image_path).convert('L')
        image = image.resize(target_size, Image.LANCZOS)
        image_array = np.array(image, dtype=np.float64)
        image_array = image_array / np.max(image_array)
        return image_array
    except FileNotFoundError:
        print(f"Aten√É¬ß√É¬£o: Imagem '{image_path}' n√É¬£o encontrada. Criando imagem de teste...")
        target_image = np.zeros(target_size)
        target_image[150:300, 100:200] = 1.0
        target_image[150:200, 200:350] = 1.0
        target_image[250:300, 200:350] = 1.0
        return target_image

def apply_zero_padding(image, padding_factor=2):
    """
    Aplica zero-padding √É¬† imagem
    """
    original_size = image.shape
    padded_size = (image.shape[0] * padding_factor, image.shape[1] * padding_factor)
    padded_image = np.zeros(padded_size, dtype=complex)

    start_row = (padded_size[0] - original_size[0]) // 2
    start_col = (padded_size[1] - original_size[1]) // 2
    padded_image[start_row:start_row+original_size[0],
                 start_col:start_col+original_size[1]] = image

    return padded_image, original_size

def create_low_pass_filter(shape, wavelength, dx, NA):
    """
    Cria filtro passa-baixa baseado na abertura num√É¬©rica
    """
    nx, ny = shape
    fx = np.fft.fftfreq(nx, dx)
    fy = np.fft.fftfreq(ny, dx)
    FX, FY = np.meshgrid(fx, fy, indexing='ij')

    f_cutoff = NA / wavelength
    freq_radius = np.sqrt(FX**2 + FY**2)
    filter_mask = (freq_radius <= f_cutoff).astype(np.float64)

    return filter_mask

def angular_spectrum_propagation(U, wavelength, z, dx, filter_mask=None):
    """
    Propaga o campo usando m√É¬©todo do espectro angular
    """
    k = 2 * np.pi / wavelength
    nx, ny = U.shape

    fx = np.fft.fftfreq(nx, dx)
    fy = np.fft.fftfreq(ny, dx)
    FX, FY = np.meshgrid(fx, fy, indexing='ij')

    root_term = 1 - (wavelength * FX)**2 - (wavelength * FY)**2
    root_term[root_term < 0] = 0

    H = np.exp(1j * k * z * np.sqrt(root_term))

    if filter_mask is not None:
        H = H * filter_mask

    U_freq = fft2(U)
    U_prop_freq = U_freq * H
    U_prop = ifft2(U_prop_freq)

    return U_prop

def calculate_correlation(target, reconstructed):
    """
    Calcula a correla√É¬ß√É¬£o de Pearson entre duas imagens (valores reais)
    """
    target_real = np.real(target).flatten()
    reconstructed_real = np.real(reconstructed).flatten()

    correlation = np.corrcoef(target_real, reconstructed_real)[0, 1]

    if np.isnan(correlation):
        return 0.0

    return float(correlation)

def extract_center(image, original_size):
    """
    Extrai regi√É¬£o central da imagem (remove padding)
    """
    nx, ny = original_size
    start_row = (image.shape[0] - nx) // 2
    start_col = (image.shape[1] - ny) // 2
    return image[start_row:start_row+nx, start_col:start_col+ny]

def reconstruct_image(phase_map, wavelength, z, dx, NA):
    """
    Reconstr√É¬≥i a imagem a partir do mapa de fase
    """
    # Aplica zero-padding ao mapa de fase
    phase_padded, original_size = apply_zero_padding(np.exp(1j * phase_map))

    # Cria filtro
    filter_mask = create_low_pass_filter(phase_padded.shape, wavelength, dx, NA)

    # Propaga para o plano da imagem
    reconstructed = angular_spectrum_propagation(phase_padded, wavelength, z, dx, filter_mask)

    # Extrai regi√É¬£o central e pega a amplitude (np.abs)
    reconstructed = extract_center(np.abs(reconstructed), original_size)

    return np.real(reconstructed)  # Garante valor real

def reconstruct_image_from_field(complex_field, wavelength, z, dx, NA):
    """
    Reconstr√≥i a imagem a partir de um CAMPO COMPLEXO (Amp + Fase)
    """
    # 1. Aplica zero-padding ao campo complexo
    field_padded, original_size = apply_zero_padding(complex_field)

    # 2. Cria filtro
    filter_mask = create_low_pass_filter(field_padded.shape, wavelength, dx, NA)

    # 3. Propaga para o plano da imagem
    reconstructed = angular_spectrum_propagation(field_padded, wavelength, z, dx, filter_mask)

    # 4. Extrai regi√£o central e pega a amplitude (np.abs)
    reconstructed_center = extract_center(reconstructed, original_size)
    reconstructed_amplitude = np.abs(reconstructed_center)

    return np.real(reconstructed_amplitude) # Garante valor real

def gerchberg_saxton_angular_spectrum(target, wavelength, z, dx, NA, num_iter=50):
    """
    Algoritmo de Gerchberg-Saxton com espectro angular
    Retornar a imagem original e a reconstru√É¬≠da
    """
    target_padded, original_size = apply_zero_padding(target)
    nx_pad, ny_pad = target_padded.shape

    filter_mask = create_low_pass_filter((nx_pad, ny_pad), wavelength, dx, NA)

    phase = np.random.rand(nx_pad, ny_pad) * 2 * np.pi
    U = target_padded * np.exp(1j * phase)

    correlations = []

    for i in range(num_iter):
        # 1. Propaga para o plano da imagem
        U_image = angular_spectrum_propagation(U, wavelength, z, dx, filter_mask)

        # 2. Mant√É¬©m a fase, atualiza amplitude com alvo
        amplitude_image = np.abs(U_image)
        phase_image = np.angle(U_image)

        # Calcula correla√É¬ß√É¬£o para monitorar converg√É¬™ncia
        target_region = extract_center(target_padded, original_size)
        recon_region = extract_center(amplitude_image, original_size)

        corr = calculate_correlation(target_region, recon_region)
        correlations.append(corr)

        # Atualiza campo no plano da imagem
        U_image_updated = target_padded * np.exp(1j * phase_image)

        # 3. Propaga de volta para o plano do holograma
        U = angular_spectrum_propagation(U_image_updated, wavelength, -z, dx, filter_mask)

        # 4. Mant√É¬©m a fase, atualiza amplitude com incidente (unit√É¬°ria)
        phase_hologram = np.angle(U)
        U = np.exp(1j * phase_hologram) # Amplitude unit√É¬°ria (holograma de fase)

        if (i + 1) % 10 == 0:
            print(f" 	Itera√É¬ß√É¬£o GS (X) {i+1}/{num_iter}, Correla√É¬ß√É¬£o: {corr:.4f}")

    phase_final = extract_center(np.angle(U), original_size)

    reconstructed_image = reconstruct_image(phase_final, wavelength, z, dx, NA)

    return target, reconstructed_image, phase_final, correlations

wavelength = 1064e-9  # 1064 nm
z = 380e-6           # 380 √é¬ºm
dx = 520e-9          # pixel pitch
NA = 0.65            # abertura num√É¬©rica
num_iter = 200       # n√É¬∫mero de itera√É¬ß√É¬µes

print("Carregando e pr√É¬©-processando imagem...")
target_original = load_and_preprocess_image('/content/HJV.png', target_size=(90, 90))

print("Executando algoritmo de Gerchberg-Saxton...")

img_original, img_reconstruida, mapa_de_fase, correlations = gerchberg_saxton_angular_spectrum(
    target_original,
    wavelength,
    z,
    dx,
    NA,
    num_iter
)
print(f"\nCorrela√É¬ß√É¬£o final: {correlations[-1]:.4f}")

print(f"Dimens√É¬µes da Imagem Original: {img_original.shape}")
print(f"Dimens√É¬µes da Imagem Reconstru√É¬≠da: {img_reconstruida.shape}")
print(f"Dimens√É¬µes do Mapa de Fase: {mapa_de_fase.shape}")

# np.savetxt('phase_map_x_polarization.txt', mapa_de_fase)
# np.savetxt('correlations.txt', correlations)

plt.figure(figsize=(15, 10))

plt.subplot(2, 3, 1)
plt.imshow(img_original, cmap='gray')
plt.title('Imagem Original')

plt.subplot(2, 3, 2)
plt.imshow(mapa_de_fase, cmap='hsv')
plt.title('Mapa de Fase')

plt.subplot(2, 3, 3)
plt.imshow(img_reconstruida, cmap='gray')
plt.title('Imagem Reconstru√É¬≠da')

plt.subplot(2, 3, 5)
plt.plot(correlations)
plt.title('Converg√É¬™ncia')
plt.xlabel('Itera√É¬ß√É¬£o')
plt.ylabel('Correla√É¬ß√É¬£o')

plt.tight_layout()
#plt.show()

"""## Polariza√É¬ß√É¬£o y:"""

def generate_dammann_phase_map(
    P: float = 520e-9,
    wavelength: float = 1064e-9,
    supercell_pixels: int = 45,
    n_supercells: int = 2,
    iters_gs: int = 400,
    random_seed: int = 0,
    verbose: bool = True
) -> tuple[np.ndarray, dict, list]:
    """
    Gera o mapa de fase para uma grade de Dammann (spot-cloud) usando o algoritmo GS.
    """
    np.random.seed(random_seed)

    N_super = supercell_pixels
    dx = P
    d = dx * N_super

    # --- Grade k e alvo ---
    # (Esta fun√É¬ß√É¬£o usa np.fft, o que n√É¬£o conflita com o Bloco 1)
    kx = np.fft.fftfreq(N_super, d=dx)
    ky = np.fft.fftfreq(N_super, d=dx)
    kx_shift = np.fft.fftshift(kx)
    ky_shift = np.fft.fftshift(ky)
    KX, KY = np.meshgrid(kx_shift, ky_shift)
    K_rad = np.sqrt(KX**2 + KY**2)
    target_radius = min(1.0 / wavelength, 1.0 / (2.0 * dx))
    target_amp = (K_rad <= target_radius).astype(float)

    # --- Algoritmo GS ---
    plane_field = np.exp(1j * 2.0 * np.pi * np.random.rand(N_super, N_super))
    errors = []

    # Loop de itera√É¬ß√É¬£o para Dammann
    gs_iterator = range(iters_gs)
    if verbose:
        # Cria uma barra de progresso se 'verbose' for True
        gs_iterator = tqdm(range(iters_gs), desc="  Itera√É¬ß√É¬µes GS (Y)", leave=False)

    for it in gs_iterator:
        far = np.fft.fft2(plane_field)
        far_shift = np.fft.fftshift(far)

        amp_current = np.abs(far_shift)
        err = np.sqrt(np.mean((amp_current / (amp_current.max() + 1e-9) - target_amp)**2))
        errors.append(err)

        far_shift = target_amp * np.exp(1j * np.angle(far_shift))
        far = np.fft.ifftshift(far_shift)

        plane_field = np.fft.ifft2(far)
        plane_field = np.exp(1j * np.angle(plane_field))

    supercell_phase = np.angle(plane_field)

    # --- Constru√É¬ß√É¬£o da Metassuperf√É¬≠cie Completa ---
    full_phase = np.tile(supercell_phase, (n_supercells, n_supercells))

    if verbose:
        print(f"  Mapa Dammann (Y) gerado: {full_phase.shape} pixels")

    metrics = {} # Vazio, focado apenas na gera√É¬ß√É¬£o de fase

    return full_phase, metrics, errors

def analyze_and_plot_results(
    full_phase: np.ndarray,
    errors: list,
    P: float,
    wavelength: float,
    supercell_pixels: int
) -> pd.DataFrame:
    """
    Analisa e plota os resultados da metassuperf√É¬≠cie.
    Retorna um DataFrame com os dados das ordens propagantes.
    """

    print("Iniciando an√É¬°lise e plotagem...")

    # --- Plot 1: Mapa de Fase ---
    print("Plot 1: Gerando plot_mapa_fase.png...")
    plt.figure(figsize=(7, 6))
    plt.imshow(full_phase, cmap='twilight', extent=None)
    plt.colorbar(label="Fase (rad)")
    plt.title(f"Mapa de Fase ({full_phase.shape[0]}x{full_phase.shape[1]} pixels)")
    plt.xlabel("Pixels (X)")
    plt.ylabel("Pixels (Y)")
    plt.tight_layout()
    plt.savefig("plot_mapa_fase.png", dpi=150)

    # --- Plot 2: Evolu√É¬ß√É¬£o do Erro ---
    print("Plot 2: Gerando plot_erro_gs.png...")
    plt.figure(figsize=(7, 5))
    plt.plot(errors)
    plt.xlabel("Itera√É¬ß√É¬£o GS")
    plt.ylabel("Erro RMSE")
    plt.title("Evolu√É¬ß√É¬£o do Erro GS")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("plot_erro_gs.png", dpi=150)

    # --- An√É¬°lise de Far-Field (C√É¬©lulas 8 & 9) ---
    print("An√É¬°lise: Calculando Far-Field da metassuperf√É¬≠cie completa...")
    N_full = full_phase.shape[0]
    dx_full = P

    # Criar a grade k
    kx_full = np.fft.fftfreq(N_full, d=dx_full)
    ky_full = np.fft.fftfreq(N_full, d=dx_full)
    kx_full_shift = np.fft.fftshift(kx_full)
    ky_full_shift = np.fft.fftshift(ky_full)

    # Calcular o far-field
    plane_field_full = np.exp(1j * full_phase)
    far_field_full = np.fft.fft2(plane_field_full)
    far_field_full_shift = np.fft.fftshift(far_field_full)
    far_field_intensity = np.abs(far_field_full_shift)**2

    print("An√É¬°lise: Extraindo ordens de difra√É¬ß√É¬£o...")
    d_supercell = supercell_pixels * P
    max_order = int(np.floor((d_supercell) / wavelength))

    p_range = np.arange(-max_order, max_order + 1)
    q_range = np.arange(-max_order, max_order + 1)
    order_data = []

    for p in p_range:
        for q in q_range:
            k_p = p / d_supercell
            k_q = q / d_supercell

            idx_p = np.argmin(np.abs(kx_full_shift - k_p))
            idx_q = np.argmin(np.abs(ky_full_shift - k_q))

            k_rad_order = np.sqrt(k_p**2 + k_q**2)
            if k_rad_order <= (1.0 / wavelength):
                intensity = far_field_intensity[idx_q, idx_p]
                order_data.append({'p': p, 'q': q, 'intensity': intensity, 'k_p': k_p, 'k_q': k_q})

    df_orders = pd.DataFrame(order_data)
    print(f"An√É¬°lise conclu√É¬≠da. Encontradas {len(df_orders)} ordens propagantes.")

    print("Plot 3: Gerando plot_far_field.png...")
    log_intensity = np.log10(far_field_intensity + 1e-9)
    vmax_val = log_intensity.max()
    vmin_val = vmax_val - 4

    plt.figure(figsize=(7, 6))
    plt.imshow(
        log_intensity,
        cmap='hot',
        extent=[kx_full_shift.min(), kx_full_shift.max(), ky_full_shift.min(), ky_full_shift.max()],
        vmin=vmin_val,
        vmax=vmax_val,
        origin='lower'
    )

    k_max_plot = 1.0 / wavelength
    circle = plt.Circle((0, 0), k_max_plot, color='cyan', fill=False, linestyle='--', linewidth=1.5, label=r'$1/\lambda$')
    plt.gca().add_artist(circle)
    plt.legend(handles=[circle])

    plt.xlim(-k_max_plot * 1.5, k_max_plot * 1.5)
    plt.ylim(-k_max_plot * 1.5, k_max_plot * 1.5)
    plt.colorbar(label="Intensidade (log10)")
    plt.title("Far-Field (Intensidade Logar√É¬≠tmica)")
    plt.xlabel(r"$k_x$ (1/m)")
    plt.ylabel(r"$k_y$ (1/m)")
    plt.gca().set_aspect('equal')
    plt.tight_layout()
    plt.savefig("plot_far_field.png", dpi=150)

    print("Plot 4: Gerando plot_ordens_scatter.png...")
    if not df_orders.empty:
        plt.figure(figsize=(7, 6))
        norm_intensity = df_orders['intensity'] / (df_orders['intensity'].max() + 1e-9)

        sc = plt.scatter(df_orders['p'], df_orders['q'], c=norm_intensity, cmap='viridis', s=50, edgecolor="k", vmin=0)
        plt.colorbar(sc, label="Intensidade Normalizada")
        plt.xlabel("Ordem p")
        plt.ylabel("Ordem q")
        plt.title(f"Mapa das Ordens Propagantes ({len(df_orders)} ordens)")
        plt.grid(True)
        plt.gca().set_aspect('equal')
        plt.tight_layout()
        plt.savefig("plot_ordens_scatter.png", dpi=150)
    else:
        print("Plot 4: DataFrame de ordens vazio, pulando o scatter plot.")

    print("\nTodos os plots foram salvos como arquivos .png.")

    return df_orders

P = 520e-9
wavelength = 1064e-9
supercell_pixels = 45
n_supercells = 2
iters_gs = 400
random_seed = 0

print("Iniciando Gera√É¬ß√É¬£o de Fase...")
full_phase, _, errors = generate_dammann_phase_map(
    P=P,
    wavelength=wavelength,
    supercell_pixels=supercell_pixels,
    n_supercells=n_supercells,
    iters_gs=iters_gs,
    random_seed=random_seed,
    verbose=True
)
print("Gera√É¬ß√É¬£o de Fase Conclu√É¬≠da.")

df_ordens_propagantes = analyze_and_plot_results(
    full_phase=full_phase,
    errors=errors,
    P=P,
    wavelength=wavelength,
    supercell_pixels=supercell_pixels
)

print("\n--- Resumo das Ordens ---")
print(df_ordens_propagantes.head())

"""# Carregamento dos Modelos Treinados"""

IMG_SIZE = 64
LATENT_DIM = 128 # Dimens√É¬£o do vetor latente

GENERATOR_PATH = "/content/generator_teste_5_final.pth"
SIMULATOR_PATH = "/content/simulador_NG_teste_1.pth"

# --- Componentes Globais ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Carregar Gerador
try:
    generator = CPPN_Generator(latent_dim=LATENT_DIM, in_coords=2).to(device)
    generator.load_state_dict(torch.load(GENERATOR_PATH, map_location=device))
    generator.eval()
    print(f"Gerador '{GENERATOR_PATH}' carregado em {device}.")
except Exception as e:
    print(f"Erro ao carregar Gerador: {e}")
    exit()

# Carregar Simulador
try:
    # Use a arquitetura ResNet que voc√É¬™ definiu (ex: ResNet-18)
    simulator = ResNetSimulator(in_channels=1, n_outputs=4).to(device)
    simulator.load_state_dict(torch.load(SIMULATOR_PATH, map_location=device))
    simulator.eval()
    print(f"Simulador '{SIMULATOR_PATH}' carregado em {device}.")
except Exception as e:
    print(f"Erro ao carregar Simulador: {e}")
    exit()

# Criar a grade de coordenadas (apenas uma vez)
def make_coordinate_grid(size, dev):
    xs = np.linspace(-1, 1, size)
    ys = np.linspace(-1, 1, size)
    xx, yy = np.meshgrid(xs, ys)
    coords = np.stack([xx.ravel(), yy.ravel()], axis=-1).astype(np.float32)
    return torch.from_numpy(coords).to(dev)

coords_grid = make_coordinate_grid(IMG_SIZE, device)
print("Grade de coordenadas criada.")

"""
============================================================================
## Otimiza√ß√£o por Gradiente (Substitui√ß√£o do GA)
============================================================================
"""

# ---------------------------------------------------------------------------
# Configura√ß√µes da Otimiza√ß√£o por Gradiente
# ---------------------------------------------------------------------------
N_STEPS = 500       # N√∫mero de passos de otimiza√ß√£o (similar ao NGEN)
LR = 0.01           # Taxa de aprendizado para o Adam
BATCH_PIXELS = 1024 # N√∫mero de pixels processados em paralelo na GPU
PRINT_EVERY = 20    # Imprimir o progresso a cada N passos

# Caminhos de salvamento (mesmos do script original)
save_dir = "resultados_otimizacao_gradiente" # Novo diret√≥rio para n√£o misturar
os.makedirs(save_dir, exist_ok=True)
latent_vectors_path = os.path.join(save_dir, "latent_vectors_optimized.npy")

# ---------------------------------------------------------------
# 1. Mapas de fase alvo e Tensores
# ---------------------------------------------------------------
phase_map_x = mapa_de_fase  # da se√ß√£o GS
phase_map_y = full_phase    # da se√ß√£o Dammann
rows, cols = phase_map_x.shape
latent_dim = LATENT_DIM_GA # Usando a dimens√£o do seu GA

print(f"\n‚öôÔ∏è Iniciando otimiza√ß√£o por Gradiente em {rows}x{cols} pixels ({rows*cols} totais)...\n")

# Mover modelos para o device (j√° deve estar feito, mas garantindo)
generator.to(device)
simulator.to(device)
generator.eval()
simulator.eval()

# Criar tensores de alvo na GPU
targets_phase_x = torch.tensor(phase_map_x, dtype=torch.float32, device=device)
targets_phase_y = torch.tensor(phase_map_y, dtype=torch.float32, device=device)

# Criar a grade de coordenadas (j√° existe como 'coords_grid')
coords_grid_tensor = coords_grid.to(device)

# ---------------------------------------------------------------
# 2. Inicializar Vetores Latentes (Otimiz√°veis)
# ---------------------------------------------------------------
# Criamos UM tensor para TODOS os vetores latentes
# E informamos ao PyTorch que queremos calcular gradientes para ele
latent_vectors_tensor = torch.randn(
    (rows, cols, latent_dim),
    dtype=torch.float32,
    device=device,
    requires_grad=True
)

# ---------------------------------------------------------------
# 3. Fun√ß√µes de Custo (Loss) e Otimizador
# ---------------------------------------------------------------
optimizer = torch.optim.Adam([latent_vectors_tensor], lr=LR)
pi_tensor = torch.tensor(2 * np.pi, device=device)

def calculate_batch_phase_loss(sim_outputs, target_phase_x_batch, target_phase_y_batch):
    """
    Calcula o erro de fase L1 (com wrapping) para um lote de pixels.
    """
    s_tm_real = sim_outputs[:, 0]
    s_tm_imag = sim_outputs[:, 1]
    s_te_real = sim_outputs[:, 2]
    s_te_imag = sim_outputs[:, 3]

    sim_phase_x = torch.atan2(s_te_imag, s_te_real)
    sim_phase_y = torch.atan2(s_tm_imag, s_tm_real)

    # Targets j√° s√£o tensores (batch_size,)
    error_x = torch.abs(sim_phase_x - target_phase_x_batch)
    error_y = torch.abs(sim_phase_y - target_phase_y_batch)

    # Corrigir o phase wrapping
    error_x = torch.min(error_x, pi_tensor - error_x)
    error_y = torch.min(error_y, pi_tensor - error_y)

    # Retorna o erro L1 total somado (error_x + error_y)
    # e calculamos a M√âDIA sobre o lote
    return (error_x + error_y).mean()

def get_per_pixel_errors(z_tensor, target_x_tensor, target_y_tensor):
    """
    Fun√ß√£o helper para calcular o erro final por pixel.
    Usado para popular as m√©tricas (all_initial_avg_fitness e all_best_fitness).
    """
    print(" ¬† ¬†Calculando erros por pixel (para m√©tricas)...")
    z_flat = z_tensor.view(-1, latent_dim)
    target_x_flat = target_x_tensor.view(-1)
    target_y_flat = target_y_tensor.view(-1)

    all_errors = []
    with torch.no_grad():
        for i in range(0, z_flat.shape[0], BATCH_PIXELS):
            z_batch = z_flat[i:i + BATCH_PIXELS]
            target_x_batch = target_x_flat[i:i + BATCH_PIXELS]
            target_y_batch = target_y_flat[i:i + BATCH_PIXELS]

            imgs = generator(coords_grid_tensor, z_batch)
            outputs = simulator(imgs)

            # Reutiliza a l√≥gica da loss, mas sem a m√©dia
            s_tm_real = outputs[:, 0]
            s_tm_imag = outputs[:, 1]
            s_te_real = outputs[:, 2]
            s_te_imag = outputs[:, 3]
            sim_phase_x = torch.atan2(s_te_imag, s_te_real)
            sim_phase_y = torch.atan2(s_tm_imag, s_tm_real)
            error_x = torch.abs(sim_phase_x - target_x_batch)
            error_y = torch.abs(sim_phase_y - target_y_batch)
            error_x = torch.min(error_x, pi_tensor - error_x)
            error_y = torch.min(error_y, pi_tensor - error_y)

            batch_errors = (error_x + error_y).cpu().numpy()
            all_errors.extend(batch_errors)

    return all_errors

# ---------------------------------------------------------------
# 4. Loop de Otimiza√ß√£o Principal
# ---------------------------------------------------------------

# --- M√©tricas (para compatibilidade com seu script) ---
# 1. Erro inicial (equivalente √† "Gera√ß√£o 0" do GA)
all_initial_avg_fitness = get_per_pixel_errors(latent_vectors_tensor, targets_phase_x, targets_phase_y)

print("\nIniciando loop de otimiza√ß√£o por gradiente...")
start_time = time.time()

# √çndices de todos os pixels (ex: 0 a 8099)
total_pixels = rows * cols
pixel_indices = torch.arange(total_pixels, device=device)

# Achata os tensores para facilitar o batching
z_flat = latent_vectors_tensor.view(-1, latent_dim)
target_x_flat = targets_phase_x.view(-1)
target_y_flat = targets_phase_y.view(-1)


for step in range(N_STEPS):
    # Embaralha os pixels a cada √©poca
    permuted_indices = pixel_indices[torch.randperm(total_pixels)]

    total_loss_epoch = 0
    num_batches = 0

    for batch_start in range(0, total_pixels, BATCH_PIXELS):
        batch_indices = permuted_indices[batch_start : batch_start + BATCH_PIXELS]

        # Pega os dados do lote
        # IMPORTANTE: Pegamos os vetores latentes do tensor *achatado*
        z_batch = z_flat[batch_indices]
        target_x_batch = target_x_flat[batch_indices]
        target_y_batch = target_y_flat[batch_indices]

        # --- Etapa de Otimiza√ß√£o ---
        optimizer.zero_grad()

        # 1. Forward pass
        imgs_batch = generator(coords_grid_tensor, z_batch)
        outputs_raw = simulator(imgs_batch)

        # 2. Calcular Loss
        loss = calculate_batch_phase_loss(outputs_raw, target_x_batch, target_y_batch)

        # 3. Backward pass
        loss.backward()

        # 4. Atualizar pesos (vetores latentes)
        optimizer.step()

        total_loss_epoch += loss.item()
        num_batches += 1

    # Imprimir progresso
    avg_loss = total_loss_epoch / num_batches
    if (step + 1) % PRINT_EVERY == 0 or step == 0:
        print(f" ¬†Passo [{step+1}/{N_STEPS}], Loss M√©dia: {avg_loss:.6f}")

# ---------------------------------------------------------------
# 5. Resultado final
# ---------------------------------------------------------------
total_execution_time = time.time() - start_time
print(f"\nüéâ Otimiza√ß√£o por Gradiente conclu√≠da!")
print(f"Tempo total: {total_execution_time:.2f} s")

# --- M√©tricas (para compatibilidade) ---
# 2. Erro final (equivalente ao "Best Fitness" do GA)
print("Calculando m√©tricas finais...")
all_best_fitness = get_per_pixel_errors(latent_vectors_tensor, targets_phase_x, targets_phase_y)

# 3. NFE (N√∫mero de Avalia√ß√µes de Fun√ß√£o)
# Cada passo otimiza todos os pixels, ent√£o NFE total = N_STEPS * total_pixels
total_nfe = N_STEPS * total_pixels

# 4. Salvar os vetores latentes otimizados (como o script GA fazia)
latent_vectors = latent_vectors_tensor.detach().cpu().numpy()
np.save(latent_vectors_path, latent_vectors)
print(f"Resultados salvos em:\n{latent_vectors_path}\n")


"""
============================================================================
## C√°lculo e Salvamento de M√©tricas (Gradiente)
============================================================================
"""

print("\nCalculando m√©tricas finais da otimiza√ß√£o (Gradiente)...")

num_pixels = rows * cols
# Esta vari√°vel 'all_best_fitness' foi criada com sucesso pelo bloco de gradiente
fitness_array = np.array(all_best_fitness)

# 1. Salvar os dados do erro para compara√ß√£o posterior
np.save('gradient_errors.npy', fitness_array)

# 2. Gerar o histograma apenas do Gradiente
print("Gerando histograma de erros do Gradiente...")
plt.figure(figsize=(10, 6))
plt.hist(fitness_array, bins=50, alpha=0.7, color='green') # Cor verde para Gradiente
plt.title(f'Histograma de Erro (Fitness) - M√©todo Gradiente (N={len(fitness_array)} pixels)')
plt.xlabel('Erro Total por Pixel (Fitness Final)')
plt.ylabel('Contagem de Pixels')
plt.grid(True, linestyle='--', alpha=0.5)
plt.savefig('histograma_Gradiente.png', dpi=150)
plt.show()

# ----------------------------
# 3. Calcular estat√≠sticas
# ----------------------------

# A fitness J√Å √© o erro, ent√£o a m√©dia √© o Erro M√©dio.
mean_error = np.mean(fitness_array)
min_best_fitness = np.min(fitness_array) # O melhor pixel individual

# Desvio Padr√£o
std_dev_fitness = np.std(fitness_array)

# Taxa de Sucesso
SUCCESS_THRESHOLD = 0.1 # <-- Voc√™ pode ajustar este valor
successful_pixels = np.sum(fitness_array < SUCCESS_THRESHOLD)
success_rate = (successful_pixels / num_pixels) * 100

# Converg√™ncia
# 'all_initial_avg_fitness' tamb√©m foi criada pelo bloco de gradiente
avg_initial_fitness = np.mean(all_initial_avg_fitness)
avg_final_fitness = mean_error
convergence_improvement = (avg_initial_fitness - avg_final_fitness)

# total_execution_time e total_nfe j√° existem do bloco de gradiente

# --- Criar o arquivo .txt ---
# *** NOME DE ARQUIVO ATUALIZADO ***
metrics_filename = "optimization_metrics_GRADIENTE.txt"
try:
    with open(metrics_filename, "w", encoding="utf-8") as f:
        # *** CONTE√öDO ATUALIZADO ***
        f.write("--- M√©tricas da Otimiza√ß√£o por GRADIENTE da Metassuperf√≠cie ---\n\n")
        f.write(f"Par√¢metros do Otimizador (Adam):\n")
        f.write(f"  N√∫mero de Passos (Steps): {N_STEPS}\n")
        f.write(f"  Taxa de Aprendizado (LR): {LR}\n")
        f.write(f"  Tamanho do Lote (Batch Size): {BATCH_PIXELS}\n")
        f.write(f"  Total de Meta-√Åtomos Otimizados: {num_pixels} ({rows}x{cols})\n\n")

        f.write("--- Performance Computacional ---\n")
        f.write(f"Tempo de Execu√ß√£o Total: {total_execution_time:.2f} segundos\n")
        f.write(f"N√∫mero Total de Avalia√ß√µes (NFE): {total_nfe}\n")
        f.write(f"NFE por pixel (m√©dia): {total_nfe / num_pixels:.1f}\n\n")

        f.write("--- Qualidade da Otimiza√ß√£o (Fitness/Erro) ---\n")
        f.write(f"Erro M√©dio Final (M√©dia do Best Fitness): {mean_error:.6f}\n")
        f.write(f"Melhor Fitness Individual (pixel √∫nico): {min_best_fitness:.6f}\n")
        f.write(f"Desvio Padr√£o (Fitness): {std_dev_fitness:.6f}\n\n")

        f.write("--- M√©tricas de Converg√™ncia ---\n")
        f.write(f"Erro M√©dio Inicial (M√©dia da Pop. Step 0): {avg_initial_fitness:.6f}\n")
        f.write(f"Melhoria M√©dia (Inicial - Final): {convergence_improvement:.6f}\n")
        f.write(f"Taxa de Sucesso (Erro < {SUCCESS_THRESHOLD}): {success_rate:.2f} %\n")

    print(f"M√©tricas salvas com sucesso em '{metrics_filename}'")

except Exception as e:
    print(f"Erro ao salvar arquivo de m√©tricas: {e}")
"""# Verifica√É¬ß√É¬£o de Performance"""

"""
============================================================================
## Verifica√ß√£o de Performance (Gradiente)
============================================================================
"""

print("\nVerificando performance final (em lotes) da metassuperf√≠cie otimizada...")

# Inicializa os arrays
final_phase_x = np.zeros((rows, cols))
final_phase_y = np.zeros((rows, cols))
final_amp_x = np.zeros((rows, cols))
final_amp_y = np.zeros((rows, cols))

# Usa a grade de coordenadas global
coords_grid_tensor = coords_grid

with torch.no_grad():
    for i in tqdm(range(rows), desc="Verificando linhas"):

        # 1. Pega todos os vetores latentes da linha
        # A vari√°vel 'latent_vectors' foi salva corretamente pelo bloco de gradiente
        z_row_list = [latent_vectors[i, j] for j in range(cols)]
        z_row_tensor = torch.from_numpy(np.array(z_row_list)).float().to(device)
        # Shape: (cols, latent_dim)

        # 2. Gera e simula TODAS as imagens da linha de uma vez (em batch)
        imgs_batch = generator(coords_grid_tensor, z_row_tensor)
        imgs_binary_batch = (imgs_batch > 0.5).float() # Binariza

        # outputs_raw ter√° shape (cols, 4)
        outputs_raw = simulator(imgs_binary_batch)

        # 3. Processa os resultados do batch
        s_tm_real = outputs_raw[:, 0]
        s_tm_imag = outputs_raw[:, 1]
        s_te_real = outputs_raw[:, 2]
        s_te_imag = outputs_raw[:, 3]

        # Calcula fases e amplitudes com opera√ß√µes de tensor
        s_te_complex = torch.complex(s_te_real, s_te_imag)
        s_tm_complex = torch.complex(s_tm_real, s_tm_imag)

        # 4. Salva no array (movendo para CPU de uma vez)
        final_phase_x[i, :] = torch.angle(s_te_complex).cpu().numpy()
        final_amp_x[i, :] = torch.abs(s_te_complex).cpu().numpy()
        final_phase_y[i, :] = torch.angle(s_tm_complex).cpu().numpy()
        final_amp_y[i, :] = torch.abs(s_tm_complex).cpu().numpy()

print("Verifica√ß√£o em lote conclu√≠da.")

# --- Plotar resultados ---
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
# *** T√çTULO ATUALIZADO ***
plt.suptitle("Resultados Finais da Otimiza√ß√£o (GRADIENTE)", fontsize=16)

# Coluna 1: Alvos
im0 = axes[0, 0].imshow(phase_map_x, cmap='hsv', vmin=-np.pi, vmax=np.pi)
axes[0, 0].set_title("Alvo - Fase X (rad)")
fig.colorbar(im0, ax=axes[0, 0])

im1 = axes[1, 0].imshow(phase_map_y, cmap='hsv', vmin=-np.pi, vmax=np.pi)
axes[1, 0].set_title("Alvo - Fase Y (rad)")
fig.colorbar(im1, ax=axes[1, 0])

# Coluna 2: Fases Obtidas
im2 = axes[0, 1].imshow(final_phase_x, cmap='hsv', vmin=-np.pi, vmax=np.pi)
axes[0, 1].set_title("Obtido - Fase X (rad)")
fig.colorbar(im2, ax=axes[0, 1])

im3 = axes[1, 1].imshow(final_phase_y, cmap='hsv', vmin=-np.pi, vmax=np.pi)
axes[1, 1].set_title("Obtido - Fase Y (rad)")
fig.colorbar(im3, ax=axes[1, 1])

# Coluna 3: Amplitudes Obtidas
im4 = axes[0, 2].imshow(final_amp_x, cmap='viridis', vmin=0, vmax=1)
axes[0, 2].set_title("Obtido - Amplitude X (Efici√™ncia)")
fig.colorbar(im4, ax=axes[0, 2])

im5 = axes[1, 2].imshow(final_amp_y, cmap='viridis', vmin=0, vmax=1)
axes[1, 2].set_title("Obtido - Amplitude Y (Efici√™ncia)")
fig.colorbar(im5, ax=axes[1, 2])

plt.tight_layout()

# *** NOME DE ARQUIVO ATUALIZADO ***
plt.savefig("Resultados_Otimizacao_GRADIENTE.png", dpi=300, bbox_inches="tight")
plt.show()
#plt.show()

"""
============================================================================
## Gerando Imagem da Metassuperf√≠cie Completa (Gradiente)
============================================================================
"""

print("\n--- Gerando Imagem da Metassuperf√≠cie Completa ---")
print("Isso pode demorar alguns minutos e consumir bastante RAM...")

try:
    # 'latent_vectors' foi carregado/definido pelo bloco de gradiente
    rows, cols = latent_vectors.shape[0], latent_vectors.shape[1]
    img_size = IMG_SIZE

    full_surface_rows_list = []

    for i in tqdm(range(rows), desc="Gerando Linhas da Metassuperf√≠cie"):

        # 1. Pegar todos os vetores latentes desta linha
        z_row_list = [latent_vectors[i, j] for j in range(cols)]
        z_row_tensor = torch.from_numpy(np.array(z_row_list)).float().to(device)

        # 2. Gerar TODAS as imagens da linha de uma vez (em batch)
        with torch.no_grad():
            img_row_batch = generator(coords_grid, z_row_tensor)
            img_row_binary = (img_row_batch > 0.5).float()

        # 3. "Costurar" as imagens horizontalmente
        img_row_binary = img_row_binary.squeeze(1)
        img_row_binary = img_row_binary.transpose(0, 1)
        final_row_img = img_row_binary.reshape(img_size, -1)

        # 4. Mover para CPU (para liberar VRAM) e adicionar √† lista
        full_surface_rows_list.append(final_row_img.cpu().numpy())

    # 5. "Costurar" todas as linhas verticalmente
    print("Costurando imagem final...")
    final_metasurface_image = np.concatenate(full_surface_rows_list, axis=0)

    print(f"Imagem final gerada com dimens√µes: {final_metasurface_image.shape}")

    # 6. Salvar
    # *** NOME DE ARQUIVO ATUALIZADO ***
    output_filename = "metasurface_completa_geometria_GRADIENTE.png"
    plt.imsave(output_filename, final_metasurface_image, cmap='gray')
    print(f"Imagem da metassuperf√≠cie completa salva em: {output_filename}")

except Exception as e:
    print(f"Ocorreu um erro ao gerar a metassuperf√≠cie completa: {e}")

"""
============================================================================
## Reconstru√ß√£o Final da Imagem (Gradiente)
============================================================================
"""

print("\nReconstruindo imagem final a partir do campo otimizado pelo GRADIENTE...")

# Criar o campo complexo para Pol-X (Holograma)
# Usamos 'final_amp_x' e 'final_phase_x' da "Verifica√ß√£o de Performance"
campo_complexo_x_final_grad = final_amp_x * np.exp(1j * final_phase_x)

# Usar os mesmos par√¢metros do GS (wavelength, z, dx, NA)
img_reconstruida_final_grad = reconstruct_image_from_field(
    campo_complexo_x_final_grad,
    wavelength,
    z,
    dx,
    NA
)
print("Reconstru√ß√£o final (Gradiente) conclu√≠da.")


print("Gerando gr√°fico de compara√ß√£o de reconstru√ß√£o (Gradiente)...")

fig_grad_recon = plt.figure(figsize=(18, 6))
# *** T√çTULO ATUALIZADO ***
plt.suptitle("Compara√ß√£o da Imagem Reconstru√≠da (Pol-X)", fontsize=16)

# Painel 1: Imagem Alvo Original
plt.subplot(1, 3, 1)
plt.imshow(img_original, cmap='gray')
plt.title("1. Imagem Alvo Original")
plt.xlabel("Pixel X")
plt.ylabel("Pixel Y")

# Painel 2: Reconstru√ß√£o Ideal (do GS, Fase Pura)
# Esta √© a 'img_reconstruida' da se√ß√£o GS
plt.subplot(1, 3, 2)
plt.imshow(img_reconstruida, cmap='gray')
plt.title("2. Reconstru√ß√£o Ideal (GS - Fase Pura)")
plt.xlabel("Pixel X")
plt.ylabel("Pixel Y")

# Painel 3: Reconstru√ß√£o Real (do Gradiente + Simulador)
plt.subplot(1, 3, 3)
plt.imshow(img_reconstruida_final_grad, cmap='gray')
# *** T√çTULO ATUALIZADO ***
plt.title("3. Reconstru√ß√£o Real (Gradiente + Simulador)")
plt.xlabel("Pixel X")
plt.ylabel("Pixel Y")

plt.tight_layout()
# *** NOME DE ARQUIVO ATUALIZADO ***
plt.savefig("comparacao_reconstrucao_GRADIENTE.png", dpi=300, bbox_inches='tight')
plt.show()

print("Gr√°fico 'comparacao_reconstrucao_GRADIENTE.png' salvo.")