In [1]:
# 첫 번째 셀에서 이것부터 실행
import sys
print(sys.path)

# jax.py가 있는지 확인
import os
for f in os.listdir('.'):
    if 'jax' in f.lower():
        print(f"Found: {f}")

['/scratch/e1729a03/발사', '/opt/conda/lib/python38.zip', '/opt/conda/lib/python3.8', '/opt/conda/lib/python3.8/lib-dynload', '', '/home01/e1729a03/.local/lib/python3.8/site-packages', '/opt/conda/lib/python3.8/site-packages', '/opt/conda/lib/python3.8/site-packages/torchtext-0.11.0a0-py3.8-linux-x86_64.egg', '/opt/conda/lib/python3.8/site-packages/certifi-2022.9.14-py3.8.egg', '/opt/conda/lib/python3.8/site-packages/functorch-0.3.0a0-py3.8-linux-x86_64.egg', '/home01/e1729a03/.local/lib/python3.8/site-packages/setuptools/_vendor']


In [2]:
import torch
import gc

def clear_gpu_memory():
    """
    PyTorch GPU 메모리 캐시를 완전히 초기화하는 함수.
    - garbage collection 강제 실행
    - CUDA 캐시 해제
    - peak memory stats 리셋 (디버깅용)
    """
    # Python 객체 참조 해제
    gc.collect()
    
    # PyTorch CUDA 캐시 해제
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()  # 동기화로 pending operation 완료 보장
        torch.cuda.reset_peak_memory_stats()  # peak 메모리 통계 리셋 (옵션)
        print(f"GPU memory cleared. Current allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    else:
        print("CUDA not available. Running on CPU.")

# 사용 예시
clear_gpu_memory()

GPU memory cleared. Current allocated: 0.00 GB
GPU memory reserved: 0.00 GB


In [None]:
import os
import math
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as nnF
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from scipy.ndimage import label

try:
    from skimage.measure import marching_cubes
    import trimesh
    MESH_EXPORT = True
except ImportError:
    MESH_EXPORT = False

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float64)
device = torch.device("cpu")
OUT_DIR = "./out_track_b_3d_latest"
os.makedirs(OUT_DIR, exist_ok=True)

gamma = 1.22
R = 355.0
Tc = 3200.0
rho_p = 1700.0
a = 5.0e-5
n = 0.35
pa = 101325.0
At = 3.0e-4

L = 0.20
R_case = 0.10
N = 32
dx = max(L, 2*R_case) / N
eps = 4.0 * dx

x = torch.linspace(-R_case, R_case, N, device=device)
y = torch.linspace(-R_case, R_case, N, device=device)
z = torch.linspace(-L/2, L/2, N, device=device)
Z, Y, X = torch.meshgrid(z, y, x, indexing="ij")
coords_flat = torch.stack([X.flatten(), Y.flatten(), Z.flatten()], dim=-1)

num_freqs = 6
freq_bands = 2.0 ** torch.linspace(0., num_freqs-1, num_freqs, device=device)
def fourier_encode(coords):
    enc = []
    for d in range(3):
        c = coords[:, d:d+1]
        enc.append(c)
        for f in freq_bands:
            arg = f * c
            enc.append(torch.sin(arg))
            enc.append(torch.cos(arg))
    return torch.cat(enc, dim=-1)
coords_encoded = fourier_encode(coords_flat)
base_features = coords_encoded.shape[1]

kx = 2 * math.pi * torch.fft.fftfreq(N, d=dx).to(device)
k2 = kx[None,None,:]**2 + kx[None,:,None]**2 + kx[:,None,None]**2

dt = 0.001
t_end = 3.5
time_grid = torch.arange(0.0, t_end + dt, dt, device=device)
NT = len(time_grid)

Pc_target = torch.ones_like(time_grid) * 6.0e6
Pc_target[time_grid > 2.2] *= torch.exp(-0.8 * (time_grid[time_grid > 2.2] - 2.2))
F_target = Pc_target * At * 100.0

def mdot_choked(Pc):
    C = math.sqrt(gamma / (R * Tc)) * (2.0 / (gamma + 1.0)) ** ((gamma + 1.0) / (2.0 * (gamma - 1.0)))
    return At * torch.clamp(Pc, min=0.0, max=1e8) * C

