## Domain_walls

In [None]:
from ase import Atoms
from ase.build import cut, rotate, stack
from ase.io import read
from ase.neighborlist import NeighborList
from ase.spacegroup import get_spacegroup
from ase.visualize import view
from numpy.linalg import norm
import io
import numpy as np

class DomainWallSystem:
    def __init__(self, atoms, **kwargs):
        self.atoms = atoms
        self.international_symbol, _ = self.identify_system_symmetry()    
        self.domain_wall_tag = kwargs.get('domain_wall_tag', "180")

        if self.domain_wall_tag == "customed":
            required_params = ['domainA_a', 'domainA_b', 'domainA_c', 'domainB_a', 'domainB_b', 'domainB_c', 'stack_axis']
            for param in required_params:
                value = kwargs.get(param)
                if value is None:
                    raise ValueError(f"Parameter '{param}' must be provided.")
                    
            self.domainA_a = kwargs.get('domainA_a')
            self.domainA_b = kwargs.get('domainA_b')
            self.domainA_c = kwargs.get('domainA_c')
            self.domainB_a = kwargs.get('domainB_a')
            self.domainB_b = kwargs.get('domainB_b')
            self.domainB_c = kwargs.get('domainB_c')
            self.stack_axis = kwargs.get('stack_axis')  # 0/1/2
        else:
            self.domain_size = kwargs.get('domain_size', 3)
            predefined_systems = self.predefined_systems()
            if self.international_symbol not in predefined_systems or str(self.domain_wall_tag) not in predefined_systems[self.international_symbol]:
                raise ValueError(f"The combination of international symbol '{self.international_symbol}' and domain wall tag '{self.domain_wall_tag}' is not defined in predefined systems.")
            
            predefined_values = predefined_systems[self.international_symbol][str(self.domain_wall_tag)]
            self.domainA_a, self.domainA_b, self.domainA_c, self.stack_axis = predefined_values[0].values()
            self.domainB_a, self.domainB_b, self.domainB_c, _ = predefined_values[1].values()
            
        self.cutoff = kwargs.get('cutoff', 0.0)
            
    def identify_system_symmetry(self):
        symmetry = get_spacegroup(self.atoms, symprec=1e-5)
        international_symbol = symmetry.symbol
        spacegroup_number = symmetry.no
        return international_symbol, spacegroup_number

    def predefined_systems(self):
        predefined_domain_wall_dict = {
            "R 3 m": {
                "71": [
                    {'a': [self.domain_size, self.domain_size, 0], 'b': [0, 0, 1], 'c': [1, -1, 0], 'stack_axis': 0},
                    {'a': [self.domain_size, self.domain_size, 0], 'b': [0, 0, -1], 'c': [-1, 1, 0], 'stack_axis': 0}
                ],
                "109": [
                    {'a': [self.domain_size, 0, 0], 'b': [0, 1, 0], 'c': [0, 0, 1], 'stack_axis': 0},
                    {'a': [self.domain_size, 0, 0], 'b': [0, -1, 0], 'c': [0, 0, -1], 'stack_axis': 0}
                ],
                "180": [
                    {'a': [1, 1, 0], 'b': [0, 0, 1], 'c': [self.domain_size, -self.domain_size, 0], 'stack_axis': 2},
                    {'a': [-1, -1, 0], 'b': [0, 0, -1], 'c': [self.domain_size, -self.domain_size, 0], 'stack_axis': 2}
                ]
            },
            "R 3 c": {
                "71": [
                    {'a': [1, -1, 0], 'b': [0, 0, 1], 'c': [self.domain_size, self.domain_size, 0], 'stack_axis': 2},
                    {'a': [-1, 1, 0], 'b': [0, 0, -1], 'c': [self.domain_size, self.domain_size, 0], 'stack_axis': 2}
                ],
                "109": [
                    {'a': [0.0, 1, 0], 'b': [1, 0, 0], 'c': [0, 0, self.domain_size], 'stack_axis': 2},
                    {'a': [0.0, -1, 0], 'b': [-1, 0, 0], 'c': [0, 0, self.domain_size], 'stack_axis': 2}
                ],
                "180": [
                    {'a': [1, 1, 0], 'b': [0, 0, -1], 'c': [-self.domain_size, self.domain_size, 0], 'stack_axis': 2},
                    {'a': [-1, -1, 0], 'b': [0, 0, 1], 'c': [-self.domain_size, self.domain_size, 0], 'stack_axis': 2}
                ]
            },
            "P 4 m m": {
                "90": [
                    {'a': [0, 1, 0], 'b': [-1, 0, 1], 'c': [self.domain_size, 0, self.domain_size], 'stack_axis': 2},
                    {'a': [0, -1, 0], 'b': [1, 0, -1], 'c': [self.domain_size, 0, self.domain_size], 'stack_axis': 2}
                ],
                "180": [
                    {'a': [1.01, 0, 0], 'b': [0, self.domain_size, 0], 'c': [0, 0, 1.01], 'stack_axis': 1},
                    {'a': [-1.01, 0, 0], 'b': [0, self.domain_size, 0], 'c': [0, 0, -1.01], 'stack_axis': 1}
                ]
            },
            "P m c 21": {
                "90": [
                    {'a': [0, 1*self.domain_size, 0], 'b': [-1, 0, 0], 'c': [0, 0, 1], 'stack_axis': 0},
                    {'a': [-1*self.domain_size, 0, 0], 'b': [0, -1, 0], 'c': [0, 0, 1], 'stack_axis': 0}
                ],
                "120_HH_TT": [
                    {'a': [-1, -2, 0], 'b': [2.0*self.domain_size, -1.0*self.domain_size, 0], 'c': [0, 0, 1], 'stack_axis': 1},
                    {'a': [-1, 2, 0], 'b': [-2.0*self.domain_size, -1.0*self.domain_size, 0], 'c': [0, 0, 1], 'stack_axis': 1}
                ],
                "120_HT": [
                    {'a': [-1*self.domain_size, -2*self.domain_size, 0], 'b': [2.0, -1.0, 0], 'c': [0, 0, 1], 'stack_axis': 0},
                    {'a': [-1*self.domain_size, 2*self.domain_size, 0], 'b': [-2.0, -1.0, 0], 'c': [0, 0, 1], 'stack_axis': 0}
                ],
                "180": [
                    {'a': [1, 0, 0], 'b': [0, self.domain_size, 0], 'c': [0, 0, 1], 'stack_axis': 1},
                    {'a': [-1, 0, 0], 'b': [0, -self.domain_size, 0], 'c': [0, 0, 1], 'stack_axis': 1}
                ]
            },
        }
        return predefined_domain_wall_dict
    
    def get_domain_wall_dict(self):
        return {
            self.international_symbol: {
                str(self.domain_wall_tag): [
                    {'a': self.domainA_a, 'b': self.domainA_b, 'c': self.domainA_c, 'stack_axis': self.stack_axis},
                    {'a': self.domainB_a, 'b': self.domainB_b, 'c': self.domainB_c, 'stack_axis': self.stack_axis}
                ]
            }
        }

    def cut_and_rotate(self, atoms, a, b, c):
        slab = cut(atoms, a=a, b=b, c=c)
        rotate(slab, slab.cell[0], (0, 1, 0), slab.cell[1], (1, 0, 0))
        return slab

    def remove_close_atoms(self, atoms, cutoff=0.0):
        nl = NeighborList([cutoff / 2.0] * len(atoms), self_interaction=False, bothways=True)
        nl.update(atoms)
        indices_to_remove = set()
        for i in range(len(atoms)):
            indices, offsets = nl.get_neighbors(i)
            for idx in indices:
                if i < idx:
                    distance = atoms.get_distance(i, idx)
                    if distance < cutoff:
                        indices_to_remove.add(idx)
        atoms = atoms[[atom.index not in indices_to_remove for atom in atoms]]
        return atoms

    def calculate_lattice_strain(self, slab1, slab2):
        a1, b1, c1 = slab1.cell
        a2, b2, c2 = slab2.cell
        strain_a = (norm(a1) - norm(a2)) / norm(a2) * 100
        strain_b = (norm(b1) - norm(b2)) / norm(b2) * 100
        strain_c = (norm(c1) - norm(c2)) / norm(c2) * 100
        print("Lattice strain for DW:")
        print(f"strain along a (%): {strain_a:.2f}")
        print(f"strain along b (%): {strain_b:.2f}")
        print(f"strain along c (%): {strain_c:.2f}")

    def build_domain_wall(self):
        domain_wall_dict = self.get_domain_wall_dict()
        angles = domain_wall_dict[self.international_symbol][str(self.domain_wall_tag)]
        stack_axis = angles[0]['stack_axis']
        slabs = [self.cut_and_rotate(self.atoms, **{k: v for k, v in angle.items() if k in ('a', 'b', 'c')}) for angle in angles]
        
        for i, slab in enumerate(slabs):
            slab = self.remove_close_atoms(slab, self.cutoff)
        
        stacked_slab = stack(*slabs, axis=stack_axis, maxstrain=None)
        stacked_slab = self.remove_close_atoms(stacked_slab, self.cutoff)
        
        return stacked_slab


