# Shift SIMLIB MJDs to 2023

This cell shifts `PEAKMJD` and all `S:` observation MJDs to the 2023 window while preserving relative timing within each block. It keeps formatting consistent with the input file and reports basic validation stats.


In [None]:
import re
import math
import hashlib
import datetime as dt
from pathlib import Path

# -----------------------
# config
# -----------------------
input_path  = Path("LOWZ_REDSHIFT_LT015_FROM_SIMDATA.SIMLIB")
output_path = Path("LOWZ_REDSHIFT_LT015_FROM_SIMDATA_SHIFTED_TO_2023.SIMLIB")

start_date_str = "2022-12-31"
end_date_str   = "2023-12-31"

# -----------------------
# MJD <-> datetime
# -----------------------
MJD_EPOCH = dt.datetime(1858, 11, 17, 0, 0, 0)

def mjd_to_datetime(mjd: float) -> dt.datetime:
    return MJD_EPOCH + dt.timedelta(days=float(mjd))

def datetime_to_mjd(d: dt.datetime) -> float:
    return (d - MJD_EPOCH).total_seconds() / 86400.0

def replace_year_safe(d: dt.datetime, year: int) -> dt.datetime:
    """Replace year while keeping month/day/time; handle Feb29 -> Feb28 (or nearest valid)."""
    try:
        return d.replace(year=year)
    except ValueError:
        if d.month == 2 and d.day == 29:
            return d.replace(year=year, day=28)
        day = d.day
        while day > 28:
            day -= 1
            try:
                return d.replace(year=year, day=day)
            except ValueError:
                continue
        raise

# -----------------------
# formatting helpers
# -----------------------
def decimals_in_token(tok: str) -> int:
    return len(tok.split(".")[-1]) if "." in tok else 0

def format_like(val: float, template: str) -> str:
    """Format float like template token: preserve decimals and width."""
    dec = decimals_in_token(template)
    s = f"{val:.{dec}f}"
    return s.rjust(len(template))

def round_to(val: float, decs: int) -> float:
    return float(f"{val:.{decs}f}")

def ceil_to(val: float, decs: int) -> float:
    step = 10 ** (-decs)
    return math.ceil(val / step - 1e-12) * step

def floor_to(val: float, decs: int) -> float:
    step = 10 ** (-decs)
    return math.floor(val / step + 1e-12) * step

# -----------------------
# regex
# -----------------------
peak_token_re = re.compile(r"(PEAKMJD:\\s*)([0-9]+(?:\.[0-9]*)?)")
s_prefix_re   = re.compile(r"^(S:\\s*)([0-9]+(?:\.[0-9]*)?)")
nobs_re       = re.compile(r"NOBS:\\s*([0-9]+)")

def update_peak_line(line: str, new_peak: float) -> str:
    m = peak_token_re.search(line)
    if not m:
        return line
    templ = m.group(2)
    newtok = format_like(new_peak, templ)
    return line[:m.start(2)] + newtok + line[m.end(2):]

def update_s_line(line: str, new_mjd: float) -> str:
    m = s_prefix_re.match(line)
    if not m:
        return line
    templ = m.group(2)
    newtok = format_like(new_mjd, templ)
    return m.group(1) + newtok + line[m.end(2):]

# -----------------------
# parse blocks
# -----------------------
def parse_blocks(lines):
    """Return prefix_lines (before first LIBID) and list of blocks (each block = list[str])."""
    prefix = []
    blocks = []
    cur = None
    for line in lines:
        if line.startswith("LIBID:"):
            if cur is not None:
                blocks.append(cur)
            cur = [line]
        else:
            if cur is None:
                prefix.append(line)
            else:
                cur.append(line)
    if cur is not None:
        blocks.append(cur)
    return prefix, blocks

# -----------------------
# main shifting
# -----------------------
start_dt = dt.datetime.fromisoformat(start_date_str)  # 00:00:00
end_dt_excl = dt.datetime.fromisoformat(end_date_str) + dt.timedelta(days=1)  # exclusive
start_mjd = datetime_to_mjd(start_dt)
end_mjd_excl = datetime_to_mjd(end_dt_excl)

target_year = int(end_date_str.split("-")[0])  # 2023

text = input_path.read_text()
lines = text.splitlines(keepends=True)
prefix_lines, blocks = parse_blocks(lines)

# stats & checks
clamp_low = clamp_high = keep_monthday = 0
round_adjust = 0

max_dt_err = 0.0
bad_range = []
bad_sorted = []
bad_nobs = []

out_blocks = []