def exhaust_velocity(Pc):
    Pc_eff = torch.clamp(Pc, min=pa * 1.1, max=1e8)
    return torch.sqrt(2.0 * gamma * R * Tc / (gamma - 1.0) * (1.0 - (pa / Pc_eff)**((gamma - 1.0)/gamma)))

use_conditioning = False
cond_img = None
if use_conditioning:
    def time_series_to_image(time_series, pressure_series, thrust_series, img_size=32):
        T_norm = (time_series - time_series.mean()) / (time_series.std() + 1e-8)
        P_norm = (pressure_series - pressure_series.mean()) / (pressure_series.std() + 1e-8)
        F_norm = (thrust_series - thrust_series.mean()) / (thrust_series.std() + 1e-8)
        idx = torch.linspace(0, 1, len(time_series), device=device)
        Rt = torch.outer(idx, idx)
        Rp = torch.outer(P_norm, P_norm)
        Rf = torch.outer(F_norm, F_norm)
        img = torch.stack([Rt, Rp, Rf], dim=0).unsqueeze(0)
        img = nnF.interpolate(img, size=(img_size, img_size), mode='bilinear', align_corners=False)
        return img.squeeze(0)
    cond_img = time_series_to_image(time_grid, Pc_target, F_target)

class GrainField3D(nn.Module):
    def __init__(self, hidden_size=192, num_layers=7, use_cond=False, cond_emb_dim=64):
        super().__init__()
        self.use_cond = use_cond
        in_features = base_features
        if use_cond:
            cond_size = 3*32*32 if cond_img is not None else 0
            self.cond_proj = nn.Linear(cond_size, cond_emb_dim)
            in_features += cond_emb_dim
        layers = []
        current_in = in_features
        for _ in range(num_layers):
            layers += [nn.Linear(current_in, hidden_size), nn.Softplus(beta=10)]
            current_in = hidden_size
        layers += [nn.Linear(hidden_size, 1)]
        self.net = nn.Sequential(*layers)
        r0 = 0.055
        radial = torch.sqrt(X**2 + Y**2)
        initial_logits = 10.0 * (r0 - radial) / eps
        initial_logits += 0.2 * torch.randn_like(initial_logits)
        initial_logits = torch.clamp(initial_logits, -10.0, 10.0)
        self.logit_bias = nn.Parameter(initial_logits.flatten(), requires_grad=True)

    def forward(self, encoded_coords, cond=None):
        x = encoded_coords
        if self.use_cond and cond is not None:
            cond_emb = self.cond_proj(cond.flatten())
            cond_emb = cond_emb.unsqueeze(0).repeat(x.shape[0], 1)
            x = torch.cat([x, cond_emb], dim=-1)
        logits = self.net(x).squeeze(-1) + self.logit_bias
        logits = torch.clamp(logits, -20.0, 20.0)
        return logits

    def phi(self, cond=None):
        logits = self.forward(coords_encoded, cond)
        phi = torch.sigmoid(logits).reshape(N, N, N)
        phi = torch.clamp(phi, 0.0, 1.0)
        return phi

def double_well(phi):
    return phi**2 * (1.0 - phi)**2

def grad_norm_squared(phi):
    gx = (phi.roll(-1, 2) - phi.roll(1, 2)) / (2 * dx)
    gy = (phi.roll(-1, 1) - phi.roll(1, 1)) / (2 * dx)
    gz = (phi.roll(-1, 0) - phi.roll(1, 0)) / (2 * dx)
    return gx**2 + gy**2 + gz**2 + 1e-12

def interface_measure(phi):
    return torch.sum(eps * grad_norm_squared(phi) + double_well(phi) / eps) * dx**3

def propellant_volume(phi):
    return torch.sum(torch.clamp(phi, 0.0, 1.0)) * dx**3

def total_volume():
    return math.pi * R_case**2 * L

def loading_fraction(phi):
    return propellant_volume(phi) / total_volume()