In [None]:
atoms = read("../examples/Domain_walls/R3m/BaTiO3_R3m.vasp")
system = DomainWallSystem(atoms, domain_wall_tag="180")
new_structure = system.build_domain_wall()
view(new_structure)

### Q-Api 接口

In [None]:
def domain_wall(**kwargs):
    if kwargs['domain_wall_tag'] == "customed":
        for domain in ['domainA', 'domainB']:
            for axis in ['a', 'b', 'c']:
                kwargs[f'{domain}_{axis}'] = [kwargs.pop(f'{domain}_{axis}_{i}') for i in range(1, 4)]
    system = DomainWallSystem(**kwargs)
    new_structure = system.build_domain_wall()
    cif_output = io.BytesIO()
    write(cif_output, new_structure, format='cif')
    cif_string = cif_output.getvalue().decode('utf-8')
    return cif_string

In [None]:
@guide_register_func(ModelTag('构建铁电材料畴壁结构', '钙钛矿', note='需要指定极化相基体，并基于输入结构的空间群和预设的极化方向构建铁电材料畴壁结构').identifier)
def Domain_wall(data):
    mode = data.get('mode')
    
    if mode == 'init':
        title = Description(name='title', note='基于输入结构的空间群和预设的极化方向构建铁电材料畴壁结构')
        atoms = StructureFromList(name='atoms', note='选择极化相基体结构，注意预设的极化方向为：R3c(1-11)、R3m(001)、P4mm(001)、Pmc21(user-defined)、customed(user-defined)', 
                                  id='atoms', structure_type='crystal', output_format='cif')
        domain_wall_tag = SingleFromList(name='domain_wall_tag', note='根据提示选择畴壁类型或自定义畴壁类型customed',
                                         id='domain_wall_tag', default_value='180', 
                                         list_value=[{'label': '适用于R 3 m、R 3 c空间群：71', 'value': '71'},
                                                     {'label': '适用于P 4 m m、P m c 21空间群：90', 'value': '90'},
                                                     {'label': '适用于R 3 m、R 3 c空间群：109', 'value': '109'},
                                                     {'label': '适用于P m c 21空间群：120_HH_TT', 'value': '120_HH_TT'},
                                                     {'label': '适用于P m c 21空间群：120_HT', 'value': '120_HT'},
                                                     {'label': '适用于R 3 m、R 3 c、P 4 m m、P m c 21空间群：180', 'value': '180'},
                                                     {'label': '适用于任意空间群：customed', 'value': 'customed'}], is_required=1)
        
        domainA_a_1 = SingleInput(name='domainA_a_1', note='domainA_a_1', id='domainA_a_1', input_type='int', default_value= 1, min=-10, max=10, is_required=0)
        domainA_a_2 = SingleInput(name='domainA_a_2', note='domainA_a_2', id='domainA_a_2', input_type='int', default_value= 1, min=-10, max=10, is_required=0)
        domainA_a_3 = SingleInput(name='domainA_a_3', note='domainA_a_3', id='domainA_a_3', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainA_b_1 = SingleInput(name='domainA_b_1', note='domainA_b_1', id='domainA_b_1', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainA_b_2 = SingleInput(name='domainA_b_2', note='domainA_b_2', id='domainA_b_2', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainA_b_3 = SingleInput(name='domainA_b_3', note='domainA_b_3', id='domainA_b_3', input_type='int', default_value= 1, min=-10, max=10, is_required=0)
        domainA_c_1 = SingleInput(name='domainA_c_1', note='domainA_c_1', id='domainA_c_1', input_type='int', default_value= 1, min=-10, max=10, is_required=0)
        domainA_c_2 = SingleInput(name='domainA_c_2', note='domainA_c_2', id='domainA_c_2', input_type='int', default_value=-1, min=-10, max=10, is_required=0)
        domainA_c_3 = SingleInput(name='domainA_c_3', note='domainA_c_3', id='domainA_c_3', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainB_a_1 = SingleInput(name='domainB_a_1', note='domainB_a_1', id='domainB_a_1', input_type='int', default_value=-1, min=-10, max=10, is_required=0)
        domainB_a_2 = SingleInput(name='domainB_a_2', note='domainB_a_2', id='domainB_a_2', input_type='int', default_value=-1, min=-10, max=10, is_required=0)
        domainB_a_3 = SingleInput(name='domainB_a_3', note='domainB_a_3', id='domainB_a_3', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainB_b_1 = SingleInput(name='domainB_b_1', note='domainB_b_1', id='domainB_b_1', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainB_b_2 = SingleInput(name='domainB_b_2', note='domainB_b_2', id='domainB_b_2', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        domainB_b_3 = SingleInput(name='domainB_b_3', note='domainB_b_3', id='domainB_b_3', input_type='int', default_value=-1, min=-10, max=10, is_required=0)
        domainB_c_1 = SingleInput(name='domainB_c_1', note='domainB_c_1', id='domainB_c_1', input_type='int', default_value= 1, min=-10, max=10, is_required=0)
        domainB_c_2 = SingleInput(name='domainB_c_2', note='domainB_c_2', id='domainB_c_2', input_type='int', default_value=-1, min=-10, max=10, is_required=0)
        domainB_c_3 = SingleInput(name='domainB_c_3', note='domainB_c_3', id='domainB_c_3', input_type='int', default_value= 0, min=-10, max=10, is_required=0)
        
        stack_axis = SingleInput(name='stack_axis', note='堆叠轴', id='stack_axis', input_type='int', default_value=0, min=0, max=2, is_required=0)
        domain_size = SingleInput(name='domain_size', note='畴壁尺寸', id='domain_size', input_type='float', default_value=3, min=1, max=10, is_required=0)
        cutoff = SingleInput(name='cutoff', note='根据截断删除重叠原子', id='cutoff', input_type='float', default_value=0, min=0, max=2, is_required=1)
        
        return [title(), atoms(), domain_wall_tag(), domainA_a_1(), domainA_a_2(), domainA_a_3(), domainA_b_1(), domainA_b_2(), domainA_b_3(), domainA_c_1(), domainA_c_2(), domainA_c_3(), 
                domainB_a_1(), domainB_a_2(), domainB_a_3(), domainB_b_1(), domainB_b_2(), domainB_b_3(), domainB_c_1(), domainB_c_2(), domainB_c_3(), stack_axis(), domain_size(), cutoff()]
    
    elif mode == 'generate':
        value = data.get('value')
        cif_string = domain_wall(**value)
        return {'file_content': cif_string, 'file_format': 'cif'}