for blk in blocks:
    # find PEAKMJD line/token
    old_peak = None
    peak_idx = None
    peak_tok = None
    for i, line in enumerate(blk):
        m = peak_token_re.search(line)
        if m:
            old_peak = float(m.group(2))
            peak_idx = i
            peak_tok = m.group(2)
            break
    if old_peak is None:
        out_blocks.append(blk)
        continue

    peak_decs = decimals_in_token(peak_tok)

    # collect S lines
    s_idxs = []
    mjds = []
    mjd_toks = []
    for i, line in enumerate(blk):
        m = s_prefix_re.match(line)
        if m:
            s_idxs.append(i)
            mjd_toks.append(m.group(2))
            mjds.append(float(m.group(2)))

    if not mjds:
        out_blocks.append(blk)
        continue

    # relative times to peak (keep unchanged)
    dts = [m - old_peak for m in mjds]
    dt_min, dt_max = min(dts), max(dts)

    # allowed new peak range so that all obs are within [start, end_excl)
    allowed_min = start_mjd - dt_min
    allowed_max = (end_mjd_excl - 1e-9) - dt_max

    # candidate new peak: same month/day/time but in target_year
    cand_dt = replace_year_safe(mjd_to_datetime(old_peak), target_year)
    cand_mjd = datetime_to_mjd(cand_dt)

    if cand_mjd < allowed_min:
        new_peak_raw = allowed_min
        clamp_low += 1
        method = "clamp_low"
    elif cand_mjd > allowed_max:
        new_peak_raw = allowed_max
        clamp_high += 1
        method = "clamp_high"
    else:
        new_peak_raw = cand_mjd
        keep_monthday += 1
        method = "monthday_keep"

    # quantize peak to file precision, then ensure still within allowed range
    new_peak = round_to(new_peak_raw, peak_decs)
    if new_peak < allowed_min - 1e-12:
        new_peak = ceil_to(allowed_min, peak_decs)
        round_adjust += 1
    if new_peak > allowed_max + 1e-12:
        new_peak = floor_to(allowed_max, peak_decs)
        round_adjust += 1

    # update PEAKMJD line
    blk2 = list(blk)
    blk2[peak_idx] = update_peak_line(blk2[peak_idx], new_peak)

    # update S lines
    for (idx, old_mjd, tok) in zip(s_idxs, mjds, mjd_toks):
        mjd_decs = decimals_in_token(tok)
        dt_old = old_mjd - old_peak
        new_mjd = round_to(new_peak + dt_old, mjd_decs)
        blk2[idx] = update_s_line(blk2[idx], new_mjd)

        # range check
        if not (start_mjd - 1e-6 <= new_mjd < end_mjd_excl + 1e-6):
            bad_range.append(("MJD", blk2[0].strip(), new_mjd))

        # dt check (file-level)
        dt_new = new_mjd - new_peak
        max_dt_err = max(max_dt_err, abs(dt_new - dt_old))

    # sorted check
    new_mjds = [float(s_prefix_re.match(blk2[i]).group(2)) for i in s_idxs]
    if any(new_mjds[i] > new_mjds[i+1] + 1e-9 for i in range(len(new_mjds)-1)):
        bad_sorted.append(blk2[0].strip())

    # NOBS consistency check
    nobs_decl = None
    for line in blk2:
        m = nobs_re.search(line)
        if m and nobs_decl is None:
            nobs_decl = int(m.group(1))
            break
    if nobs_decl is not None and nobs_decl != len(s_idxs):
        bad_nobs.append((blk2[0].strip(), nobs_decl, len(s_idxs)))

    # PEAK range check
    if not (start_mjd - 1e-6 <= new_peak < end_mjd_excl + 1e-6):
        bad_range.append(("PEAK", blk2[0].strip(), new_peak))

    out_blocks.append(blk2)

# write output
out_text = "".join(prefix_lines) + "".join("".join(b) for b in out_blocks)
output_path.write_text(out_text)

# checksums
def md5(path: Path) -> str:
    h = hashlib.md5()
    with path.open("rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()

print("=== DONE ===")
print(f"Input : {input_path}")
print(f"Output: {output_path}")
print()
print(f"Time window: [{start_date_str} 00:00:00, {end_dt_excl.date()} 00:00:00)  (MJD [{start_mjd:.1f}, {end_mjd_excl:.1f}))")
print(f"Blocks total: {len(blocks)}")
print(f"Peak mapping: keep_monthday={keep_monthday}, clamp_low={clamp_low}, clamp_high={clamp_high}, round_adjust={round_adjust}")
print()
print(f"max |(MJD-PEAK)_new - (MJD-PEAK)_old| = {max_dt_err:.3e} days")
print(f"bad_range  = {len(bad_range)}")
print(f"bad_sorted = {len(bad_sorted)}")
print(f"bad_nobs   = {len(bad_nobs)}")
print()
print(f"MD5(input) = {md5(input_path)}")
print(f"MD5(output)= {md5(output_path)}")