def compactness_penalty(phi):
    surface = interface_measure(phi)
    vol = propellant_volume(phi)
    if vol < 1e-8 or vol > total_volume() * 1.05:
        return torch.tensor(1000.0, device=device)
    ideal = 4 * math.pi * (3*vol/(4*math.pi))**(2/3)
    return (surface - ideal).square()

def connectedness_penalty(phi):
    binary = (phi > 0.5).cpu().numpy()
    labeled, num = label(binary)
    if num == 0:
        return torch.tensor(1000.0, device=device)
    largest = np.max(np.bincount(labeled.ravel())[1:])
    total = np.sum(binary)
    return torch.tensor((1.0 - largest / total)**2 if total > 0 else 1000.0, device=device)

def forward_motor_trackB3D(field, cond=None, store_history=True):
    phi = field.phi(cond).clone()
    Pc = torch.tensor(2.0e6, device=device)
    history = {"Pc": [], "F": [], "Ab": [], "vol": [], "load": []}
    phi_keyframes = []
    for step in range(NT):
        Ab = interface_measure(phi)
        Pc_pos = torch.clamp(Pc, min=1.0, max=1e8)
        r_dot = a * Pc_pos**n
        mdot_gen = rho_p * Ab * r_dot
        mdot_noz = mdot_choked(Pc_pos)
        Vg = torch.clamp(torch.sum(1.0 - phi) * dx**3, min=1e-6, max=total_volume() * 1.1)
        dPdt = (mdot_gen - mdot_noz) * R * Tc / Vg
        Pc = Pc + dPdt * dt
        Pc = torch.clamp(Pc, min=0.0, max=1e8)
        W_prime = 2.0 * phi * (1.0 - phi) * (1.0 - 2.0 * phi)
        phi_fft = torch.fft.fftn(phi)
        W_prime_fft = torch.fft.fftn(W_prime)
        rhs_fft = phi_fft - (dt * r_dot / eps) * W_prime_fft
        denom = 1.0 + dt * r_dot * eps * k2
        denom = torch.clamp(denom, min=1e-6)
        phi_fft_new = rhs_fft / denom
        phi = torch.fft.ifftn(phi_fft_new).real
        phi = torch.clamp(phi, -0.2, 1.2)
        phi = torch.sigmoid(20.0 * (phi - 0.5) + 0.5)
        ve = exhaust_velocity(Pc)
        F = mdot_noz * ve
        history["Pc"].append(Pc.item())
        history["F"].append(F.item())
        history["Ab"].append(Ab.item())
        history["vol"].append(propellant_volume(phi).item())
        history["load"].append(loading_fraction(phi).item())
        if store_history and step % 30 == 0:
            phi_keyframes.append(phi.detach().cpu())
    out = {k: torch.tensor(v, device=device) for k, v in history.items()}
    out["phi_final"] = phi
    if store_history:
        out["phi_keyframes"] = phi_keyframes
    return out

field = GrainField3D(hidden_size=192, num_layers=7, use_cond=use_conditioning, cond_emb_dim=64).to(device)
optimizer = optim.Adam(field.parameters(), lr=3e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, factor=0.5)

target_loading = 0.70
min_loading = 0.35
max_loading = 1.00
num_iters = 1000
cond = cond_img if use_conditioning else None

for it in range(num_iters):
    optimizer.zero_grad()
    out = forward_motor_trackB3D(field, cond, store_history=False)
    Pc = out["Pc"]
    phi = field.phi(cond)
    fit_loss = nnF.mse_loss(Pc, Pc_target)
    load_frac = loading_fraction(phi)
    load_loss = (target_loading - load_frac)**2 + \
                10.0 * nnF.relu(min_loading - load_frac)**2 + \
                50.0 * nnF.relu(load_frac - max_loading)**2
    compact_loss = compactness_penalty(phi)
    connect_loss = connectedness_penalty(phi)
    smooth_loss = torch.mean(double_well(phi))
    loss = 20.0 * fit_loss + 10.0 * load_loss + 0.5 * compact_loss + 3.0 * connect_loss + 0.01 * smooth_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(field.parameters(), max_norm=0.5)
    optimizer.step()
    scheduler.step(fit_loss)
    if it % 50 == 0:
        rel_err = torch.mean(torch.abs(Pc - Pc_target) / (Pc_target + 1e-6))
        print(f"Iter {it:4d} | Loss {loss.item():.3e} | Fit {fit_loss.item():.4e} | RelErr {rel_err:.3f} | Load {load_frac:.3f}")

