# Remap MobileNetV2 `.safetensors` → Rust naming ⚙️

Notebook ini akan **membaca** file `.safetensors` (beragam skema penamaan: TIMM klasik, atau pola `blocks.i.j.*` seperti yang kamu punya), lalu **mencoba memetakan** kunci-kuncinya ke skema Rust: 

- `base.stem.*`
- `base.ir_{i}_{j}.(expand|dw|project).(weight|bn.*)`
- `base.last.*`

Hasilnya disimpan ke file baru (misal: `mobilenet_v2_1_0_base_rust_mapped.safetensors`).

In [3]:

# === Path sumber & tujuan ===
from pathlib import Path

# Ubah jika perlu
CANDIDATES = [
    "./mobilenet_v2_1_0_imagenet.safetensors",
    "./model.safetensors",
    "./mobilenet_v2_1_0_imagenet.safetensors",
    "./model.safetensors",
]
SRC = None
for c in CANDIDATES:
    if Path(c).exists():
        SRC = c
        break

DST = "./mobilenet_v2_1_0_base_rust_mapped.safetensors"
print("Source:", SRC)
print("Dest  :", DST)
if SRC is None:
    raise SystemExit("⚠ File sumber tidak ditemukan. Set SRC secara manual.")


Source: ./mobilenet_v2_1_0_imagenet.safetensors
Dest  : ./mobilenet_v2_1_0_base_rust_mapped.safetensors


In [4]:

# === Imports ===
try:
    from safetensors.numpy import load_file, save_file
except Exception as e:
    raise SystemExit(
        "Tidak bisa import safetensors.numpy. Install dulu:\n"
        "  pip install safetensors\n"
        f"Detail error: {e}"
    )

import numpy as np
import pandas as pd
import re
from collections import defaultdict


In [5]:

# === Load file ===
sd = load_file(SRC)  # dict-like: key -> np.ndarray
print("Loaded tensors:", len(sd))
keys = sorted(sd.keys())
pd.DataFrame({"key": keys}).head(30)


Loaded tensors: 314


Unnamed: 0,key
0,blocks.0.0.bn1.bias
1,blocks.0.0.bn1.num_batches_tracked
2,blocks.0.0.bn1.running_mean
3,blocks.0.0.bn1.running_var
4,blocks.0.0.bn1.weight
5,blocks.0.0.bn2.bias
6,blocks.0.0.bn2.num_batches_tracked
7,blocks.0.0.bn2.running_mean
8,blocks.0.0.bn2.running_var
9,blocks.0.0.bn2.weight


## Konfigurasi blok (harus match dengan model Rust)


In [6]:

# cfg: (t, c, n, s) — mengikuti implementasi Rust kamu
CFG = [
    (1, 16, 1, 1),
    (6, 24, 2, 2),
    (6, 32, 3, 2),
    (6, 64, 4, 2),
    (6, 96, 3, 1),
    (6,160, 3, 2),
    (6,320, 1, 1),
]

def k_to_ij(k, cfg=CFG):
    kk = int(k)
    for i, (_t,_c,n,_s) in enumerate(cfg):
        if kk < n:
            return i, kk
        kk -= n
    raise ValueError(f"k out of range: {k}")


## Mapper untuk 2 skema umum: TIMM klasik dan `blocks.i.j.*`


In [7]:

