from pywarpx import picmi
import os

# 自动切换终端位置
#current_dir = os.getcwd()
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
print(os.getcwd())





#Physical constants

c = picmi.constants.c
q_e = picmi.constants.q_e

# Numver of time steps
max_steps = 100

# Number of cells
nx = 32 * 4
ny = 32 * 4
nz = 16 * 4

# Physical domain
xmin = -16e-3
xmax = 16e-3
ymin = -16e-3
ymax = 16e-3
zmin = 0
zmax = 16e-3

# Domain decomposition
max_grid_size = 256
blocking_factor = 16

# Embedded boundary
implicit_function = r"-1 + 2 * ( heaviside( z - 1e-3, 1 ) - heaviside( z - 2e-3, 0 ) ) * heaviside( (10e-3)**2 - x**2 - y**2, 1 )"
#implicit_function = r"-1 + 2 * ( heaviside( z - 1e-3, 0.5 ) - heaviside( z - 2e-3, 0.5 ) ) * ( heaviside( x - 1e-3, 0.5 ) - heaviside( x - 2e-3, 0.5 ) ) * ( heaviside( y - 1e-3, 0.5 ) - heaviside( y - 2e-3, 0.5 ) )"
#implicit_function =  r"-1 + 2 * ( heaviside( z - 2e-3, 1 ) - heaviside( z - 3e-3, 0 ) ) * ( heaviside( x + 2e-3, 1 ) - heaviside( x - 2e-3, 0 ) ) * ( heaviside( y + 2e-3, 1 ) - heaviside( y - 2e-3, 0 ) )"
potential_function = "-10000"
embedded_boundary = picmi.EmbeddedBoundary(
    implicit_function=implicit_function,
    potential=potential_function,
    #cover_multiple_cuts=1,
)



# Create grid
refinement = [
    [1, [-10e-3, -10e-3, 0.5e-3], [10e-3, 10e-3, 5e-3], [2,2,2] ]
]

grid = picmi.Cartesian3DGrid(
    number_of_cells = [nx, ny, nz],
    lower_bound = [xmin, ymin, zmin],
    upper_bound = [xmax, ymax, zmax],
    lower_boundary_conditions = ["dirichlet", "dirichlet", "dirichlet"],
    upper_boundary_conditions = ["dirichlet", "dirichlet", "dirichlet"],
    lower_boundary_conditions_particles = ["absorbing", "absorbing", "absorbing"],
    upper_boundary_conditions_particles = ["absorbing", "absorbing", "absorbing"],
    warpx_max_grid_size = max_grid_size,
    warpx_blocking_factor = blocking_factor,
    refined_regions = refinement,
)

# picmi.Cartesian3DGrid.warpx_potential_lo_z = 0
# picmi.Cartesian3DGrid.warpx_potential_hi_z = 16000

# Particles: electrons
density = 1e8
lower_bound = [-1e-3, -1e-3, 3e-3]
upper_bound = [1e-3, 1e-3, 4e-3]

uniform_bunch_distribution = picmi.UniformDistribution(
    density=density,
    lower_bound=lower_bound,
    upper_bound=upper_bound,
)

beam = picmi.Species(
    particle_type = "electron",
    name = "beam",
    initial_distribution=uniform_bunch_distribution,
)

# Electromagnetic Solver
solver = picmi.ElectromagneticSolver(grid=grid, method = "Yee", cfl = 0.95, divE_cleaning=0)
#solver = picmi.ElectrostaticSolver( grid=grid, method='Multigrid', warpx_cfl = 0.95, warpx_dt_update_interval = 10)

# Diagnostics
diag_field_list = ["E", "rho"]
particle_diag = picmi.ParticleDiagnostic(
    name = "diag1",
    period = 10,
)
field_diag = picmi.FieldDiagnostic(
    name="diag1",
    grid=grid,
    period=10,
    data_list=diag_field_list,
)

# Set up simulation
sim = picmi.Simulation(
    solver=solver,
    max_steps=max_steps,
    verbose=1,
    particle_shape="cubic",
    warpx_use_filter=1,
    warpx_serialize_initial_conditions=1,
    warpx_do_dynamic_scheduling=0,
    warpx_embedded_boundary = embedded_boundary,
)

# Add beam
layout = picmi.GriddedLayout(grid = grid, n_macroparticle_per_cell= [8, 8, 8])
#sim.add_species(beam, layout=layout, initialize_self_field=1)
sim.add_species(beam, layout=layout, initialize_self_field=0)

# Add diagnostics
sim.add_diagnostic(particle_diag)
sim.add_diagnostic(field_diag)

# Write input file
sim.write_input_file(file_name="input_file.txt")

# Initialize inputs and WarpX instance
sim.initialize_inputs()
sim.initialize_warpx()

# Advance simulation until last time step
sim.step(max_steps)