final_out = forward_motor_trackB3D(field, cond, store_history=True)
Pc_final = final_out["Pc"]
phi_keyframes = final_out["phi_keyframes"]
phi_final = final_out["phi_final"]

def plot_3d_slices(phi_tensor, title, fname_prefix):
    phi_np = phi_tensor.cpu().numpy()
    fig, axs = plt.subplots(1, 3, figsize=(15,5))
    mid = N // 2
    axs[0].imshow(phi_np[mid, :, :], cmap=cm.gray, origin="lower")
    axs[0].set_title("Axial slice (mid-z)")
    axs[1].imshow(phi_np[:, mid, :], cmap=cm.gray, origin="lower")
    axs[1].set_title("Radial slice (mid-y)")
    axs[2].imshow(phi_np[:, :, mid], cmap=cm.gray, origin="lower")
    axs[2].set_title("Radial slice (mid-x)")
    fig.suptitle(title)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"{fname_prefix}_slices.png"), dpi=200)

plot_3d_slices(field.phi(cond), "Initial Grain", "initial")
plot_3d_slices(phi_final, "Final Grain", "final")

if MESH_EXPORT:
    phi_np = phi_final.cpu().numpy()
    verts, faces, _, _ = marching_cubes(phi_np, level=0.5)
    mesh = trimesh.Trimesh(verts, faces)
    mesh.export(os.path.join(OUT_DIR, "final_grain.obj"))

fig, ax = plt.subplots(figsize=(6,6))
mid = N // 2
im = ax.imshow(phi_keyframes[0][mid, :, :], cmap=cm.gray, origin="lower")
def update(i):
    im.set_data(phi_keyframes[i][mid, :, :])
    ax.set_title(f"Axial Mid-Slice t ≈ {i*30*dt:.2f}s")
    return [im]
anim = FuncAnimation(fig, update, frames=len(phi_keyframes), interval=200)
anim.save(os.path.join(OUT_DIR, "burnback_axial.gif"), writer="pillow")

t_np = time_grid.numpy()
fig, axs = plt.subplots(4, 1, figsize=(10, 12))
axs[0].plot(t_np, Pc_target.numpy()/1e6, "k--", label="Target")
axs[0].plot(t_np, Pc_final.numpy()/1e6, "r-", label="Achieved")
axs[0].set_ylabel("Pressure [MPa]")
axs[1].plot(t_np, final_out["F"].numpy()/1e3, "b-")
axs[1].set_ylabel("Thrust [kN]")
axs[2].plot(t_np, final_out["Ab"].numpy(), "g-")
axs[2].set_ylabel("Ab [m²]")
axs[3].plot(t_np, final_out["load"].numpy(), "m-")
axs[3].set_ylabel("Loading")
axs[3].set_xlabel("Time [s]")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "performance_curves.png"), dpi=200)

In [None]:
import os
import math
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as nnF
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
from scipy.ndimage import label

try:
    from skimage.measure import marching_cubes
    import trimesh
    MESH_EXPORT = True
except ImportError:
    MESH_EXPORT = False

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

OUT_DIR = "./out_track_b_3d_latest"
os.makedirs(OUT_DIR, exist_ok=True)

gamma = 1.22
R = 355.0
Tc = 3200.0
rho_p = 1700.0
a = 5.0e-5
n = 0.35
pa = 101325.0
At = 3.0e-4

L = 0.20
R_case = 0.10
N = 32
dx = max(L, 2*R_case) / N
eps = 4.0 * dx

z = torch.linspace(-L/2, L/2, N, device=device)
y = torch.linspace(-R_case, R_case, N, device=device)
x = torch.linspace(-R_case, R_case, N, device=device)
Z, Y, X = torch.meshgrid(z, y, x, indexing="ij")
coords_flat = torch.stack([X.flatten(), Y.flatten(), Z.flatten()], dim=-1)