def remap(sd):
    out = {}
    misses = []

    # --- STEM / LAST Candidates ---
    # TimM klasik:
    if "conv_stem.weight" in sd: out["base.stem.weight"] = sd["conv_stem.weight"]
    for p in ("weight","bias","running_mean","running_var","num_batches_tracked"):
        k = f"bn1.{p}"
        if k in sd: out[f"base.stem.bn.{p}"] = sd[k]

    if "conv_head.weight" in sd: out["base.last.weight"] = sd["conv_head.weight"]
    for p in ("weight","bias","running_mean","running_var","num_batches_tracked"):
        k = f"bn2.{p}"
        if k in sd: out[f"base.last.bn.{p}"] = sd[k]

    # Skema lain (TorchVision-ish) untuk stem/last (heuristik):
    # cari sesuatu seperti 'features.0.0.weight' / 'first_conv.weight'
    stem_candidates = [k for k in sd.keys() if re.search(r"(conv_stem|first_conv|features\.0\.\d+\.weight|conv1\.weight)", k)]
    if "base.stem.weight" not in out and stem_candidates:
        out["base.stem.weight"] = sd[stem_candidates[0]]
    # cari bn stem
    bn_stem_candidates = [k for k in sd.keys() if re.search(r"(bn1\.|features\.0\.\d+\.bn\.)", k)]
    # tidak selalu bisa di-map sempurna; skip jika tidak ketemu

    # --- BLOCKS: TIMM klasik: blocks.<k>.conv.xxx ---
    for k in list(sd.keys()):
        if k.startswith("blocks.") and ".conv." in k:
            # ex: blocks.4.conv.dw_bn.weight
            parts = k.split(".")
            # ["blocks","<flat_k>","conv", "<piece>", "<rest...>"]
            if len(parts) >= 4 and parts[2] == "conv":
                flat_k = int(parts[1])
                i, j = k_to_ij(flat_k)
                piece = parts[3]
                tail = ".".join(parts[4:]) if len(parts) > 4 else ""

                if piece == "pw":
                    dst = f"base.ir_{i}_{j}.expand.{tail or 'weight'}"
                elif piece == "pw_bn":
                    dst = f"base.ir_{i}_{j}.expand.bn.{tail}"
                elif piece == "dw":
                    dst = f"base.ir_{i}_{j}.dw.{tail or 'weight'}"
                elif piece == "dw_bn":
                    dst = f"base.ir_{i}_{j}.dw.bn.{tail}"
                elif piece == "pw_linear":
                    dst = f"base.ir_{i}_{j}.project.{tail or 'weight'}"
                elif piece == "pw_linear_bn":
                    dst = f"base.ir_{i}_{j}.project.bn.{tail}"
                else:
                    dst = None

                if dst:
                    out[dst] = sd[k]

    # --- BLOCKS: skema `blocks.i.j.*` (yang kamu kirim cuplikannya) ---
    # Pola yang dicari per (i,j):
    # - conv_dw.weight, bn? (bn1/bn2/... yang channel-nya == hidden)
    # - conv_pw(.weight) untuk expand / project (tergantung out_channels)
    # - mungkin ada conv_pw_1.weight untuk pointwise kedua
    # - bn1, bn2, bn3 ... (ditentukan via shape match)
    # Strategi: cocokkan berdasarkan shape.
    # Kumpulkan per blok (i,j)
    by_block = defaultdict(list)
    for k in sd.keys():
        m = re.match(r"blocks\.(\d+)\.(\d+)\.(.+)", k)
        if m:
            i, j, tail = int(m.group(1)), int(m.group(2)), m.group(3)
            by_block[(i,j)].append((k, tail))

    def get_bn_channels(key_prefix):
        # return channels if bn.* has weight; else None
        for p in ("weight","bias","running_mean","running_var"):
            kk = f"{key_prefix}.{p}"
            if kk in sd:
                return sd[kk].shape[0]
        return None

    for (i,j), items in by_block.items():
        # ambil kandidat conv & bn
        convs = {name: k for k, name in items if name.startswith("conv") and k.endswith(".weight")}
        bns   = {name: f"blocks.{i}.{j}.{name}" for _, name in items if name.startswith("bn")}

        # shapes
        dw_key = f"blocks.{i}.{j}.conv_dw.weight"
        dw_out = sd[dw_key].shape[0] if dw_key in sd else None

        # kumpulkan bn channels
        bn_ch = {}
        for bnname, bkey in bns.items():
            ch = get_bn_channels(bkey)
            if ch is not None:
                bn_ch[bnname] = ch

        # pointwise conv candidates
        pw_candidates = [n for n in convs if n.startswith("conv_pw")]
        # mapping sementara
        used_bn = set()

        # --- map depthwise ---
        if dw_out is not None:
            # cari bn dengan channels == dw_out
            bn_dw = None
            # prefer bn2 untuk kasus 2-BN (t=1), atau bn2 setelah bn1 (t>1)
            for pref in ("bn2","bn1","bn3","bn0","bn"):
                if pref in bn_ch and bn_ch[pref] == dw_out:
                    bn_dw = pref
                    break
            # fallback: ambil BN manapun yang match
            if not bn_dw:
                for bnname, ch in bn_ch.items():
                    if ch == dw_out:
                        bn_dw = bnname
                        break
            # set map dw
            out[f"base.ir_{i}_{j}.dw.weight"] = sd[dw_key]
            if bn_dw:
                used_bn.add(bn_dw)
                for p in ("weight","bias","running_mean","running_var","num_batches_tracked"):
                    src = f"blocks.{i}.{j}.{bn_dw}.{p}"
                    if src in sd:
                        out[f"base.ir_{i}_{j}.dw.bn.{p}"] = sd[src]

        # --- map pointwise convs (expand/project) ---
        # deteksi out_ch untuk masing-masing pw
        pw_info = []
        for name in pw_candidates:
            kk = f"blocks.{i}.{j}.{name}.weight"
            if kk in sd:
                pw_info.append((name, sd[kk].shape[0], kk))
        # sort agar stabil
        pw_info.sort()

        # heuristik: yang out_ch == dw_out → expand, sisanya → project
        for name, out_ch, kk in pw_info:
            if dw_out is not None and out_ch == dw_out:
                # expand
                out[f"base.ir_{i}_{j}.expand.weight"] = sd[kk]
                # cari bn utk expand (prefer bn1 yg match ch)
                bn_exp = None
                for pref in ("bn1","bn0","bn3"):
                    if pref in bn_ch and bn_ch.get(pref) == out_ch and pref not in used_bn:
                        bn_exp = pref; break
                if not bn_exp:
                    for bnname, ch in bn_ch.items():
                        if ch == out_ch and bnname not in used_bn:
                            bn_exp = bnname; break
                if bn_exp:
                    used_bn.add(bn_exp)
                    for p in ("weight","bias","running_mean","running_var","num_batches_tracked"):
                        src = f"blocks.{i}.{j}.{bn_exp}.{p}"
                        if src in sd:
                            out[f"base.ir_{i}_{j}.expand.bn.{p}"] = sd[src]
            else:
                # project
                out[f"base.ir_{i}_{j}.project.weight"] = sd[kk]
                # cari bn utk project (channels == out_ch & belum terpakai)
                bn_prj = None
                for pref in ("bn3","bn2","bn1","bn0","bn"):
                    if pref in bn_ch and bn_ch.get(pref) == out_ch and pref not in used_bn:
                        bn_prj = pref; break
                if not bn_prj:
                    for bnname, ch in bn_ch.items():
                        if ch == out_ch and bnname not in used_bn:
                            bn_prj = bnname; break
                if bn_prj:
                    used_bn.add(bn_prj)
                    for p in ("weight","bias","running_mean","running_var","num_batches_tracked"):
                        src = f"blocks.{i}.{j}.{bn_prj}.{p}"
                        if src in sd:
                            out[f"base.ir_{i}_{j}.project.bn.{p}"] = sd[src]

    return out


