# Weighting Extended XYZ Datasets for NEP / GPUMD

This notebook applies **per-dataset relative weights** to all frames in extended XYZ files
(GPUMD / NEP format).

Each frame structure:
- Line 1: number of atoms
- Line 2: metadata line (contains `weight=`)

Weights are applied **per frame**, not per atom.


## Cell 1 — Imports

In [6]:
from pathlib import Path
import re
import pandas as pd


## Cell 2 — User Configuration (EDIT THIS CELL ONLY)

In [7]:
# ============================================================
# INPUT DIRECTORY (contains *.xyz datasets)
# ============================================================

DATASET_DIR = Path("/blue/ypchen/emir.bilgili/NEP-TEST/proj/data/xyz/all_data_10-20")

# Output directory will be created automatically as:
#   <DATASET_DIR>/weighted/
OUTPUT_DIR = DATASET_DIR / "weighted"
OUTPUT_DIR.mkdir(exist_ok=True)

# ============================================================
# MODE: "set" | "multiply" | "add"
# ============================================================

MODE = "set"

# ============================================================
# PER-DATASET WEIGHTS (by filename)
# ============================================================

DATASET_WEIGHTS = {
    "aimd.xyz": 1.0,
    "alloy.xyz": 1.0,
    "elastic_tensor.xyz": 1.0,
    "hetero.xyz": 6.0,
    "interface_gasp.xyz": 1.0,
    "mos2.xyz": 2.0,
    "pressure.xyz": 2.0,
    "saph.xyz": 2.0,
    'phonons.xyz': 3.0,
}

# Files not listed above
DEFAULT_WEIGHT = None  # set to e.g. 1.0 to force all files


## Cell 3 — Internal Logic (do not edit)

In [8]:
WEIGHT_RE = re.compile(r'(?i)(^|\s)weight\s*=\s*([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)')

def update_comment_line(line, dataset_weight, mode):
    match = WEIGHT_RE.search(line)
    if match:
        existing = float(match.group(2))
        if mode == "set":
            new = dataset_weight
        elif mode == "multiply":
            new = existing * dataset_weight
        elif mode == "add":
            new = existing + dataset_weight
        else:
            raise ValueError(mode)
        start, end = match.span(2)
        return line[:start] + f"{new:.16g}" + line[end:]
    else:
        if mode == "set":
            new = dataset_weight
        elif mode == "multiply":
            new = 1.0 * dataset_weight
        elif mode == "add":
            new = dataset_weight
        else:
            raise ValueError(mode)
        return line.rstrip("\n") + f" weight={new:.16g}\n"

def process_xyz(infile, outfile, dataset_weight, mode):
    frames = 0
    modified = 0
    with open(infile, "r", errors="replace") as fin, open(outfile, "w") as fout:
        while True:
            line1 = fin.readline()
            if not line1:
                break
            if not line1.strip():
                fout.write(line1)
                continue
            n_atoms = int(line1.strip())
            line2 = fin.readline()
            new_line2 = update_comment_line(line2, dataset_weight, mode)
            frames += 1
            if new_line2 != line2:
                modified += 1
            fout.write(line1)
            fout.write(new_line2)
            for _ in range(n_atoms):
                fout.write(fin.readline())
    return frames, modified


## Cell 4 — Apply Weights

In [9]:
summary = []

for xyz in sorted(DATASET_DIR.glob("*.xyz")):
    name = xyz.name
    if name in DATASET_WEIGHTS:
        w = DATASET_WEIGHTS[name]
    elif DEFAULT_WEIGHT is not None:
        w = DEFAULT_WEIGHT
    else:
        print(f"[SKIP] {name}")
        continue
    out_path = OUTPUT_DIR / name
    frames, modified = process_xyz(xyz, out_path, w, MODE)
    summary.append((name, frames, modified, w))
    print(f"[OK] {name}: frames={frames}, modified={modified}, weight={w}")


[OK] aimd.xyz: frames=8708, modified=8708, weight=1.0
[OK] alloy.xyz: frames=6394, modified=6394, weight=1.0
[OK] elastic_tensor.xyz: frames=41481, modified=41481, weight=1.0
[OK] hetero.xyz: frames=6906, modified=6906, weight=6.0
[OK] interface_gasp.xyz: frames=152, modified=152, weight=1.0
[OK] mos2.xyz: frames=2681, modified=2681, weight=2.0
[OK] phonons.xyz: frames=9273, modified=9273, weight=3.0
[OK] pressure.xyz: frames=4168, modified=4168, weight=2.0
[OK] saph.xyz: frames=1810, modified=1810, weight=2.0


## Cell 5 — Summary

In [10]:
pd.DataFrame(summary, columns=["File", "Frames", "Frames Modified", "Applied Weight"])


Unnamed: 0,File,Frames,Frames Modified,Applied Weight
0,aimd.xyz,8708,8708,1.0
1,alloy.xyz,6394,6394,1.0
2,elastic_tensor.xyz,41481,41481,1.0
3,hetero.xyz,6906,6906,6.0
4,interface_gasp.xyz,152,152,1.0
5,mos2.xyz,2681,2681,2.0
6,phonons.xyz,9273,9273,3.0
7,pressure.xyz,4168,4168,2.0
8,saph.xyz,1810,1810,2.0