num_freqs = 6
freq_bands = 2.0 ** torch.linspace(0., num_freqs-1, num_freqs, device=device)
def fourier_encode(coords):
    enc = []
    for d in range(3):
        c = coords[:, d:d+1]
        enc.append(c)
        for f in freq_bands:
            arg = f * c
            enc.append(torch.sin(arg))
            enc.append(torch.cos(arg))
    return torch.cat(enc, dim=-1)
coords_encoded = fourier_encode(coords_flat)
base_features = coords_encoded.shape[1]

kx = 2 * math.pi * torch.fft.fftfreq(N, d=dx).to(device)
k2 = kx[None,None,:]**2 + kx[None,:,None]**2 + kx[:,None,None]**2

dt = 0.001
t_end = 3.5
time_grid = torch.arange(0.0, t_end + dt, dt, device=device)
NT = len(time_grid)

Pc_target = torch.ones_like(time_grid) * 6.0e6
Pc_target[time_grid > 2.2] *= torch.exp(-0.8 * (time_grid[time_grid > 2.2] - 2.2))
F_target = Pc_target * At * 100.0

def mdot_choked(Pc):
    C = math.sqrt(gamma / (R * Tc)) * (2.0 / (gamma + 1.0)) ** ((gamma + 1.0) / (2.0 * (gamma - 1.0)))
    return At * torch.clamp(Pc, min=0.0, max=1e8) * C

def exhaust_velocity(Pc):
    Pc_eff = torch.clamp(Pc, min=pa * 1.1, max=1e8)
    return torch.sqrt(2.0 * gamma * R * Tc / (gamma - 1.0) * (1.0 - (pa / Pc_eff)**((gamma - 1.0)/gamma)))

use_conditioning = False
cond_img = None
if use_conditioning:
    def time_series_to_image(time_series, pressure_series, thrust_series, img_size=32):
        T_norm = (time_series - time_series.mean()) / (time_series.std() + 1e-8)
        P_norm = (pressure_series - pressure_series.mean()) / (pressure_series.std() + 1e-8)
        F_norm = (thrust_series - thrust_series.mean()) / (thrust_series.std() + 1e-8)
        idx = torch.linspace(0, 1, len(time_series), device=device)
        Rt = torch.outer(idx, idx)
        Rp = torch.outer(P_norm, P_norm)
        Rf = torch.outer(F_norm, F_norm)
        img = torch.stack([Rt, Rp, Rf], dim=0).unsqueeze(0)
        img = nnF.interpolate(img, size=(img_size, img_size), mode='bilinear', align_corners=False)
        return img.squeeze(0)
    cond_img = time_series_to_image(time_grid, Pc_target, F_target)

class GrainField3D(nn.Module):
    def __init__(self, hidden_size=192, num_layers=7, use_cond=False, cond_emb_dim=64):
        super().__init__()
        self.use_cond = use_cond
        in_features = base_features
        if use_cond:
            cond_size = 3*32*32 if cond_img is not None else 0
            self.cond_proj = nn.Linear(cond_size, cond_emb_dim)
            in_features += cond_emb_dim
        layers = []
        current_in = in_features
        for _ in range(num_layers):
            layers += [nn.Linear(current_in, hidden_size), nn.Softplus(beta=10)]
            current_in = hidden_size
        layers += [nn.Linear(hidden_size, 1)]
        self.net = nn.Sequential(*layers)
        r0 = 0.055
        radial = torch.sqrt(X**2 + Y**2)
        initial_logits = 10.0 * (r0 - radial) / eps
        initial_logits += 0.2 * torch.randn_like(initial_logits)
        initial_logits = torch.clamp(initial_logits, -10.0, 10.0)
        self.logit_bias = nn.Parameter(initial_logits.flatten(), requires_grad=True)

    def forward(self, encoded_coords, cond=None):
        x = encoded_coords
        if self.use_cond and cond is not None:
            cond_emb = self.cond_proj(cond.flatten())
            cond_emb = cond_emb.unsqueeze(0).repeat(x.shape[0], 1)
            x = torch.cat([x, cond_emb], dim=-1)
        logits = self.net(x).squeeze(-1) + self.logit_bias
        logits = torch.clamp(logits, -20.0, 20.0)
        return logits

    def phi(self, cond=None):
        logits = self.forward(coords_encoded, cond)
        phi = torch.sigmoid(logits).reshape(N, N, N)
        phi = torch.clamp(phi, 0.0, 1.0)
        return phi