In [8]:

# === Jalankan remap dan simpan ===
out = remap(sd)
print("Mapped tensors:", len(out))

# Simpan
from safetensors.numpy import save_file
save_file(out, DST)
print("Saved:", DST)

# Ringkasannya
rows = [{"dst_key": k, "shape": tuple(v.shape)} for k, v in sorted(out.items())]
pd.DataFrame(rows).head(30)


Mapped tensors: 29
Saved: ./mobilenet_v2_1_0_base_rust_mapped.safetensors


Unnamed: 0,dst_key,shape
0,base.ir_0_0.dw.weight,"(32, 1, 3, 3)"
1,base.ir_1_0.dw.weight,"(96, 1, 3, 3)"
2,base.ir_1_1.dw.weight,"(144, 1, 3, 3)"
3,base.ir_2_0.dw.weight,"(144, 1, 3, 3)"
4,base.ir_2_1.dw.weight,"(192, 1, 3, 3)"
5,base.ir_2_2.dw.weight,"(192, 1, 3, 3)"
6,base.ir_3_0.dw.weight,"(192, 1, 3, 3)"
7,base.ir_3_1.dw.weight,"(384, 1, 3, 3)"
8,base.ir_3_2.dw.weight,"(384, 1, 3, 3)"
9,base.ir_3_3.dw.weight,"(384, 1, 3, 3)"
