diff --git a/CMakeLists.txt b/CMakeLists.txt index d12e7f4..b21051b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ file(GLOB_RECURSE EXT_FILES ext/*) file(GLOB_RECURSE PROJ_FILES cmake/*) set(SRC_FILES + src/md_atomic_infer.c src/md_csv.c src/md_csv.h src/md_cube.c diff --git a/benchmark/bench_gro.c b/benchmark/bench_gro.c index 8a77b3b..e098135 100644 --- a/benchmark/bench_gro.c +++ b/benchmark/bench_gro.c @@ -37,7 +37,7 @@ UBENCH_EX(gro, postprocess) { md_molecule_t mol = {0}; md_gro_molecule_api()->init_from_file(&mol, path, 0, alloc); - md_util_molecule_postprocess(&mol, alloc, MD_UTIL_POSTPROCESS_ELEMENT_BIT | MD_UTIL_POSTPROCESS_RESIDUE_BIT); + md_util_molecule_postprocess(&mol, alloc, MD_UTIL_POSTPROCESS_RESIDUE_BIT); // Element inference now happens during parsing size_t reset_pos = md_linear_allocator_get_pos(alloc); UBENCH_DO_BENCHMARK() { diff --git a/src/core/md_atomic.h b/src/core/md_atomic.h new file mode 100644 index 0000000..19c8a6e --- /dev/null +++ b/src/core/md_atomic.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Atomic number type: 0 = unknown, 1-118 = element atomic numbers +typedef uint8_t md_atomic_number_t; + +// Legacy alias for compatibility +typedef md_atomic_number_t md_element_t; + +// Forward declaration of molecule type +struct md_molecule_t; + +// Atomic number constants for all elements (Z values) +enum { + MD_Z_X = 0, // Unknown + MD_Z_H = 1, // Hydrogen + MD_Z_He = 2, // Helium + MD_Z_Li = 3, // Lithium + MD_Z_Be = 4, // Beryllium + MD_Z_B = 5, // Boron + MD_Z_C = 6, // Carbon + MD_Z_N = 7, // Nitrogen + MD_Z_O = 8, // Oxygen + MD_Z_F = 9, // Fluorine + MD_Z_Ne = 10, // Neon + MD_Z_Na = 11, // Sodium + MD_Z_Mg = 12, // Magnesium + MD_Z_Al = 13, // Aluminium + MD_Z_Si = 14, // Silicon + MD_Z_P = 15, // Phosphorus + MD_Z_S = 16, // Sulfur + MD_Z_Cl = 17, // Chlorine + MD_Z_Ar = 18, // Argon + MD_Z_K = 19, // Potassium + MD_Z_Ca = 20, // Calcium + MD_Z_Sc = 21, // Scandium + MD_Z_Ti = 22, // Titanium + MD_Z_V = 23, // Vanadium + MD_Z_Cr = 24, // Chromium + MD_Z_Mn = 25, // Manganese + MD_Z_Fe = 26, // Iron + MD_Z_Co = 27, // Cobalt + MD_Z_Ni = 28, // Nickel + MD_Z_Cu = 29, // Copper + MD_Z_Zn = 30, // Zinc + MD_Z_Ga = 31, // Gallium + MD_Z_Ge = 32, // Germanium + MD_Z_As = 33, // Arsenic + MD_Z_Se = 34, // Selenium + MD_Z_Br = 35, // Bromine + MD_Z_Kr = 36, // Krypton + MD_Z_Rb = 37, // Rubidium + MD_Z_Sr = 38, // Strontium + MD_Z_Y = 39, // Yttrium + MD_Z_Zr = 40, // Zirconium + MD_Z_Nb = 41, // Niobium + MD_Z_Mo = 42, // Molybdenum + MD_Z_Tc = 43, // Technetium + MD_Z_Ru = 44, // Ruthenium + MD_Z_Rh = 45, // Rhodium + MD_Z_Pd = 46, // Palladium + MD_Z_Ag = 47, // Silver + MD_Z_Cd = 48, // Cadmium + MD_Z_In = 49, // Indium + MD_Z_Sn = 50, // Tin + MD_Z_Sb = 51, // Antimony + MD_Z_Te = 52, // Tellurium + MD_Z_I = 53, // Iodine + MD_Z_Xe = 54, // Xenon + MD_Z_Cs = 55, // Caesium + MD_Z_Ba = 56, // Barium + MD_Z_La = 57, // Lanthanum + MD_Z_Ce = 58, // Cerium + MD_Z_Pr = 59, // Praseodymium + MD_Z_Nd = 60, // Neodymium + MD_Z_Pm = 61, // Promethium + MD_Z_Sm = 62, // Samarium + MD_Z_Eu = 63, // Europium + MD_Z_Gd = 64, // Gadolinium + MD_Z_Tb = 65, // Terbium + MD_Z_Dy = 66, // Dysprosium + MD_Z_Ho = 67, // Holmium + MD_Z_Er = 68, // Erbium + MD_Z_Tm = 69, // Thulium + MD_Z_Yb = 70, // Ytterbium + MD_Z_Lu = 71, // Lutetium + MD_Z_Hf = 72, // Hafnium + MD_Z_Ta = 73, // Tantalum + MD_Z_W = 74, // Tungsten + MD_Z_Re = 75, // Rhenium + MD_Z_Os = 76, // Osmium + MD_Z_Ir = 77, // Iridium + MD_Z_Pt = 78, // Platinum + MD_Z_Au = 79, // Gold + MD_Z_Hg = 80, // Mercury + MD_Z_Tl = 81, // Thallium + MD_Z_Pb = 82, // Lead + MD_Z_Bi = 83, // Bismuth + MD_Z_Po = 84, // Polonium + MD_Z_At = 85, // Astatine + MD_Z_Rn = 86, // Radon + MD_Z_Fr = 87, // Francium + MD_Z_Ra = 88, // Radium + MD_Z_Ac = 89, // Actinium + MD_Z_Th = 90, // Thorium + MD_Z_Pa = 91, // Protactinium + MD_Z_U = 92, // Uranium + MD_Z_Np = 93, // Neptunium + MD_Z_Pu = 94, // Plutonium + MD_Z_Am = 95, // Americium + MD_Z_Cm = 96, // Curium + MD_Z_Bk = 97, // Berkelium + MD_Z_Cf = 98, // Californium + MD_Z_Es = 99, // Einsteinium + MD_Z_Fm = 100, // Fermium + MD_Z_Md = 101, // Mendelevium + MD_Z_No = 102, // Nobelium + MD_Z_Lr = 103, // Lawrencium + MD_Z_Rf = 104, // Rutherfordium + MD_Z_Db = 105, // Dubnium + MD_Z_Sg = 106, // Seaborgium + MD_Z_Bh = 107, // Bohrium + MD_Z_Hs = 108, // Hassium + MD_Z_Mt = 109, // Meitnerium + MD_Z_Ds = 110, // Darmstadtium + MD_Z_Rg = 111, // Roentgenium + MD_Z_Cn = 112, // Copernicium + MD_Z_Nh = 113, // Nihonium + MD_Z_Fl = 114, // Flerovium + MD_Z_Mc = 115, // Moscovium + MD_Z_Lv = 116, // Livermorium + MD_Z_Ts = 117, // Tennessine + MD_Z_Og = 118, // Oganesson +}; + +// New preferred API names + +// Element symbol and name lookup functions +md_atomic_number_t md_atomic_number_from_symbol(str_t sym); +md_atomic_number_t md_atomic_number_from_symbol_icase(str_t sym); +str_t md_symbol_from_atomic_number(md_atomic_number_t z); +str_t md_name_from_atomic_number(md_atomic_number_t z); + +// Element property functions +float md_atomic_mass(md_atomic_number_t z); +float md_vdw_radius(md_atomic_number_t z); +float md_covalent_radius(md_atomic_number_t z); +int md_max_valence(md_atomic_number_t z); +uint32_t md_cpk_color(md_atomic_number_t z); + +// Per-atom inference from labels (atom name + residue) +md_atomic_number_t md_atom_infer_atomic_number(str_t atom_name, str_t res_name); + +// Batch form wired to molecule structure +bool md_atoms_infer_atomic_numbers(md_atomic_number_t out[], size_t n, const struct md_molecule_t* mol); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/src/core/md_str.c b/src/core/md_str.c index 7f7342c..5d5395c 100644 --- a/src/core/md_str.c +++ b/src/core/md_str.c @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -42,12 +42,6 @@ bool str_eq_n_ignore_case(const str_t str_a, const str_t str_b, size_t n) { return true; } -bool str_eq_cstr(str_t str, const char* cstr) { - if (!str.ptr || !str.len || !cstr) return false; - if (str.ptr[0] != cstr[0]) return false; - return (strncmp(str.ptr, cstr, str.len) == 0) && cstr[str.len] == '\0'; -} - // Compare str and cstr only up to n characters bool str_eq_cstr_n(str_t str, const char* cstr, size_t n) { if (!n) return false; diff --git a/src/core/md_str.h b/src/core/md_str.h index 232a738..68c11ef 100644 --- a/src/core/md_str.h +++ b/src/core/md_str.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include #include @@ -111,7 +111,12 @@ bool str_eq_n_ignore_case(const str_t str_a, const str_t str_b, size_t n); // Lexicographical comparison int str_cmp_lex(str_t a, str_t b); -bool str_eq_cstr(str_t str, const char* cstr); +static inline bool str_eq_cstr(str_t str, const char* cstr) { + if (!str.ptr || !str.len || !cstr) return false; + if (str.ptr[0] != cstr[0]) return false; + return (strncmp(str.ptr, cstr, str.len) == 0) && cstr[str.len] == '\0'; +} + bool str_eq_cstr_n(str_t str, const char* cstr, size_t n); bool str_eq_cstr_ignore_case(str_t str, const char* cstr); bool str_eq_cstr_n_ignore_case(str_t str, const char* cstr, size_t n); diff --git a/src/md_atomic_infer.c b/src/md_atomic_infer.c new file mode 100644 index 0000000..a28b5b8 --- /dev/null +++ b/src/md_atomic_infer.c @@ -0,0 +1,108 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +// Core atomic number functions using existing md_util tables +md_atomic_number_t md_atomic_number_from_symbol(str_t sym) { + return md_util_element_lookup(sym); +} + +md_atomic_number_t md_atomic_number_from_symbol_icase(str_t sym) { + return md_util_element_lookup_ignore_case(sym); +} + +str_t md_symbol_from_atomic_number(md_atomic_number_t z) { + return md_util_element_symbol(z); +} + +str_t md_name_from_atomic_number(md_atomic_number_t z) { + return md_util_element_name(z); +} + +float md_atomic_mass(md_atomic_number_t z) { + return md_util_element_atomic_mass(z); +} + +float md_vdw_radius(md_atomic_number_t z) { + return md_util_element_vdw_radius(z); +} + +float md_covalent_radius(md_atomic_number_t z) { + return md_util_element_covalent_radius(z); +} + +int md_max_valence(md_atomic_number_t z) { + return md_util_element_max_valence(z); +} + +uint32_t md_cpk_color(md_atomic_number_t z) { + return md_util_element_cpk_color(z); +} + +// Inference functions +md_atomic_number_t md_atom_infer_atomic_number(str_t atom_name, str_t res_name) { + + // Special case: if atom name is empty but residue name is an element (ion case) + if (atom_name.len == 0 && res_name.len > 0) { + md_atomic_number_t res_element = md_atomic_number_from_symbol_icase(res_name); + if (res_element != MD_Z_X) { + return res_element; + } + return MD_Z_X; + } + + if (atom_name.len == 0) return MD_Z_X; + + // First try residue+atom combination + if (res_name.len > 0) { + // Special case: if water, amino acid or nucleotide: Match against first character only + if (md_util_resname_water(res_name) || md_util_resname_amino_acid(res_name) || md_util_resname_nucleotide(res_name)) { + str_t first_char = str_substr(atom_name, 0, 1); + md_atomic_number_t z = md_atomic_number_from_symbol_icase(first_char); + if (z != MD_Z_X) return z; + } + + // If residue name itself is an element (ion case) + md_atomic_number_t res_element = md_atomic_number_from_symbol_icase(res_name); + if (res_element != MD_Z_X) { + // If atom name is empty or equals residue, return that element + if (atom_name.len == 0 || str_eq_ignore_case(atom_name, res_name)) { + return res_element; + } + } + } + + // Try two-letter element heuristic (e.g., CL12 => Cl, BR1 => Br) + if (atom_name.len >= 2) { + str_t two_letter_str = str_substr(atom_name, 0, 2); + md_atomic_number_t two_z = md_atomic_number_from_symbol_icase(two_letter_str); + if (two_z != MD_Z_X) { + return two_z; + } + } + + // Final fallback: first-letter element mapping + str_t first_letter_str = str_substr(atom_name, 0, 1); + return md_atomic_number_from_symbol_icase(first_letter_str); +} + +bool md_atoms_infer_atomic_numbers(md_atomic_number_t out[], size_t n, const struct md_molecule_t* mol) { + if (!out || !mol || n == 0) return false; + + size_t count = MIN(n, mol->atom.count); + for (size_t i = 0; i < count; ++i) { + str_t atom_name = LBL_TO_STR(mol->atom.type[i]); + str_t res_name = mol->atom.resname ? LBL_TO_STR(mol->atom.resname[i]) : (str_t){0}; + out[i] = md_atom_infer_atomic_number(atom_name, res_name); + } + + return true; +} \ No newline at end of file diff --git a/src/md_gro.c b/src/md_gro.c index 33c6114..d266c3f 100644 --- a/src/md_gro.c +++ b/src/md_gro.c @@ -169,6 +169,8 @@ bool md_gro_molecule_init(struct md_molecule_t* mol, const md_gro_data_t* data, mol->atom.y = md_array_create(float, capacity, alloc); mol->atom.z = md_array_create(float, capacity, alloc); mol->atom.type = md_array_create(md_label_t, capacity, alloc); + mol->atom.type_idx = md_array_create(md_atom_type_idx_t, capacity, alloc); + mol->atom.element = md_array_create(md_element_t, capacity, alloc); mol->atom.resid = md_array_create(md_residue_id_t, capacity, alloc); mol->atom.resname = md_array_create(md_label_t, capacity, alloc); @@ -198,6 +200,8 @@ bool md_gro_molecule_init(struct md_molecule_t* mol, const md_gro_data_t* data, mol->atom.y[i] = y; mol->atom.z[i] = z; mol->atom.type[i] = make_label(atom_name); + mol->atom.type_idx[i] = -1; // Initialize to -1, will be set after populating atom type table + mol->atom.element[i] = 0; // Initialize to unknown, will be filled below mol->atom.resid[i] = res_id; mol->atom.resname[i] = make_label(res_name); mol->atom.flags[i] = flags; @@ -214,6 +218,21 @@ bool md_gro_molecule_init(struct md_molecule_t* mol, const md_gro_data_t* data, mol->unit_cell = md_util_unit_cell_from_matrix(box); + // Use hash-backed inference to assign elements (GRO typically lacks explicit element information) + md_util_element_guess(mol->atom.element, mol->atom.count, mol); + + // Now populate the atom type table and assign type indices + for (size_t i = 0; i < mol->atom.count; ++i) { + md_label_t type_name = mol->atom.type[i]; + md_element_t element = mol->atom.element[i]; + float mass = md_util_element_atomic_mass(element); + float radius = md_util_element_vdw_radius(element); + + // Find or add the atom type + md_atom_type_idx_t type_idx = md_atom_type_find_or_add(&mol->atom_type, type_name, element, mass, radius, alloc); + mol->atom.type_idx[i] = type_idx; + } + return true; } diff --git a/src/md_lammps.c b/src/md_lammps.c index 514a35f..b39d4d0 100644 --- a/src/md_lammps.c +++ b/src/md_lammps.c @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -437,7 +437,7 @@ static bool parse_dihedrals(md_lammps_dihedral_t out_dihedrals[], size_t dihedra return true; } -static size_t parse_masses(float* mass_type_table, size_t mass_type_capacity, size_t expected_count, md_buffered_reader_t* reader) { +static size_t parse_masses(double* mass_type_table, size_t mass_type_capacity, size_t expected_count, md_buffered_reader_t* reader) { str_t tok[4]; str_t line; size_t extracted_count = 0; @@ -448,13 +448,13 @@ static size_t parse_masses(float* mass_type_table, size_t mass_type_capacity, si } const size_t num_tok = extract_tokens(tok, ARRAY_SIZE(tok), &line); if (num_tok < 2) { - MD_LOG_ERROR("Failed to parse mass line, expected 2 tokens, got %i", (int)num_tok); + MD_LOG_ERROR("Failed to parse mass line, expected 2 tokens, got %zu", num_tok); return 0; } int type = (int)parse_int(tok[0]); - float mass = (float)parse_float(tok[1]); - if (type >= (int)mass_type_capacity) { - MD_LOG_ERROR("Invalid atom type: %i", (int)num_tok); + double mass = parse_float(tok[1]); + if (type < 0 || (size_t)type >= mass_type_capacity) { + MD_LOG_ERROR("Invalid atom type index in Masses: %d (capacity %zu)", type, mass_type_capacity); return 0; } mass_type_table[type] = mass; @@ -527,7 +527,7 @@ static bool md_lammps_data_parse(md_lammps_data_t* data, md_buffered_reader_t* r MEMSET(data, 0, sizeof(md_lammps_data_t)); - float mass_table[256] = {0}; + double atom_type_mass_table[512] = {0}; str_copy_to_char_buf(data->title, sizeof(data->title), str_trim(line)); @@ -547,11 +547,9 @@ static bool md_lammps_data_parse(md_lammps_data_t* data, md_buffered_reader_t* r // Sort atoms by id qsort(data->atoms, data->num_atoms, sizeof(md_lammps_atom_t), compare_atom); - if (mass_table) { - for (size_t i = 0; i < data->num_atoms; ++i) { - int32_t type = data->atoms[i].type; - data->atoms[i].mass = type < (int)ARRAY_SIZE(mass_table) ? mass_table[type] : 0.0f; - } + for (size_t i = 0; i < data->num_atoms; ++i) { + int32_t type = data->atoms[i].type; + data->atoms[i].mass = type < (int)ARRAY_SIZE(atom_type_mass_table) ? atom_type_mass_table[type] : 0.0f; } } else if (num_tok > 0 && str_eq(tok[0], STR_LIT("Bonds"))) { if (!data->num_bonds) { @@ -589,7 +587,7 @@ static bool md_lammps_data_parse(md_lammps_data_t* data, md_buffered_reader_t* r return false; } md_buffered_reader_skip_line(reader); - if (parse_masses(mass_table, ARRAY_SIZE(mass_table), data->num_atom_types, reader) != data->num_atom_types) { + if (parse_masses(atom_type_mass_table, ARRAY_SIZE(atom_type_mass_table), data->num_atom_types, reader) != data->num_atom_types) { MD_LOG_ERROR("Number of masses in table did not match the number of atom types"); return false; } @@ -702,11 +700,12 @@ bool md_lammps_molecule_init(md_molecule_t* mol, const md_lammps_data_t* data, m MEMSET(mol, 0, sizeof(md_molecule_t)); const size_t capacity = ROUND_UP(data->num_atoms, 16); - md_array_resize(mol->atom.type, capacity, alloc); - md_array_resize(mol->atom.x, capacity, alloc); - md_array_resize(mol->atom.y, capacity, alloc); - md_array_resize(mol->atom.z, capacity, alloc); - md_array_resize(mol->atom.mass, capacity, alloc); + md_array_resize(mol->atom.type, capacity, alloc); + md_array_resize(mol->atom.type_idx, capacity, alloc); + md_array_resize(mol->atom.x, capacity, alloc); + md_array_resize(mol->atom.y, capacity, alloc); + md_array_resize(mol->atom.z, capacity, alloc); + md_array_resize(mol->atom.mass, capacity, alloc); bool has_resid = false; if (data->num_atoms > 0 && data->atoms[0].resid != -1) { @@ -714,8 +713,32 @@ bool md_lammps_molecule_init(md_molecule_t* mol, const md_lammps_data_t* data, m has_resid = true; } + // Prepare for mass→element mapping first + md_array_resize(mol->atom.element, capacity, alloc); + bool mass_to_element_success = false; + + // Build atom type table from LAMMPS types - first pass + for (int lammps_type = 1; lammps_type <= (int)data->num_atom_types; ++lammps_type) { + md_label_t type_name = {0}; + type_name.len = (uint8_t)snprintf(type_name.buf, sizeof(type_name.buf), "%i", lammps_type); + + md_element_t element = 0; + float mass = 0.0f; + float radius = 0.0f; + + // Add to atom type table (will be updated after mass-to-element mapping) + md_atom_type_find_or_add(&mol->atom_type, type_name, element, mass, radius, alloc); + } + for (size_t i = 0; i < data->num_atoms; ++i) { - mol->atom.type[i].len = (uint8_t)snprintf(mol->atom.type[i].buf, sizeof(mol->atom.type[i].buf), "%i", data->atoms[i].type); + int lammps_type = data->atoms[i].type; + + // Set legacy per-atom type name for backward compatibility + mol->atom.type[i].len = (uint8_t)snprintf(mol->atom.type[i].buf, sizeof(mol->atom.type[i].buf), "%i", lammps_type); + + // Set atom type index (LAMMPS types are 1-indexed, array is 0-indexed) + mol->atom.type_idx[i] = lammps_type - 1; + mol->atom.x[i] = data->atoms[i].x - data->cell.xlo; mol->atom.y[i] = data->atoms[i].y - data->cell.ylo; mol->atom.z[i] = data->atoms[i].z - data->cell.zlo; @@ -725,10 +748,23 @@ bool md_lammps_molecule_init(md_molecule_t* mol, const md_lammps_data_t* data, m } } - //Set elements - md_array_resize(mol->atom.element, capacity, alloc); - if (!md_util_element_from_mass(mol->atom.element, mol->atom.mass, data->num_atoms)) { - MD_LOG_ERROR("One or more masses are missing matching element"); + // Try mass-to-element mapping using per-atom masses + mass_to_element_success = md_util_lammps_element_from_mass(mol->atom.element, mol->atom.mass, data->num_atoms); + + if (mass_to_element_success) { + // Update atom type table with masses and elements + for (size_t i = 0; i < data->num_atoms; ++i) { + int lammps_type = data->atoms[i].type; + md_atom_type_idx_t type_idx = lammps_type - 1; + if (type_idx >= 0 && (size_t)type_idx < mol->atom_type.count) { + // Update mass and element in atom type table + mol->atom_type.mass[type_idx] = mol->atom.mass[i]; + mol->atom_type.element[type_idx] = mol->atom.element[i]; + } + } + } else { + // CG/reduced-units detected or mapping failed, leave elements as 0 + MD_LOG_DEBUG("LAMMPS data appears to be coarse-grained or reduced-units, elements left unassigned"); } mol->atom.count = data->num_atoms; @@ -814,12 +850,12 @@ md_molecule_loader_i* md_lammps_molecule_api(void) { //Reads data that is useful later when we want to parse a frame from the trajectory bool lammps_get_header(struct md_trajectory_o* inst, md_trajectory_header_t* header) { - lammps_trajectory_t* dataPtr = (lammps_trajectory_t*)inst; - ASSERT(dataPtr); - ASSERT(dataPtr->magic == MD_LAMMPS_TRAJ_MAGIC); + lammps_trajectory_t* traj = (lammps_trajectory_t*)inst; + ASSERT(traj); + ASSERT(traj->magic == MD_LAMMPS_TRAJ_MAGIC); ASSERT(header); - *header = dataPtr->header; + *header = traj->header; return true; } diff --git a/src/md_lammps.h b/src/md_lammps.h index 9fe640c..9324a68 100644 --- a/src/md_lammps.h +++ b/src/md_lammps.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include #include diff --git a/src/md_mmcif.c b/src/md_mmcif.c index 29df116..9f9896c 100644 --- a/src/md_mmcif.c +++ b/src/md_mmcif.c @@ -148,8 +148,13 @@ static bool mmcif_parse_atom_site(md_atom_data_t* atom, md_buffered_reader_t* re int32_t entity_id = -1; str_t sym = tok[table[ATOM_SITE_TYPE_SYMBOL]]; - uint32_t* entry = md_hashmap_get(&elem_map, md_hash64(sym.ptr, sym.len, 0)); - md_element_t elem = entry ? (md_element_t)*entry : 0; + md_element_t elem = 0; + + // Prefer _atom_site.type_symbol if not '.' + if (sym.len > 0 && sym.ptr[0] != '.') { + uint32_t* entry = md_hashmap_get(&elem_map, md_hash64(sym.ptr, sym.len, 0)); + elem = entry ? (md_element_t)*entry : 0; + } md_label_t type = make_label(tok[table[ATOM_SITE_LABEL_ATOM_ID]]); @@ -204,6 +209,7 @@ static bool mmcif_parse_atom_site(md_atom_data_t* atom, md_buffered_reader_t* re md_array_push(atom->resid, res_id, alloc); md_array_push(atom->resname, resname, alloc); md_array_push(atom->chainid, chain_id, alloc); + md_array_push(atom->type_idx, -1, alloc); // Initialize to -1, will be set after populating atom type table num_atoms += 1; next: @@ -217,6 +223,7 @@ static bool mmcif_parse_atom_site(md_atom_data_t* atom, md_buffered_reader_t* re size_t capacity = ROUND_UP(num_atoms, 16); md_array_ensure(atom->element, capacity, alloc); md_array_ensure(atom->type, capacity, alloc); + md_array_ensure(atom->type_idx, capacity, alloc); md_array_ensure(atom->x, capacity, alloc); md_array_ensure(atom->y, capacity, alloc); md_array_ensure(atom->z, capacity, alloc); @@ -226,6 +233,8 @@ static bool mmcif_parse_atom_site(md_atom_data_t* atom, md_buffered_reader_t* re md_array_ensure(atom->chainid, capacity, alloc); } + + atom->count = num_atoms; done: md_temp_set_pos_back(temp_pos); @@ -299,16 +308,64 @@ static bool mmcif_parse(md_molecule_t* mol, md_buffered_reader_t* reader, md_all return false; } atom_site_found = true; + // Sub-parser already consumed its lines, continue without skipping + continue; } else if (str_eq_cstr_n(line, "_cell.", 6)) { if (!mmcif_parse_cell(&mol->unit_cell, reader)) { MD_LOG_ERROR("Failed to parse _cell"); return false; } + // Sub-parser already consumed its lines, continue without skipping + continue; } } md_buffered_reader_skip_line(reader); } + // Fill in missing elements using hash-backed inference if atoms were found + if (atom_site_found && mol->atom.count > 0) { + for (size_t i = 0; i < mol->atom.count; ++i) { + if (mol->atom.element[i] == 0) { + // Use hash-backed inference for missing elements + str_t atom_label = LBL_TO_STR(mol->atom.type[i]); + + // Try direct element lookup first + md_element_t elem = md_util_element_lookup_ignore_case(atom_label); + + // If that fails, use the hash-backed inference from md_util_element_guess + if (elem == 0) { + // Create a temporary molecule structure for the single atom to use element_guess + md_molecule_t temp_mol = {0}; + temp_mol.atom.count = 1; + temp_mol.atom.type = &mol->atom.type[i]; + temp_mol.atom.resname = &mol->atom.resname[i]; + temp_mol.atom.flags = mol->atom.flags ? &mol->atom.flags[i] : NULL; + + md_element_t temp_element = 0; + if (md_util_element_guess(&temp_element, 1, &temp_mol)) { + elem = temp_element; + } + } + + mol->atom.element[i] = elem; + } + } + } + + // Populate atom type table and assign type indices if atoms were found + if (atom_site_found && mol->atom.count > 0) { + for (size_t i = 0; i < mol->atom.count; ++i) { + md_label_t type_name = mol->atom.type[i]; + md_element_t element = mol->atom.element[i]; + float mass = md_util_element_atomic_mass(element); + float radius = md_util_element_vdw_radius(element); + + // Find or add the atom type + md_atom_type_idx_t type_idx = md_atom_type_find_or_add(&mol->atom_type, type_name, element, mass, radius, alloc); + mol->atom.type_idx[i] = type_idx; + } + } + return atom_site_found; } diff --git a/src/md_molecule.c b/src/md_molecule.c index 8d4bd03..18fcd4f 100644 --- a/src/md_molecule.c +++ b/src/md_molecule.c @@ -15,14 +15,24 @@ void md_molecule_free(md_molecule_t* mol, struct md_allocator_i* alloc) { if (mol->atom.x) md_array_free(mol->atom.x, alloc); if (mol->atom.y) md_array_free(mol->atom.y, alloc); if (mol->atom.z) md_array_free(mol->atom.z, alloc); + if (mol->atom.type_idx) md_array_free(mol->atom.type_idx, alloc); if (mol->atom.radius) md_array_free(mol->atom.radius, alloc); if (mol->atom.mass) md_array_free(mol->atom.mass, alloc); if (mol->atom.element) md_array_free(mol->atom.element, alloc); + if (mol->atom.type) md_array_free(mol->atom.type, alloc); if (mol->atom.resid) md_array_free(mol->atom.resid, alloc); if (mol->atom.resname) md_array_free(mol->atom.resname, alloc); if (mol->atom.chainid) md_array_free(mol->atom.chainid, alloc); + if (mol->atom.res_idx) md_array_free(mol->atom.res_idx, alloc); + if (mol->atom.chain_idx) md_array_free(mol->atom.chain_idx, alloc); if (mol->atom.flags) md_array_free(mol->atom.flags, alloc); + // Atom Type + if (mol->atom_type.name) md_array_free(mol->atom_type.name, alloc); + if (mol->atom_type.element) md_array_free(mol->atom_type.element, alloc); + if (mol->atom_type.mass) md_array_free(mol->atom_type.mass, alloc); + if (mol->atom_type.radius) md_array_free(mol->atom_type.radius, alloc); + // Residue if (mol->residue.name) md_array_free(mol->residue.name, alloc); if (mol->residue.id) md_array_free(mol->residue.id, alloc); @@ -35,18 +45,26 @@ void md_molecule_free(md_molecule_t* mol, struct md_allocator_i* alloc) { // Backbone if (mol->protein_backbone.range.offset) md_array_free(mol->protein_backbone.range.offset, alloc); + if (mol->protein_backbone.range.chain_idx) md_array_free(mol->protein_backbone.range.chain_idx, alloc); if (mol->protein_backbone.atoms) md_array_free(mol->protein_backbone.atoms, alloc); if (mol->protein_backbone.angle) md_array_free(mol->protein_backbone.angle, alloc); if (mol->protein_backbone.secondary_structure) md_array_free(mol->protein_backbone.secondary_structure, alloc); if (mol->protein_backbone.ramachandran_type) md_array_free(mol->protein_backbone.ramachandran_type, alloc); if (mol->protein_backbone.residue_idx) md_array_free(mol->protein_backbone.residue_idx, alloc); + // Nucleic Backbone + if (mol->nucleic_backbone.range.offset) md_array_free(mol->nucleic_backbone.range.offset, alloc); + if (mol->nucleic_backbone.range.chain_idx) md_array_free(mol->nucleic_backbone.range.chain_idx, alloc); + if (mol->nucleic_backbone.atoms) md_array_free(mol->nucleic_backbone.atoms, alloc); + if (mol->nucleic_backbone.residue_idx) md_array_free(mol->nucleic_backbone.residue_idx, alloc); + // Bonds if (mol->bond.pairs) md_array_free(mol->bond.pairs, alloc); if (mol->bond.order) md_array_free(mol->bond.order, alloc); if (mol->bond.conn.atom_idx) md_array_free(mol->bond.conn.atom_idx, alloc); if (mol->bond.conn.bond_idx) md_array_free(mol->bond.conn.bond_idx, alloc); + if (mol->bond.conn.offset) md_array_free(mol->bond.conn.offset, alloc); md_index_data_free(&mol->structure); md_index_data_free(&mol->ring); @@ -81,6 +99,7 @@ void md_molecule_copy(md_molecule_t* dst, const md_molecule_t* src, struct md_al ARRAY_PUSH(atom, x); ARRAY_PUSH(atom, y); ARRAY_PUSH(atom, z); + ARRAY_PUSH(atom, type_idx); ARRAY_PUSH(atom, radius); ARRAY_PUSH(atom, mass); ARRAY_PUSH(atom, element); @@ -89,6 +108,13 @@ void md_molecule_copy(md_molecule_t* dst, const md_molecule_t* src, struct md_al ARRAY_PUSH(atom, resid); ARRAY_PUSH(atom, resname); ARRAY_PUSH(atom, chainid); + ARRAY_PUSH(atom, res_idx); + ARRAY_PUSH(atom, chain_idx); + + ARRAY_PUSH(atom_type, name); + ARRAY_PUSH(atom_type, element); + ARRAY_PUSH(atom_type, mass); + ARRAY_PUSH(atom_type, radius); ARRAY_PUSH(protein_backbone, atoms); ARRAY_PUSH(protein_backbone, angle); @@ -97,22 +123,41 @@ void md_molecule_copy(md_molecule_t* dst, const md_molecule_t* src, struct md_al ARRAY_PUSH(protein_backbone, residue_idx); md_array_push_array(dst->protein_backbone.range.offset, src->protein_backbone.range.offset, src->protein_backbone.range.count, alloc); + md_array_push_array(dst->protein_backbone.range.chain_idx, src->protein_backbone.range.chain_idx, src->protein_backbone.range.count, alloc); + + ARRAY_PUSH(nucleic_backbone, atoms); + ARRAY_PUSH(nucleic_backbone, residue_idx); + + md_array_push_array(dst->nucleic_backbone.range.offset, src->nucleic_backbone.range.offset, src->nucleic_backbone.range.count, alloc); + md_array_push_array(dst->nucleic_backbone.range.chain_idx, src->nucleic_backbone.range.chain_idx, src->nucleic_backbone.range.count, alloc); ARRAY_PUSH(chain, id); ARRAY_PUSH(chain, res_range); ARRAY_PUSH(chain, atom_range); md_array_push_array(dst->bond.pairs, src->bond.pairs, src->bond.count, alloc); + md_array_push_array(dst->bond.order, src->bond.order, src->bond.count, alloc); + md_array_push_array(dst->bond.conn.atom_idx, src->bond.conn.atom_idx, src->bond.conn.count, alloc); + md_array_push_array(dst->bond.conn.bond_idx, src->bond.conn.bond_idx, src->bond.conn.count, alloc); + md_array_push_array(dst->bond.conn.offset, src->bond.conn.offset, src->bond.conn.offset_count, alloc); ARRAY_PUSH(residue, name); ARRAY_PUSH(residue, id); ARRAY_PUSH(residue, atom_offset); + ARRAY_PUSH(residue, flags); dst->atom.count = src->atom.count; + dst->atom_type.count = src->atom_type.count; dst->protein_backbone.count = src->protein_backbone.count; dst->protein_backbone.range.count = src->protein_backbone.range.count; + dst->nucleic_backbone.count = src->nucleic_backbone.count; + dst->nucleic_backbone.range.count = src->nucleic_backbone.range.count; dst->chain.count = src->chain.count; dst->residue.count = src->residue.count; + dst->bond.count = src->bond.count; + dst->bond.conn.count = src->bond.conn.count; + dst->bond.conn.offset_count = src->bond.conn.offset_count; + dst->unit_cell = src->unit_cell; } #undef ARRAY_PUSH diff --git a/src/md_molecule.h b/src/md_molecule.h index 4e6b1cb..d2b2791 100644 --- a/src/md_molecule.h +++ b/src/md_molecule.h @@ -22,6 +22,9 @@ typedef struct md_atom_data_t { float* y; float* z; + // Atom Type Index (NEW: references into atom_type table) + md_atom_type_idx_t* type_idx; + // Atom Type Specific (@TODO: Compress this into a smaller set) float* radius; float* mass; @@ -125,6 +128,7 @@ typedef struct md_bond_iter_t { typedef struct md_molecule_t { md_unit_cell_t unit_cell; md_atom_data_t atom; + md_atom_type_data_t atom_type; md_residue_data_t residue; md_chain_data_t chain; md_protein_backbone_data_t protein_backbone; @@ -154,6 +158,85 @@ static inline vec3_t md_atom_coord(md_atom_data_t atom_data, size_t atom_idx) { return vec3_set(atom_data.x[atom_idx], atom_data.y[atom_idx], atom_data.z[atom_idx]); } +// Atom type table helper functions +static inline md_atom_type_idx_t md_atom_type_find_or_add(md_atom_type_data_t* atom_type, md_label_t name, md_element_t element, float mass, float radius, struct md_allocator_i* alloc) { + ASSERT(atom_type); + ASSERT(alloc); + + // First try to find existing atom type + for (size_t i = 0; i < atom_type->count; ++i) { + if (MEMCMP(&atom_type->name[i], &name, sizeof(md_label_t)) == 0 && + atom_type->element[i] == element && + atom_type->mass[i] == mass && + atom_type->radius[i] == radius) { + return (md_atom_type_idx_t)i; + } + } + + // Add new atom type + md_array_push(atom_type->name, name, alloc); + md_array_push(atom_type->element, element, alloc); + md_array_push(atom_type->mass, mass, alloc); + md_array_push(atom_type->radius, radius, alloc); + atom_type->count++; + + return (md_atom_type_idx_t)(atom_type->count - 1); +} + +static inline md_element_t md_atom_get_element(const md_molecule_t* mol, size_t atom_idx) { + ASSERT(mol); + ASSERT(atom_idx < mol->atom.count); + + // Try atom type table first if type_idx is available + if (mol->atom.type_idx && mol->atom.type_idx[atom_idx] >= 0 && + (size_t)mol->atom.type_idx[atom_idx] < mol->atom_type.count) { + return mol->atom_type.element[mol->atom.type_idx[atom_idx]]; + } + + // Fallback to per-atom element if available + if (mol->atom.element) { + return mol->atom.element[atom_idx]; + } + + return 0; +} + +static inline float md_atom_get_mass(const md_molecule_t* mol, size_t atom_idx) { + ASSERT(mol); + ASSERT(atom_idx < mol->atom.count); + + // Try atom type table first if type_idx is available + if (mol->atom.type_idx && mol->atom.type_idx[atom_idx] >= 0 && + (size_t)mol->atom.type_idx[atom_idx] < mol->atom_type.count) { + return mol->atom_type.mass[mol->atom.type_idx[atom_idx]]; + } + + // Fallback to per-atom mass if available + if (mol->atom.mass) { + return mol->atom.mass[atom_idx]; + } + + return 0.0f; +} + +static inline float md_atom_get_radius(const md_molecule_t* mol, size_t atom_idx) { + ASSERT(mol); + ASSERT(atom_idx < mol->atom.count); + + // Try atom type table first if type_idx is available + if (mol->atom.type_idx && mol->atom.type_idx[atom_idx] >= 0 && + (size_t)mol->atom.type_idx[atom_idx] < mol->atom_type.count) { + return mol->atom_type.radius[mol->atom.type_idx[atom_idx]]; + } + + // Fallback to per-atom radius if available + if (mol->atom.radius) { + return mol->atom.radius[atom_idx]; + } + + return 0.0f; +} + static inline md_range_t md_residue_atom_range(md_residue_data_t res, size_t res_idx) { md_range_t range = {0}; if (res.atom_offset && res_idx < res.count) { @@ -206,6 +289,37 @@ static inline size_t md_chain_atom_count(md_chain_data_t chain, size_t chain_idx return count; } +// Convenience functions to extract atom properties into arrays +static inline void md_atom_extract_radii(float out_radii[], size_t n, const md_molecule_t* mol) { + ASSERT(out_radii); + ASSERT(mol); + ASSERT(n <= mol->atom.count); + + for (size_t i = 0; i < n; ++i) { + out_radii[i] = md_atom_get_radius(mol, i); + } +} + +static inline void md_atom_extract_masses(float out_masses[], size_t n, const md_molecule_t* mol) { + ASSERT(out_masses); + ASSERT(mol); + ASSERT(n <= mol->atom.count); + + for (size_t i = 0; i < n; ++i) { + out_masses[i] = md_atom_get_mass(mol, i); + } +} + +static inline void md_atom_extract_elements(md_element_t out_elements[], size_t n, const md_molecule_t* mol) { + ASSERT(out_elements); + ASSERT(mol); + ASSERT(n <= mol->atom.count); + + for (size_t i = 0; i < n; ++i) { + out_elements[i] = md_atom_get_element(mol, i); + } +} + static inline md_bond_iter_t md_bond_iter(const md_bond_data_t* bond_data, size_t atom_idx) { md_bond_iter_t it = {0}; if (bond_data && bond_data->conn.offset && atom_idx < bond_data->conn.offset_count) { diff --git a/src/md_pdb.c b/src/md_pdb.c index f7aea57..8ddc9cb 100644 --- a/src/md_pdb.c +++ b/src/md_pdb.c @@ -479,6 +479,7 @@ bool md_pdb_molecule_init(md_molecule_t* mol, const md_pdb_data_t* data, md_pdb_ md_array_ensure(mol->atom.z, capacity, alloc); md_array_ensure(mol->atom.element, capacity, alloc); md_array_ensure(mol->atom.type, capacity, alloc); + md_array_ensure(mol->atom.type_idx, capacity, alloc); md_array_ensure(mol->atom.resid, capacity, alloc); md_array_ensure(mol->atom.resname, capacity, alloc); md_array_ensure(mol->atom.flags, capacity, alloc); @@ -550,6 +551,9 @@ bool md_pdb_molecule_init(md_molecule_t* mol, const md_pdb_data_t* data, md_pdb_ md_array_push_no_grow(mol->atom.flags, flags); md_array_push_no_grow(mol->atom.resname, make_label(res_name)); md_array_push_no_grow(mol->atom.resid, res_id); + + // Set type_idx to -1 initially, will be updated after populating atom type table + md_array_push_no_grow(mol->atom.type_idx, -1); } if (chain_ids) { @@ -666,6 +670,47 @@ bool md_pdb_molecule_init(md_molecule_t* mol, const md_pdb_data_t* data, md_pdb_ ASSERT(md_array_size(mol->instance.label) == mol->instance.count); ASSERT(md_array_size(mol->instance.atom_range) == mol->instance.count); + // Fill in missing elements using hash-backed inference + for (size_t i = 0; i < mol->atom.count; ++i) { + if (mol->atom.element[i] == 0) { + // Use hash-backed inference for missing elements + str_t atom_name = LBL_TO_STR(mol->atom.type[i]); + + // Try direct element lookup first + md_element_t elem = md_util_element_lookup_ignore_case(atom_name); + + // If that fails, use the hash-backed inference from md_util_element_guess + // This handles cases like CA (carbon alpha vs calcium) using residue context + if (elem == 0) { + // Create a temporary molecule structure for the single atom to use element_guess + md_molecule_t temp_mol = {0}; + temp_mol.atom.count = 1; + temp_mol.atom.type = &mol->atom.type[i]; + temp_mol.atom.resname = &mol->atom.resname[i]; + temp_mol.atom.flags = mol->atom.flags ? &mol->atom.flags[i] : NULL; + + md_element_t temp_element = 0; + if (md_util_element_guess(&temp_element, 1, &temp_mol)) { + elem = temp_element; + } + } + + mol->atom.element[i] = elem; + } + } + + // Now populate the atom type table and assign type indices + for (size_t i = 0; i < mol->atom.count; ++i) { + md_label_t type_name = mol->atom.type[i]; + md_element_t element = mol->atom.element[i]; + float mass = md_util_element_atomic_mass(element); + float radius = md_util_element_vdw_radius(element); + + // Find or add the atom type + md_atom_type_idx_t type_idx = md_atom_type_find_or_add(&mol->atom_type, type_name, element, mass, radius, alloc); + mol->atom.type_idx[i] = type_idx; + } + result = true; done: md_vm_arena_destroy(temp_alloc); diff --git a/src/md_types.h b/src/md_types.h index 1364e62..0f52d92 100644 --- a/src/md_types.h +++ b/src/md_types.h @@ -75,6 +75,7 @@ typedef int32_t md_backbone_idx_t; typedef int32_t md_residue_id_t; typedef int32_t md_chain_idx_t; typedef int32_t md_bond_idx_t; +typedef int32_t md_atom_type_idx_t; typedef uint32_t md_secondary_structure_t; typedef uint32_t md_flags_t; typedef uint8_t md_element_t; diff --git a/src/md_util.c b/src/md_util.c index 15a6129..02715a7 100644 --- a/src/md_util.c +++ b/src/md_util.c @@ -357,7 +357,7 @@ static const char* acidic[] = { "ASP", "GLU" }; static const char* basic[] = { "ARG", "HIS", "LYS" }; static const char* neutral[] = { "VAL", "PHE", "GLN", "TYR", "HIS", "CYS", "MET", "TRP", "ASX", "GLX", "PCA", "HYP" }; -static const char* water[] = { "H2O", "HHO", "OHH", "HOH", "OH2", "SOL", "WAT", "TIP", "TIP2", "TIP3", "TIP4", "W", "DOD", "D30" }; +static const char* water[] = { "H2O", "HHO", "OHH", "HOH", "OH2", "SOL", "WAT", "TIP", "TIP2", "TIP3", "TIP4", "TIP5", "W", "DOD", "D30", "SPC" }; static const char* hydrophobic[] = { "ALA", "VAL", "ILE", "LEU", "MET", "PHE", "TYR", "TRP", "CYX" }; // Taken from here @@ -1013,7 +1013,7 @@ bool md_util_resname_dna(str_t str) { return find_str_in_cstr_arr(NULL, str, dna, ARRAY_SIZE(dna)); } -bool md_util_resname_nucleic_acid(str_t str) { +bool md_util_resname_nucleotide(str_t str) { str = trim_label(str); return find_str_in_cstr_arr(NULL, str, rna, ARRAY_SIZE(rna)) || find_str_in_cstr_arr(NULL, str, dna, ARRAY_SIZE(dna)); } @@ -1240,143 +1240,8 @@ static inline bool is_organic(char c) { } bool md_util_element_guess(md_element_t element[], size_t capacity, const struct md_molecule_t* mol) { - ASSERT(capacity > 0); - ASSERT(mol); - ASSERT(mol->atom.count > 0); - - md_hashmap32_t map = { .allocator = md_get_temp_allocator() }; - md_hashmap_reserve(&map, 256); - - // Just for pure elements which have not been salted with resname - md_hashmap32_t elem_map = { .allocator = md_get_temp_allocator() }; - md_hashmap_reserve(&elem_map, 256); - - typedef struct { - str_t name; - md_element_t elem; - } entry_t; - - // Extra table for predefined atom types - entry_t entries[] = { - {STR_LIT("SOD"), Na}, - {STR_LIT("OW"), O}, - {STR_LIT("HW"), H}, - }; - - for (size_t i = 0; i < ARRAY_SIZE(entries); ++i) { - md_hashmap_add(&elem_map, md_hash64(entries[i].name.ptr, entries[i].name.len, 0), entries[i].elem); - } - - const size_t count = MIN(capacity, mol->atom.count); - for (size_t i = 0; i < count; ++i) { - if (element[i] != 0) continue; - - str_t original = LBL_TO_STR(mol->atom.type[i]); - - // Trim whitespace, digits and 'X's - str_t name = trim_label(original); - - if (name.len > 0) { - md_element_t elem = 0; - - str_t resname = STR_LIT(""); - uint64_t res_key = 0; - if (mol->atom.resname) { - resname = LBL_TO_STR(mol->atom.resname[i]); - res_key = md_hash64_str(resname, 0); - } - uint64_t key = md_hash64_str(name, res_key); - uint32_t* ptr = md_hashmap_get(&map, key); - if (ptr) { - element[i] = (md_element_t)*ptr; - continue; - } else { - uint64_t elem_key = md_hash64(name.ptr, name.len, 0); - ptr = md_hashmap_get(&elem_map, elem_key); - if (ptr) { - elem = (md_element_t)*ptr; - goto done; - } - } - - if ((elem = md_util_element_lookup(name)) != 0) goto done; - - // If amino acid, try to deduce the element from that - if (mol->atom.flags) { - if (mol->atom.flags[i] & (MD_FLAG_AMINO_ACID | MD_FLAG_NUCLEOTIDE)) { - // Try to match against the first character - name.len = 1; - elem = md_util_element_lookup_ignore_case(name); - goto done; - } - } - - // This is the same logic as above but more general, for the natural organic elements - if (is_organic(name.ptr[0]) && name.len > 1) { - if (name.ptr[1] - 'A' < 5) { - if (mol->residue.count > 0 && mol->atom.res_idx) { - int32_t res_idx = mol->atom.res_idx[i]; - uint32_t res_beg = mol->residue.atom_offset[res_idx]; - uint32_t res_end = mol->residue.atom_offset[res_idx+1]; - uint32_t res_len = res_end - res_beg; - if (res_len > 3) { - name.len = 1; - elem = md_util_element_lookup_ignore_case(name); - goto done; - } - } - } - } - - // Heuristic cases - - // This can be fishy... - if (str_eq_cstr(name, "HOH")) { - elem = H; - goto done; - } - if (str_eq_cstr(name, "HS")) { - elem = H; - goto done; - } - - size_t num_alpha = 0; - while (num_alpha < str_len(original) && is_alpha(original.ptr[num_alpha])) ++num_alpha; - - size_t num_digits = 0; - str_t digits = str_substr(original, num_alpha, SIZE_MAX); - while (num_digits < str_len(digits) && is_digit(digits.ptr[num_digits])) ++num_digits; - - // 2-3 letters + 1-2 digit (e.g. HO(H)[0-99]) usually means just look at the first letter - if ((num_alpha == 2 || num_alpha == 3) && (num_digits == 1 || num_digits == 2)) { - name.len = 1; - elem = md_util_element_lookup_ignore_case(name); - goto done; - } - - // Try to match against several characters but ignore the case - if (name.len > 1) { - name.len = 2; - elem = md_util_element_lookup_ignore_case(name); - } - - // Last resort, try to match against single first character - if (elem == 0) { - name.len = 1; - elem = md_util_element_lookup_ignore_case(name); - } - - done: - element[i] = elem; - if (elem != 0) { - md_hashmap_add(&map, key, elem); - } - } - } - - md_hashmap_free(&map); - - return true; + // Delegate to the new hash-backed atomic number inference system + return md_atoms_infer_atomic_numbers(element, capacity, mol); } bool md_util_element_from_mass(md_element_t element[], const float mass[], size_t count) { @@ -1424,6 +1289,138 @@ bool md_util_element_from_mass(md_element_t element[], const float mass[], size_ } } +bool md_util_lammps_element_from_mass(md_element_t out_element[], const float in_mass[], size_t count) { + if (!out_element) { + MD_LOG_ERROR("out_element is null"); + return false; + } + if (count > 0 && !in_mass) { + MD_LOG_ERROR("in_mass is null"); + return false; + } + if (count == 0) { + return true; + } + + // Initialize output to zero (unknown element) + for (size_t i = 0; i < count; ++i) { + out_element[i] = 0; + } + + // Standard atomic masses for common elements used in all-atom simulations + typedef struct { + md_element_t element; + float mass; + float tolerance; + } element_mass_entry_t; + + static const element_mass_entry_t standard_masses[] = { + {H, 1.008f, 0.2f}, // Hydrogen: Z <= 10, tolerance 0.2 amu + {C, 12.011f, 0.2f}, // Carbon + {N, 14.007f, 0.2f}, // Nitrogen + {O, 15.999f, 0.2f}, // Oxygen + {F, 18.998f, 0.2f}, // Fluorine + {Na, 22.990f, 0.3f}, // Sodium: Z > 10, <= 20, tolerance 0.3 amu + {Mg, 24.305f, 0.3f}, // Magnesium + {P, 30.974f, 0.3f}, // Phosphorus + {S, 32.06f, 0.3f}, // Sulfur + {Cl, 35.45f, 0.3f}, // Chlorine + {K, 39.098f, 0.3f}, // Potassium + {Ca, 40.078f, 0.3f}, // Calcium + {Fe, 55.845f, 0.5f}, // Iron: Z > 20, tolerance 0.5 amu + {Zn, 65.38f, 0.5f}, // Zinc + {Br, 79.904f, 0.5f}, // Bromine + {I, 126.904f, 0.5f}, // Iodine + }; + static const size_t num_standard_masses = sizeof(standard_masses) / sizeof(standard_masses[0]); + + // CG/reduced-units detection heuristics + + // Count how many unique masses we have + float unique_masses[256]; + size_t unique_count = 0; + + for (size_t i = 0; i < count; ++i) { + const float mass = in_mass[i]; + if (mass <= 0.0f) continue; // Skip invalid masses + + // Check if this mass is already in our unique list + bool found = false; + for (size_t j = 0; j < unique_count; ++j) { + if (fabsf(mass - unique_masses[j]) < 1e-6f) { + found = true; + break; + } + } + + if (!found && unique_count < sizeof(unique_masses)/sizeof(unique_masses[0])) { + unique_masses[unique_count++] = mass; + } + } + + // Heuristic 1: Too few unique masses for the number of atoms suggests CG + if (unique_count < 3 && count > 10) { + return false; // Likely CG, skip mapping + } + + // Heuristic 2: Check for CG-like mass values (e.g., 72.0, 1.0, etc.) + for (size_t i = 0; i < unique_count; ++i) { + const float mass = unique_masses[i]; + // Check for typical CG mass values - be careful not to catch legitimate hydrogen masses + if (fabsf(mass - 72.0f) < 0.001f || // Common CG mass + fabsf(mass - 36.0f) < 0.001f || // Common CG mass + mass > 200.0f) { // Unrealistically heavy for typical atoms + return false; // Likely CG/reduced units, skip mapping + } + // Check for reduced units (exactly 1.0, not hydrogen mass around 1.008) + if (fabsf(mass - 1.0f) < 1e-6f) { // Much stricter tolerance for exactly 1.0 + return false; // Likely reduced units + } + } + + // Perform conservative mass→element mapping + size_t successful_mappings = 0; + + for (size_t i = 0; i < count; ++i) { + const float mass = in_mass[i]; + if (mass <= 0.0f) continue; + + md_element_t best_element = 0; + float best_diff = FLT_MAX; + size_t num_candidates = 0; + + // Find the best matching element + for (size_t j = 0; j < num_standard_masses; ++j) { + const float diff = fabsf(mass - standard_masses[j].mass); + if (diff <= standard_masses[j].tolerance) { + num_candidates++; + if (diff < best_diff) { + best_diff = diff; + best_element = standard_masses[j].element; + } + } + } + + // Only assign if we have exactly one candidate (no ambiguity) + if (num_candidates == 1) { + out_element[i] = best_element; + successful_mappings++; + } + // If multiple elements match within tolerance, leave as 0 (ambiguous) + } + + // If we couldn't map a reasonable fraction, might be CG + if (successful_mappings == 0 || successful_mappings < (count + 2) / 3) { // At least 33% success rate + // Reset all to 0 and return false + for (size_t i = 0; i < count; ++i) { + out_element[i] = 0; + } + return false; + } + + return true; +} + const str_t* md_util_element_symbols(void) { return element_symbols; } @@ -8213,20 +8210,8 @@ bool md_util_molecule_postprocess(md_molecule_t* mol, md_allocator_i* alloc, md_ } } - if (flags & MD_UTIL_POSTPROCESS_ELEMENT_BIT) { -#ifdef PROFILE - md_timestamp_t t0 = md_time_current(); -#endif - if (!mol->atom.element) { - md_array_resize(mol->atom.element, mol->atom.count, alloc); - MEMSET(mol->atom.element, 0, md_array_bytes(mol->atom.element)); - } - md_util_element_guess(mol->atom.element, mol->atom.count, mol); -#ifdef PROFILE - md_timestamp_t t1 = md_time_current(); - MD_LOG_DEBUG("Postprocess: guess elements %.3f ms\n", md_time_as_milliseconds(t1-t0)); -#endif - } + // Element inference is now handled within each parser during parsing, not in postprocessing + // This ensures atom type table is populated correctly and eliminates dependency on per-atom element field if (flags & MD_UTIL_POSTPROCESS_RADIUS_BIT) { #ifdef PROFILE diff --git a/src/md_util.h b/src/md_util.h index 2a98aa3..7e730dd 100644 --- a/src/md_util.h +++ b/src/md_util.h @@ -1,6 +1,7 @@ -#pragma once +#pragma once #include +#include #include #include @@ -31,20 +32,16 @@ enum { typedef uint32_t md_util_postprocess_flags_t; -// This assumes the string exactly matches the value within the look up table -// The match is case sensitive and expects elements to be formatted with Big first letter and small second letter: -// E.g. H, He, Fe, Na, C -md_element_t md_util_element_lookup(str_t element_str); -md_element_t md_util_element_lookup_ignore_case(str_t element_str); - -// Access to the static arrays +// Access to the static arrays (preserved for direct access) const str_t* md_util_element_symbols(void); const str_t* md_util_element_names(void); const float* md_util_element_vdw_radii(void); +// Element functions (now calling new atomic number API internally) +md_element_t md_util_element_lookup(str_t element_str); +md_element_t md_util_element_lookup_ignore_case(str_t element_str); str_t md_util_element_symbol(md_element_t element); str_t md_util_element_name(md_element_t element); - float md_util_element_vdw_radius(md_element_t element); float md_util_element_covalent_radius(md_element_t element); float md_util_element_atomic_mass(md_element_t element); @@ -52,24 +49,28 @@ int md_util_element_max_valence(md_element_t element); uint32_t md_util_element_cpk_color(md_element_t element); bool md_util_resname_dna(str_t str); +bool md_util_resname_rna(str_t str); bool md_util_resname_acidic(str_t str); bool md_util_resname_basic(str_t str); bool md_util_resname_neutral(str_t str); bool md_util_resname_water(str_t str); bool md_util_resname_hydrophobic(str_t str); bool md_util_resname_amino_acid(str_t str); +bool md_util_resname_nucleotide(str_t str); static inline bool md_util_backbone_atoms_valid(md_protein_backbone_atoms_t prot) { return (prot.ca != prot.c) && (prot.ca != prot.o) && (prot.c != prot.o); } -// This operation tries to deduce the element from the atom type/name which usually contains alot of cruft. -// It also tries resolve some ambiguities: Such as CA, is that Carbon Alpha or is it calcium? -// We can resolve that by looking at the residue name and in the case of Carbon Alpha, the residue name should be matched to an amino acid. +// Element guess function - now delegates to the new inference system bool md_util_element_guess(md_element_t element[], size_t capacity, const struct md_molecule_t* mol); bool md_util_element_from_mass(md_element_t out_element[], const float in_mass[], size_t count); +// Conservative mass→element mapping for LAMMPS with CG/reduced-units detection +// Returns false if data appears to be CG or reduced-units (skips mapping) +bool md_util_lammps_element_from_mass(md_element_t out_element[], const float in_mass[], size_t count); + // Computes secondary structures from backbone atoms // Does not allocate any data, it assumes that secondary_structures has the same length as mol.backbone.count bool md_util_backbone_secondary_structure_compute(md_secondary_structure_t secondary_structures[], size_t capacity, const struct md_molecule_t* mol); diff --git a/src/md_xyz.c b/src/md_xyz.c index c9a0fe3..07fb8d9 100644 --- a/src/md_xyz.c +++ b/src/md_xyz.c @@ -739,6 +739,7 @@ bool md_xyz_molecule_init(md_molecule_t* mol, const md_xyz_data_t* data, struct md_array_ensure(mol->atom.y, num_atoms, alloc); md_array_ensure(mol->atom.z, num_atoms, alloc); md_array_ensure(mol->atom.element, num_atoms, alloc); + md_array_ensure(mol->atom.type_idx, num_atoms, alloc); for (size_t i = beg_coord_index; i < end_coord_index; ++i) { float x = data->coordinates[i].x; @@ -754,6 +755,19 @@ bool md_xyz_molecule_init(md_molecule_t* mol, const md_xyz_data_t* data, struct md_array_push(mol->atom.element, element, alloc); md_array_push(mol->atom.type, make_label(atom_type), alloc); md_array_push(mol->atom.flags, 0, alloc); + md_array_push(mol->atom.type_idx, -1, alloc); // Will be set after populating atom type table + } + + // Populate atom type table and assign type indices + for (size_t i = 0; i < mol->atom.count; ++i) { + md_label_t type_name = mol->atom.type[i]; + md_element_t element = mol->atom.element[i]; + float mass = md_util_element_atomic_mass(element); + float radius = md_util_element_vdw_radius(element); + + // Find or add the atom type + md_atom_type_idx_t type_idx = md_atom_type_find_or_add(&mol->atom_type, type_name, element, mass, radius, alloc); + mol->atom.type_idx[i] = type_idx; } mol->unit_cell = md_util_unit_cell_from_matrix(data->models[0].cell); diff --git a/test_data b/test_data index ea52f2c..1c7b296 160000 --- a/test_data +++ b/test_data @@ -1 +1 @@ -Subproject commit ea52f2cbcf57ed6b52aacc53219bcc5dcf817b6a +Subproject commit 1c7b296307ac33a895f3c4d2fe68364ec9096365 diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index d78f2e3..3093144 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -11,9 +11,12 @@ target_link_libraries(md_unittest PRIVATE mdlib ${MD_LIBS}) set (SRC_FILES test_allocator.c + test_api_equivalence.c test_array.c + test_atomic.c test_bitop.c test_bitfield.c + test_element_guess_compat.c test_str.c test_hash.c test_edr.c diff --git a/unittest/test_api_equivalence.c b/unittest/test_api_equivalence.c new file mode 100644 index 0000000..8d48d3a --- /dev/null +++ b/unittest/test_api_equivalence.c @@ -0,0 +1,53 @@ +// Test to verify API equivalence between old and new systems +#include "utest.h" +#include +#include + +UTEST(api_equivalence, symbol_lookup_consistency) { + // Test that old and new APIs return the same results + for (int i = 1; i <= 118; ++i) { + str_t symbol_old = md_util_element_symbol(i); + str_t symbol_new = md_symbol_from_atomic_number(i); + + EXPECT_TRUE(str_eq(symbol_old, symbol_new)); + + // Test reverse lookup + md_atomic_number_t old_lookup = md_util_element_lookup(symbol_old); + md_atomic_number_t new_lookup = md_atomic_number_from_symbol(symbol_old); + + EXPECT_EQ(old_lookup, new_lookup); + EXPECT_EQ(old_lookup, i); + } +} + +UTEST(api_equivalence, property_consistency) { + // Test a few key elements for property consistency + md_atomic_number_t test_elements[] = {MD_Z_H, MD_Z_C, MD_Z_N, MD_Z_O, MD_Z_Ca, MD_Z_Fe}; + + for (size_t i = 0; i < ARRAY_SIZE(test_elements); ++i) { + md_atomic_number_t z = test_elements[i]; + + // Test masses + float mass_old = md_util_element_atomic_mass(z); + float mass_new = md_atomic_mass(z); + EXPECT_EQ(mass_old, mass_new); + + // Test radii + float vdw_old = md_util_element_vdw_radius(z); + float vdw_new = md_vdw_radius(z); + EXPECT_EQ(vdw_old, vdw_new); + + float cov_old = md_util_element_covalent_radius(z); + float cov_new = md_covalent_radius(z); + EXPECT_EQ(cov_old, cov_new); + + // Test valence and color + int val_old = md_util_element_max_valence(z); + int val_new = md_max_valence(z); + EXPECT_EQ(val_old, val_new); + + uint32_t color_old = md_util_element_cpk_color(z); + uint32_t color_new = md_cpk_color(z); + EXPECT_EQ(color_old, color_new); + } +} \ No newline at end of file diff --git a/unittest/test_atomic.c b/unittest/test_atomic.c new file mode 100644 index 0000000..ffe9e99 --- /dev/null +++ b/unittest/test_atomic.c @@ -0,0 +1,122 @@ +#include "utest.h" +#include +#include +#include +#include + +UTEST(atomic, enum_constants) { + // Test that enum constants are correct + EXPECT_EQ(MD_Z_X, 0); // Unknown + EXPECT_EQ(MD_Z_H, 1); // Hydrogen + EXPECT_EQ(MD_Z_He, 2); // Helium + EXPECT_EQ(MD_Z_C, 6); // Carbon + EXPECT_EQ(MD_Z_N, 7); // Nitrogen + EXPECT_EQ(MD_Z_O, 8); // Oxygen + EXPECT_EQ(MD_Z_P, 15); // Phosphorus + EXPECT_EQ(MD_Z_S, 16); // Sulfur + EXPECT_EQ(MD_Z_Ca, 20); // Calcium + EXPECT_EQ(MD_Z_Cl, 17); // Chlorine + EXPECT_EQ(MD_Z_Br, 35); // Bromine + EXPECT_EQ(MD_Z_Na, 11); // Sodium + EXPECT_EQ(MD_Z_Fe, 26); // Iron + EXPECT_EQ(MD_Z_Og, 118); // Oganesson +} + +UTEST(atomic, symbol_lookup) { + // Test symbol lookup functions + EXPECT_EQ(md_atomic_number_from_symbol(STR_LIT("H")), MD_Z_H); + EXPECT_EQ(md_atomic_number_from_symbol(STR_LIT("C")), MD_Z_C); + EXPECT_EQ(md_atomic_number_from_symbol(STR_LIT("He")), MD_Z_He); + EXPECT_EQ(md_atomic_number_from_symbol(STR_LIT("Ca")), MD_Z_Ca); + EXPECT_EQ(md_atomic_number_from_symbol(STR_LIT("Unknown")), MD_Z_X); + + // Test case insensitive lookup + EXPECT_EQ(md_atomic_number_from_symbol_icase(STR_LIT("h")), MD_Z_H); + EXPECT_EQ(md_atomic_number_from_symbol_icase(STR_LIT("ca")), MD_Z_Ca); + EXPECT_EQ(md_atomic_number_from_symbol_icase(STR_LIT("HE")), MD_Z_He); +} + +UTEST(atomic, symbol_from_number) { + // Test reverse lookup + str_t h_symbol = md_symbol_from_atomic_number(MD_Z_H); + EXPECT_TRUE(str_eq_cstr(h_symbol, "H")); + + str_t c_symbol = md_symbol_from_atomic_number(MD_Z_C); + EXPECT_TRUE(str_eq_cstr(c_symbol, "C")); + + str_t ca_symbol = md_symbol_from_atomic_number(MD_Z_Ca); + EXPECT_TRUE(str_eq_cstr(ca_symbol, "Ca")); +} + +UTEST(atomic, element_properties) { + // Test that we can get basic properties + EXPECT_GT(md_atomic_mass(MD_Z_H), 0.0f); + EXPECT_GT(md_atomic_mass(MD_Z_C), 0.0f); + EXPECT_GT(md_vdw_radius(MD_Z_H), 0.0f); + EXPECT_GT(md_covalent_radius(MD_Z_C), 0.0f); + EXPECT_GT(md_max_valence(MD_Z_C), 0); + EXPECT_GT(md_cpk_color(MD_Z_C), 0); +} + +UTEST(atomic, inference_water) { + // Test water atom inference + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("O"), STR_LIT("HOH")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("OW"), STR_LIT("HOH")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("OH2"), STR_LIT("WAT")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("H"), STR_LIT("HOH")), MD_Z_H); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("H1"), STR_LIT("TIP3")), MD_Z_H); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("HW"), STR_LIT("SPC")), MD_Z_H); +} + +UTEST(atomic, inference_amino_acid) { + // Test amino acid atom inference + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("CA"), STR_LIT("ALA")), MD_Z_C); // Alpha carbon, not calcium + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("N"), STR_LIT("GLY")), MD_Z_N); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("C"), STR_LIT("SER")), MD_Z_C); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("O"), STR_LIT("TRP")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("OXT"), STR_LIT("PHE")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("OG"), STR_LIT("SER")), MD_Z_O); // Serine hydroxyl + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("SG"), STR_LIT("CYS")), MD_Z_S); // Cysteine sulfur +} + +UTEST(atomic, inference_nucleic_acid) { + // Test nucleic acid atom inference + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("P"), STR_LIT("DA")), MD_Z_P); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("OP1"), STR_LIT("DG")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("O2P"), STR_LIT("A")), MD_Z_O); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("P"), STR_LIT("U")), MD_Z_P); +} + +UTEST(atomic, inference_ions) { + // Test ion inference (residue name is element) + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("NA"), STR_LIT("NA")), MD_Z_Na); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT(""), STR_LIT("K")), MD_Z_K); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("MG"), STR_LIT("MG")), MD_Z_Mg); + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("CL"), STR_LIT("CL")), MD_Z_Cl); +} + +UTEST(atomic, inference_fallbacks) { + // Test fallback mechanisms + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("CL12"), STR_LIT("")), MD_Z_Cl); // Two-letter heuristic + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("BR1"), STR_LIT("")), MD_Z_Br); // Two-letter heuristic + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("H123"), STR_LIT("")), MD_Z_H); // First letter fallback + EXPECT_EQ(md_atom_infer_atomic_number(STR_LIT("C99"), STR_LIT("")), MD_Z_C); // First letter fallback +} + +UTEST(atomic, backward_compatibility) { + // Test that old API still works through wrappers + EXPECT_EQ(md_util_element_lookup(STR_LIT("H")), MD_Z_H); + EXPECT_EQ(md_util_element_lookup_ignore_case(STR_LIT("ca")), MD_Z_Ca); + + str_t symbol = md_util_element_symbol(MD_Z_C); + EXPECT_TRUE(str_eq_cstr(symbol, "C")); + + str_t name = md_util_element_name(MD_Z_O); + EXPECT_TRUE(str_eq_cstr(name, "Oxygen")); + + EXPECT_GT(md_util_element_atomic_mass(MD_Z_C), 0.0f); + EXPECT_GT(md_util_element_vdw_radius(MD_Z_H), 0.0f); + EXPECT_GT(md_util_element_covalent_radius(MD_Z_N), 0.0f); + EXPECT_GT(md_util_element_max_valence(MD_Z_C), 0); + EXPECT_GT(md_util_element_cpk_color(MD_Z_O), 0); +} \ No newline at end of file diff --git a/unittest/test_element_guess_compat.c b/unittest/test_element_guess_compat.c new file mode 100644 index 0000000..787c5c9 --- /dev/null +++ b/unittest/test_element_guess_compat.c @@ -0,0 +1,143 @@ +#include "utest.h" +#include +#include +#include +#include +#include +#include +#include + +// Test backward compatibility with md_util_element_guess +UTEST(element_guess_compat, basic_inference) { + md_allocator_i* alloc = md_vm_arena_create(MEGABYTES(1)); + + // Create a simple molecule structure + md_molecule_t mol = {0}; + + // Atom data + const size_t atom_count = 5; + mol.atom.count = atom_count; + + // Allocate arrays + mol.atom.type = md_alloc(alloc, sizeof(md_label_t) * atom_count); + mol.atom.resname = md_alloc(alloc, sizeof(md_label_t) * atom_count); + mol.atom.element = md_alloc(alloc, sizeof(md_element_t) * atom_count); + + // Set up atom types and residue names + // HOH water oxygen + strncpy(mol.atom.type[0].buf, "O", sizeof(mol.atom.type[0].buf)); + mol.atom.type[0].len = 1; + strncpy(mol.atom.resname[0].buf, "HOH", sizeof(mol.atom.resname[0].buf)); + mol.atom.resname[0].len = 3; + mol.atom.element[0] = 0; // Start unknown + + // HOH water hydrogen + strncpy(mol.atom.type[1].buf, "H1", sizeof(mol.atom.type[1].buf)); + mol.atom.type[1].len = 2; + strncpy(mol.atom.resname[1].buf, "HOH", sizeof(mol.atom.resname[1].buf)); + mol.atom.resname[1].len = 3; + mol.atom.element[1] = 0; // Start unknown + + // Alanine alpha carbon + strncpy(mol.atom.type[2].buf, "CA", sizeof(mol.atom.type[2].buf)); + mol.atom.type[2].len = 2; + strncpy(mol.atom.resname[2].buf, "ALA", sizeof(mol.atom.resname[2].buf)); + mol.atom.resname[2].len = 3; + mol.atom.element[2] = 0; // Start unknown + + // Sodium ion + strncpy(mol.atom.type[3].buf, "NA", sizeof(mol.atom.type[3].buf)); + mol.atom.type[3].len = 2; + strncpy(mol.atom.resname[3].buf, "NA", sizeof(mol.atom.resname[3].buf)); + mol.atom.resname[3].len = 2; + mol.atom.element[3] = 0; // Start unknown + + // Generic carbon + strncpy(mol.atom.type[4].buf, "C1", sizeof(mol.atom.type[4].buf)); + mol.atom.type[4].len = 2; + strncpy(mol.atom.resname[4].buf, "", sizeof(mol.atom.resname[4].buf)); + mol.atom.resname[4].len = 0; + mol.atom.element[4] = 0; // Start unknown + + // Call the element guess function + bool result = md_util_element_guess(mol.atom.element, atom_count, &mol); + + // Verify results + EXPECT_TRUE(result); + EXPECT_EQ(mol.atom.element[0], MD_Z_O); // Water oxygen + EXPECT_EQ(mol.atom.element[1], MD_Z_H); // Water hydrogen + EXPECT_EQ(mol.atom.element[2], MD_Z_C); // Alanine alpha carbon (not calcium!) + EXPECT_EQ(mol.atom.element[3], MD_Z_Na); // Sodium ion + EXPECT_EQ(mol.atom.element[4], MD_Z_C); // Generic carbon from C1 + + md_vm_arena_destroy(alloc); +} + +// Test element inference against all PDB files with explicit element symbols +UTEST(element_guess_compat, all_pdb_validation) { + md_allocator_i* alloc = md_vm_arena_create(MEGABYTES(16)); + + // List of PDB files to test + const char* pdb_files[] = { + MD_UNITTEST_DATA_DIR"/1a64.pdb", + MD_UNITTEST_DATA_DIR"/1k4r.pdb", + MD_UNITTEST_DATA_DIR"/c60.pdb", + MD_UNITTEST_DATA_DIR"/ciprofloxacin.pdb", + MD_UNITTEST_DATA_DIR"/tryptophan.pdb", + MD_UNITTEST_DATA_DIR"/1ALA-560ns.pdb", + MD_UNITTEST_DATA_DIR"/dppc64.pdb" + }; + const size_t num_files = sizeof(pdb_files) / sizeof(pdb_files[0]); + + // Aggregate statistics across all files + size_t total_explicit_elements = 0; + size_t total_correct_inferences = 0; + size_t files_processed = 0; + + for (size_t file_idx = 0; file_idx < num_files; ++file_idx) { + str_t path = {pdb_files[file_idx], strlen(pdb_files[file_idx])}; + md_pdb_data_t pdb_data = {0}; + bool parse_result = md_pdb_data_parse_file(&pdb_data, path, alloc); + + if (!parse_result) { + // Some files might not exist, skip gracefully + continue; + } + + md_molecule_t mol = {0}; + bool mol_result = md_pdb_molecule_init(&mol, &pdb_data, MD_PDB_OPTION_NONE, alloc); + + if (mol_result && mol.atom.count > 0) { + files_processed++; + size_t file_explicit_elements = 0; + size_t file_correct_inferences = 0; + + for (size_t i = 0; i < mol.atom.count && i < pdb_data.num_atom_coordinates; ++i) { + const char* explicit_element = pdb_data.atom_coordinates[i].element; + if (explicit_element[0] != '\0' && explicit_element[0] != ' ') { + file_explicit_elements++; + total_explicit_elements++; + + str_t explicit_str = {explicit_element, strlen(explicit_element)}; + explicit_str = str_trim(explicit_str); + md_element_t expected_element = md_util_element_lookup_ignore_case(explicit_str); + md_element_t inferred_element = mol.atom.element[i]; + + if (expected_element != 0 && inferred_element == expected_element) { + file_correct_inferences++; + total_correct_inferences++; + } + } + } + + md_molecule_free(&mol, alloc); + } + md_pdb_data_free(&pdb_data, alloc); + } + + // Validate that we processed at least some files + EXPECT_GT(files_processed, 0); + EXPECT_EQ(total_correct_inferences, total_explicit_elements); + + md_vm_arena_destroy(alloc); +} \ No newline at end of file diff --git a/unittest/test_mmcif.c b/unittest/test_mmcif.c index e9b8636..62cc5fe 100644 --- a/unittest/test_mmcif.c +++ b/unittest/test_mmcif.c @@ -78,3 +78,43 @@ UTEST(mmcif, 8g7u) { md_molecule_free(&mol, md_get_heap_allocator()); } + +// Comprehensive test for all CIF files in test_data +UTEST(mmcif, all_cif_files) { + const char* cif_files[] = { + MD_UNITTEST_DATA_DIR"/1fez.cif", + MD_UNITTEST_DATA_DIR"/2or2.cif", + MD_UNITTEST_DATA_DIR"/8g7u.cif", + }; + const size_t expected_atom_counts[] = { + 4097, // 1fez.cif + 5382, // 2or2.cif + 14229, // 8g7u.cif + }; + const size_t num_files = sizeof(cif_files) / sizeof(cif_files[0]); + + for (size_t i = 0; i < num_files; ++i) { + str_t path = str_from_cstr(cif_files[i]); + md_molecule_t mol; + bool result = md_mmcif_molecule_api()->init_from_file(&mol, path, NULL, md_get_heap_allocator()); + + EXPECT_TRUE(result); + if (result) { + EXPECT_EQ(expected_atom_counts[i], mol.atom.count); + + // Verify all atoms have valid elements + size_t zero_element_count = 0; + for (size_t j = 0; j < mol.atom.count; ++j) { + if (mol.atom.element[j] == 0) { + zero_element_count++; + } + } + + // Allow some missing elements but ensure most are filled + double missing_ratio = (double)zero_element_count / (double)mol.atom.count; + EXPECT_LT(missing_ratio, 0.15); // Less than 15% missing + + md_molecule_free(&mol, md_get_heap_allocator()); + } + } +} diff --git a/unittest/test_util.c b/unittest/test_util.c index ac70f98..5e9976e 100644 --- a/unittest/test_util.c +++ b/unittest/test_util.c @@ -814,3 +814,105 @@ UTEST(util, radix_sort) { EXPECT_LE(arr[i], arr[i+1]); } } + +UTEST(util, lammps_mass_element_mapping) { + // Test conservative mass→element mapping for LAMMPS with various scenarios + + { + // Test all-atom case with standard elements + float masses[] = {1.008f, 12.011f, 14.007f, 15.999f, 30.974f, 32.06f}; + md_element_t elements[6] = {0}; + size_t count = ARRAY_SIZE(masses); + + EXPECT_TRUE(md_util_lammps_element_from_mass(elements, masses, count)); + EXPECT_EQ(elements[0], 1); // H + EXPECT_EQ(elements[1], 6); // C + EXPECT_EQ(elements[2], 7); // N + EXPECT_EQ(elements[3], 8); // O + EXPECT_EQ(elements[4], 15); // P + EXPECT_EQ(elements[5], 16); // S + } + + { + // Test with small deviations within tolerance + float masses[] = {1.01f, 12.1f, 14.1f, 15.9f}; // Within tolerance, but 1.01 instead of 1.0 to avoid reduced units filter + md_element_t elements[4] = {0}; + size_t count = ARRAY_SIZE(masses); + + EXPECT_TRUE(md_util_lammps_element_from_mass(elements, masses, count)); + EXPECT_EQ(elements[0], 1); // H + EXPECT_EQ(elements[1], 6); // C + EXPECT_EQ(elements[2], 7); // N + EXPECT_EQ(elements[3], 8); // O + } + + { + // Test CG-like masses (should fail and return false) + float masses[] = {72.0f, 72.0f, 72.0f, 72.0f, 72.0f}; // Typical CG mass + md_element_t elements[5] = {0}; + size_t count = ARRAY_SIZE(masses); + + EXPECT_FALSE(md_util_lammps_element_from_mass(elements, masses, count)); + // All elements should remain 0 (unassigned) + for (size_t i = 0; i < count; ++i) { + EXPECT_EQ(elements[i], 0); + } + } + + { + // Test reduced units (should fail and return false) + float masses[] = {1.0f, 1.0f, 1.0f, 1.0f}; // Reduced units + md_element_t elements[4] = {0}; + size_t count = ARRAY_SIZE(masses); + + EXPECT_FALSE(md_util_lammps_element_from_mass(elements, masses, count)); + // All elements should remain 0 (unassigned) + for (size_t i = 0; i < count; ++i) { + EXPECT_EQ(elements[i], 0); + } + } + + { + // Test partial mapping case (some masses match, some don't) + float masses[] = {12.011f, 45.0f}; // Carbon matches, 45.0f doesn't match any element well + md_element_t elements[2] = {0}; + size_t count = ARRAY_SIZE(masses); + + bool result = md_util_lammps_element_from_mass(elements, masses, count); + // Should succeed because at least one element can be mapped + // The first should map to Carbon, second should remain 0 + EXPECT_TRUE(result); + EXPECT_EQ(elements[0], 6); // C + EXPECT_EQ(elements[1], 0); // Unknown/unmatched + } + + { + // Test too few unique masses for the number of atoms (CG heuristic) + float masses[20]; + for (int i = 0; i < 20; ++i) { + masses[i] = (i < 10) ? 36.0f : 72.0f; // Only 2 unique masses for 20 atoms + } + md_element_t elements[20] = {0}; + size_t count = ARRAY_SIZE(masses); + + EXPECT_FALSE(md_util_lammps_element_from_mass(elements, masses, count)); + // All elements should remain 0 (unassigned) + for (size_t i = 0; i < count; ++i) { + EXPECT_EQ(elements[i], 0); + } + } + + { + // Test empty case + md_element_t elements[1] = {0}; + EXPECT_TRUE(md_util_lammps_element_from_mass(elements, NULL, 0)); + } + + { + // Test invalid input + float masses[] = {12.011f}; + EXPECT_FALSE(md_util_lammps_element_from_mass(NULL, masses, 1)); + md_element_t elements[1] = {0}; + EXPECT_FALSE(md_util_lammps_element_from_mass(elements, NULL, 1)); + } +}