def double_well(phi):
    return phi**2 * (1.0 - phi)**2

def grad_norm_squared(phi):
    gx = (phi.roll(-1, 2) - phi.roll(1, 2)) / (2 * dx)
    gy = (phi.roll(-1, 1) - phi.roll(1, 1)) / (2 * dx)
    gz = (phi.roll(-1, 0) - phi.roll(1, 0)) / (2 * dx)
    return gx**2 + gy**2 + gz**2 + 1e-12

def interface_measure(phi):
    return torch.sum(eps * grad_norm_squared(phi) + double_well(phi) / eps) * dx**3

def propellant_volume(phi):
    return torch.sum(torch.clamp(phi, 0.0, 1.0)) * dx**3

def total_volume():
    return math.pi * R_case**2 * L

def loading_fraction(phi):
    return propellant_volume(phi) / total_volume()

def compactness_penalty(phi):
    surface = interface_measure(phi)
    vol = propellant_volume(phi)
    if vol < 1e-8 or vol > total_volume() * 1.05:
        return torch.tensor(1000.0, device=device)
    ideal = 4 * math.pi * (3*vol/(4*math.pi))**(2/3)
    return (surface - ideal).square()

def connectedness_penalty(phi):
    binary = (phi > 0.5).cpu().numpy()
    labeled, num = label(binary)
    if num == 0:
        return torch.tensor(1000.0, device=device)
    largest = np.max(np.bincount(labeled.ravel())[1:])
    total = np.sum(binary)
    return torch.tensor((1.0 - largest / total)**2 if total > 0 else 1000.0, device=device)

def forward_motor_trackB3D(field, cond=None, store_history=True):
    phi = field.phi(cond).clone()
    Pc = torch.tensor(2.0e6, device=device)
    history = {"Pc": [], "F": [], "Ab": [], "vol": [], "load": []}
    phi_keyframes = []
    for step in range(NT):
        Ab = interface_measure(phi)
        Pc_pos = torch.clamp(Pc, min=1.0, max=1e8)
        r_dot = a * Pc_pos**n
        mdot_gen = rho_p * Ab * r_dot
        mdot_noz = mdot_choked(Pc_pos)
        Vg = torch.clamp(torch.sum(1.0 - phi) * dx**3, min=1e-6, max=total_volume() * 1.1)
        dPdt = (mdot_gen - mdot_noz) * R * Tc / Vg
        Pc = Pc + dPdt * dt
        Pc = torch.clamp(Pc, min=0.0, max=1e8)
        W_prime = 2.0 * phi * (1.0 - phi) * (1.0 - 2.0 * phi)
        phi_fft = torch.fft.fftn(phi)
        W_prime_fft = torch.fft.fftn(W_prime)
        rhs_fft = phi_fft - (dt * r_dot / eps) * W_prime_fft
        denom = 1.0 + dt * r_dot * eps * k2
        denom = torch.clamp(denom, min=1e-6)
        phi_fft_new = rhs_fft / denom
        phi = torch.fft.ifftn(phi_fft_new).real
        phi = torch.clamp(phi, -0.2, 1.2)
        phi = torch.sigmoid(20.0 * (phi - 0.5) + 0.5)
        ve = exhaust_velocity(Pc)
        F = mdot_noz * ve
        history["Pc"].append(Pc.item())
        history["F"].append(F.item())
        history["Ab"].append(Ab.item())
        history["vol"].append(propellant_volume(phi).item())
        history["load"].append(loading_fraction(phi).item())
        if store_history and step % 30 == 0:
            phi_keyframes.append(phi.detach().cpu())
    out = {k: torch.tensor(v, device=device) for k, v in history.items()}
    out["phi_final"] = phi
    if store_history:
        out["phi_keyframes"] = phi_keyframes
    return out

field = GrainField3D(hidden_size=192, num_layers=7, use_cond=use_conditioning, cond_emb_dim=64)
optimizer = optim.Adam(field.parameters(), lr=3e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100, factor=0.5)

target_loading = 0.70
min_loading = 0.35
max_loading = 1.00
num_iters = 1800
cond = cond_img if use_conditioning else None

for it in range(num_iters):
    optimizer.zero_grad()
    out = forward_motor_trackB3D(field, cond, store_history=False)
    Pc = out["Pc"]
    phi = field.phi(cond)
    fit_loss = nnF.mse_loss(Pc, Pc_target)
    load_frac = loading_fraction(phi)
    load_loss = (target_loading - load_frac)**2 + \
                10.0 * nnF.relu(min_loading - load_frac)**2 + \
                50.0 * nnF.relu(load_frac - max_loading)**2
    compact_loss = compactness_penalty(phi)
    connect_loss = connectedness_penalty(phi)
    smooth_loss = torch.mean(double_well(phi))
    loss = 20.0 * fit_loss + 10.0 * load_loss + 0.5 * compact_loss + 3.0 * connect_loss + 0.01 * smooth_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(field.parameters(), max_norm=0.5)
    optimizer.step()
    scheduler.step(fit_loss)
    if it % 50 == 0:
        rel_err = torch.mean(torch.abs(Pc - Pc_target) / (Pc_target + 1e-6))
        print(f"Iter {it:4d} | Loss {loss.item():.3e} | Fit {fit_loss.item():.4e} | RelErr {rel_err:.3f} | Load {load_frac:.3f}")

final_out = forward_motor_trackB3D(field, cond, store_history=True)
Pc_final = final_out["Pc"]
phi_keyframes = final_out["phi_keyframes"]
phi_final = final_out["phi_final"]

def plot_3d_slices(phi_tensor, title, fname_prefix):
    phi_np = phi_tensor.cpu().numpy()
    fig, axs = plt.subplots(1, 3, figsize=(15,5))
    mid = N // 2
    axs[0].imshow(phi_np[mid, :, :], cmap=cm.gray, origin="lower")
    axs[0].set_title("Axial slice (mid-z)")
    axs[1].imshow(phi_np[:, mid, :], cmap=cm.gray, origin="lower")
    axs[1].set_title("Radial slice (mid-y)")
    axs[2].imshow(phi_np[:, :, mid], cmap=cm.gray, origin="lower")
    axs[2].set_title("Radial slice (mid-x)")
    fig.suptitle(title)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"{fname_prefix}_slices.png"), dpi=200)

plot_3d_slices(field.phi(cond), "Initial Grain", "initial")
plot_3d_slices(phi_final, "Final Grain", "final")

if MESH_EXPORT:
    phi_np = phi_final.cpu().numpy()
    verts, faces, _, _ = marching_cubes(phi_np, level=0.5)
    mesh = trimesh.Trimesh(verts, faces)
    mesh.export(os.path.join(OUT_DIR, "final_grain.obj"))

fig, ax = plt.subplots(figsize=(6,6))
mid = N // 2
im = ax.imshow(phi_keyframes[0][mid, :, :], cmap=cm.gray, origin="lower")
def update(i):
    im.set_data(phi_keyframes[i][mid, :, :])
    ax.set_title(f"Axial Mid-Slice t ≈ {i*30*dt:.2f}s")
    return [im]
anim = FuncAnimation(fig, update, frames=len(phi_keyframes), interval=200)
anim.save(os.path.join(OUT_DIR, "burnback_axial.gif"), writer="pillow")

t_np = time_grid.numpy()
fig, axs = plt.subplots(4, 1, figsize=(10, 12))
axs[0].plot(t_np, Pc_target.numpy()/1e6, "k--", label="Target")
axs[0].plot(t_np, Pc_final.numpy()/1e6, "r-", label="Achieved")
axs[0].set_ylabel("Pressure [MPa]")
axs[1].plot(t_np, final_out["F"].numpy()/1e3, "b-")
axs[1].set_ylabel("Thrust [kN]")
axs[2].plot(t_np, final_out["Ab"].numpy(), "g-")
axs[2].set_ylabel("Ab [m²]")
axs[3].plot(t_np, final_out["load"].numpy(), "m-")
axs[3].set_ylabel("Loading")
axs[3].set_xlabel("Time [s]")
plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "performance_curves.png"), dpi=200)