In [1]:
# === 1) Setup + legacy imports (with your pip install), save versions ===

# Install exactly what you had (no NumPy pin, no runtime reboot)
!pip install -q "transformers>=4.40" timm accelerate torchmetrics albumentations

# --- Imports (legacy-style) ---
import os, io, re, gc, math, json, time, random, shutil, hashlib, glob, pathlib
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm as tqdm_auto

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# HF Transformers (legacy-style)
from transformers import (
    AutoImageProcessor,
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
)
from transformers import logging as hf_logging

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Root dir for all artifacts
ROOT_DIR = "/content/drive/MyDrive/human_segmentation_experiments_article0_final_flow"
os.makedirs(ROOT_DIR, exist_ok=True)
print("ROOT_DIR:", ROOT_DIR)

# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# Quieter HF logs
hf_logging.set_verbosity_error()

# --- Save environment/package versions to JSON (for reproducibility audit) ---
def _safe_ver(modname):
    try:
        m = __import__(modname)
        return getattr(m, "__version__", "unknown")
    except Exception:
        return None

def _pip_freeze():
    # Best-effort snapshot of all installed dists
    try:
        import sys, subprocess
        out = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True)
        return out.strip().splitlines()
    except Exception:
        return []

cv2_ver = None
try:
    import cv2
    cv2_ver = cv2.__version__
except Exception:
    pass

ENV_INFO = {
    "python": __import__("sys").version.splitlines()[0],
    "device": DEVICE,
    "seed": SEED,
    "numpy": np.__version__,
    "pillow": _safe_ver("PIL") or _safe_ver("Pillow"),
    "torch": torch.__version__,
    "transformers": _safe_ver("transformers"),
    "timm": _safe_ver("timm"),
    "accelerate": _safe_ver("accelerate"),
    "torchmetrics": _safe_ver("torchmetrics"),
    "albumentations": _safe_ver("albumentations"),
    "sklearn": _safe_ver("sklearn"),
    "opencv-python": cv2_ver,
    "pip_freeze": _pip_freeze(),
}
with open(os.path.join(ROOT_DIR, "env_setup.json"), "w") as f:
    json.dump(ENV_INFO, f, indent=2)

print(
    "Versions →",
    "numpy", ENV_INFO["numpy"],
    "| torch", ENV_INFO["torch"],
    "| transformers", ENV_INFO["transformers"],
    "| timm", ENV_INFO["timm"],
    "| accelerate", ENV_INFO["accelerate"],
    "| albumentations", ENV_INFO["albumentations"],
    "| sklearn", ENV_INFO["sklearn"],
    "| opencv", ENV_INFO["opencv-python"],
)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━[0m [32m870.4/983.2 kB[0m [31m27.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive
ROOT_DIR: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow
Device: cuda
Versions → numpy 2.0.2 | torch 2.8.0+cu126 | transformers 4.56.1 | timm 1.0.19 | accelerate 1.10.1 | albumentations 2.0.8 | sklearn 1.6.1 | opencv 4.12.0


In [2]:
# === 2) Loading data (Control) → split into JSON → datasets and loaders ===
import os, tarfile, json, random, re, glob
from pathlib import Path
from typing import List, Tuple
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from transformers import AutoImageProcessor

# --- Paths (Supervisely archive and where to unzip) ---
DATA_AR_PATH = "/content/drive/MyDrive/human_segmentation_vit_experements/data/supervisely_dataset.tar.gz"
EXTRACT_DIR  = "/content/ds_extracted"           # isolated unpacking
RAW_ROOT     = Path(EXTRACT_DIR)                  # we will search for images/masks under this root
SPLIT_JSON   = os.path.join(ROOT_DIR, "train_val_split.json")

IMG_SIZE = 512
IGNORE_INDEX = 255
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# --- Unpacking the archive (idempotent) ---
os.makedirs(EXTRACT_DIR, exist_ok=True)
if not any(Path(EXTRACT_DIR).iterdir()):
    assert os.path.isfile(DATA_AR_PATH), f"Archive not found: {DATA_AR_PATH}"
    with tarfile.open(DATA_AR_PATH, "r:*") as tar:
        tar.extractall(EXTRACT_DIR)
    print("Распаковано в:", EXTRACT_DIR)
else:
    print("Архив уже распакован →", EXTRACT_DIR)

# --- 2.2 Search for folders with images and masks (flexibly for Supervisely) ---
def _find_candidate_dirs(root: Path, name_patterns):
    """Возвращает директории, чьё имя совпадает с любым паттерном (без учёта регистра)."""
    outs = []
    for p in root.rglob("*"):
        if p.is_dir():
            bn = p.name.lower()
            if any(re.fullmatch(pat, bn) for pat in name_patterns):
                outs.append(p)
    return outs

img_dir_candidates = _find_candidate_dirs(RAW_ROOT, [r"(img|imgs|images|image|jpegimages)"])
mask_dir_candidates = _find_candidate_dirs(RAW_ROOT, [r"(ann|anns|masks|mask|segmentationclass|labels?)"])

def _has_exts(p: Path, exts):
    return any(p.rglob(f"*{e}") for e in exts)

IMG_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
MSK_EXTS = [".png", ".bmp"]

if not img_dir_candidates:
    img_dir_candidates = [d for d in RAW_ROOT.rglob("*") if d.is_dir() and _has_exts(d, IMG_EXTS)]
if not mask_dir_candidates:
    mask_dir_candidates = [d for d in RAW_ROOT.rglob("*") if d.is_dir() and _has_exts(d, MSK_EXTS)]

assert img_dir_candidates, "No image directories found."
assert mask_dir_candidates, "No directories with masks found."

def _collect_files(dirs, exts):
    files = []
    for d in dirs:
        for e in exts:
            files += glob.glob(str(d / f"**/*{e}"), recursive=True)
    return files

img_files = _collect_files(img_dir_candidates, IMG_EXTS)
msk_files = _collect_files(mask_dir_candidates, MSK_EXTS)
assert img_files and msk_files, "No image/mask files found."

def _stem(path: str) -> str: return Path(path).stem

img_map = {};  [img_map.setdefault(_stem(p), p) for p in img_files]
msk_map = {};  [msk_map.setdefault(_stem(p), p) for p in msk_files]
stems = sorted(set(img_map.keys()) & set(msk_map.keys()))
assert stems, "The intersection of image names and masks is empty."

pairs_all: List[Tuple[str, str]] = [(img_map[s], msk_map[s]) for s in stems]
print(f"Image/mask pairs found: {len(pairs_all)}")

# --- 2.3 Load or create a split and save it as JSON ---
if os.path.isfile(SPLIT_JSON):
    with open(SPLIT_JSON, "r") as f:
        split_payload = json.load(f)
    train_pairs = [tuple(p) for p in split_payload["train"]]
    val_pairs   = [tuple(p) for p in split_payload["val"]]
    print(f"Loaded split from JSON → train={len(train_pairs)} | val={len(val_pairs)}")
else:
    random.shuffle(pairs_all)
    split = int(0.9 * len(pairs_all))
    train_pairs, val_pairs = pairs_all[:split], pairs_all[split:]
    with open(SPLIT_JSON, "w") as f:
        json.dump({"train": train_pairs, "val": val_pairs}, f, indent=2)
    print(f"Split created and saved → train={len(train_pairs)} | val={len(val_pairs)}")
print("SPLIT_JSON:", SPLIT_JSON)

# --- 2.4 Augmentations (like teacher; student can reuse) ---
train_tf = A.Compose([
    A.LongestMaxSize(max_size=IMG_SIZE, interpolation=1),                 # cv2.INTER_LINEAR
    A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, border_mode=0),# cv2.BORDER_CONSTANT
    A.RandomCrop(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(p=0.3, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    A.Affine(translate_percent=(-0.05, 0.05), scale=(0.9, 1.1), rotate=(-15, 15),
             fit_output=False, cval=0, cval_mask=0, mode=0, p=0.5),
], is_check_shapes=False)

val_tf = A.Compose([
    A.LongestMaxSize(max_size=IMG_SIZE, interpolation=1),
    A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, border_mode=0),
    A.CenterCrop(IMG_SIZE, IMG_SIZE),
], is_check_shapes=False)

# --- 2.5 Dataset (robust mask binarization) ---
class SegPairsDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str,str]], image_processor, size=512, transform=None, ignore_index=255):
        self.pairs = pairs
        self.proc = image_processor
        self.size = size
        self.transform = transform
        self.ignore_index = ignore_index

    def __len__(self): return len(self.pairs)

    def _binarize_mask(self, mask_np: np.ndarray) -> np.ndarray:
        """
        We convert the mask to {0,1}. Typical cases:
        - {0,255} → calculate (mask>0)
        - {0,1} → leave as is
        - otherwise → (mask>0)
        Then we verify that only {0,1,IGNORE}.
        """
        u = np.unique(mask_np)
        if set(u.tolist()).issubset({0, 255}):
            m = (mask_np > 0).astype(np.uint8)
        elif set(u.tolist()).issubset({0, 1}):
            m = mask_np.astype(np.uint8)
        else:
            m = (mask_np > 0).astype(np.uint8)
        m = np.where((m == 0) | (m == 1), m, self.ignore_index).astype(np.uint8)
        return m

    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]
        img  = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        if mask.mode != "L": mask = mask.convert("L")

        img_np  = np.array(img)
        mask_np = np.array(mask, dtype=np.uint8)

        # augmentations (if specified)
        if self.transform is not None:
            out = self.transform(image=img_np, mask=mask_np)
            img_np, mask_np = out["image"], out["mask"]

        # binarize the mask
        mask_np = self._binarize_mask(mask_np)

        # the processor will convert the image to the required tensors/normalization
        enc = self.proc(
            images=Image.fromarray(img_np),
            return_tensors="pt",
            do_resize=True,
            size={"height": self.size, "width": self.size},
        )
        pixel_values = enc["pixel_values"].squeeze(0)  # [3,H,W]
        labels = torch.from_numpy(mask_np).long()      # [H,W]
        return {"pixel_values": pixel_values, "labels": labels}

# --- 2.6 Processor (basic; teacher/student will override if necessary) ---
processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

# --- 2.7 Datasets and loaders (common for student phases; teacher will use the same ds) ---
train_ds = SegPairsDataset(train_pairs, processor, size=IMG_SIZE, transform=train_tf, ignore_index=IGNORE_INDEX)
val_ds   = SegPairsDataset(val_pairs,   processor, size=IMG_SIZE, transform=val_tf,   ignore_index=IGNORE_INDEX)

def seg_collate(batch):
    """We fold the batch into tensors."""
    pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0)
    labels       = torch.stack([b["labels"] for b in batch], dim=0)
    return {"pixel_values": pixel_values, "labels": labels}

NUM_WORKERS = 0  # It's more reliable in Colab
BATCH_SIZE  = 4 if torch.cuda.is_available() else 2

train_loader_student = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=False, drop_last=True, collate_fn=seg_collate
)
val_loader_student = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=False, drop_last=False, collate_fn=seg_collate
)

print(f"Done. train_batches={len(train_loader_student)} | val_batches={len(val_loader_student)}")
print("Example of a couple:", train_pairs[0])

# common compute_miou
def compute_miou_from_logits(logits: torch.Tensor, labels_np: np.ndarray, ignore_index=255):
    H, W = labels_np.shape[-2], labels_np.shape[-1]
    with torch.no_grad():
        up = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
        preds = up.argmax(dim=1).cpu().numpy()
    valid = (labels_np != ignore_index)
    tp  = np.logical_and.reduce((preds == 1, labels_np == 1, valid)).sum()
    fp  = np.logical_and.reduce((preds == 1, labels_np == 0, valid)).sum()
    fn  = np.logical_and.reduce((preds == 0, labels_np == 1, valid)).sum()
    tp0 = np.logical_and.reduce((preds == 0, labels_np == 0, valid)).sum()
    fp0 = np.logical_and.reduce((preds == 0, labels_np == 1, valid)).sum()
    fn0 = np.logical_and.reduce((preds == 1, labels_np == 0, valid)).sum()
    iou1 = tp  / (tp  + fp  + fn ) if (tp  + fp  + fn ) > 0 else 0.0
    iou0 = tp0 / (tp0 + fp0 + fn0) if (tp0 + fp0 + fn0) > 0 else 0.0
    return float((iou0+iou1)/2.0), float(iou1), float(iou0)



  tar.extractall(EXTRACT_DIR)


Распаковано в: /content/ds_extracted
Image/mask pairs found: 5711
Loaded split from JSON → train=5139 | val=572
SPLIT_JSON: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/train_val_split.json


  A.Affine(translate_percent=(-0.05, 0.05), scale=(0.9, 1.1), rotate=(-15, 15),


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

Done. train_batches=1284 | val_batches=143
Example of a couple: ('/content/ds_extracted/sp_dataset/images/pexels-photo-593561.jpeg', '/content/ds_extracted/sp_dataset/masks/pexels-photo-593561.png')


  image_processor = cls(**image_processor_dict)


In [4]:
# === 3) Basic directories for experiments (without compat/alias) ===
import os, json


TEACHER_RUNS_DIR = os.path.join(ROOT_DIR, "teacher_runs")
STUDENT_P1_ROOT  = os.path.join(ROOT_DIR, "student_phase1_alora_delayed")
STUDENT_P2_ROOT  = os.path.join(ROOT_DIR, "student_phase2_kr_ema")

os.makedirs(TEACHER_RUNS_DIR, exist_ok=True)
os.makedirs(STUDENT_P1_ROOT,  exist_ok=True)
os.makedirs(STUDENT_P2_ROOT,  exist_ok=True)

META_PATHS = os.path.join(ROOT_DIR, "paths_meta.json")
with open(META_PATHS, "w") as f:
    json.dump({
        "TEACHER_RUNS_DIR": TEACHER_RUNS_DIR,
        "STUDENT_P1_ROOT":  STUDENT_P1_ROOT,
        "STUDENT_P2_ROOT":  STUDENT_P2_ROOT,
        "SPLIT_JSON":       os.path.join(ROOT_DIR, "train_val_split.json"),
    }, f, indent=2)

print("Catalogs have been created.")
print("TEACHER_RUNS_DIR:", TEACHER_RUNS_DIR)
print("STUDENT_P1_ROOT :", STUDENT_P1_ROOT)
print("STUDENT_P2_ROOT :", STUDENT_P2_ROOT)
print("META_PATHS     →", META_PATHS)


Catalogs have been created.
TEACHER_RUNS_DIR: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/teacher_runs
STUDENT_P1_ROOT : /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed
STUDENT_P2_ROOT : /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema
META_PATHS     → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/paths_meta.json


In [None]:
# === 4) Teacher training (SegFormer-B0, binary) ===
import os, time, json, numpy as np, torch, torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation, TrainingArguments, Trainer, EarlyStoppingCallback

# --- labels / constants ---
id2label = {0: "background", 1: "person"}
label2id = {v: k for k, v in id2label.items()}
IGNORE_INDEX = 255
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- model ---
MODEL_NAME = "nvidia/segformer-b0-finetuned-ade-512-512"
teacher_model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME, token=None,
    num_labels=2, id2label=id2label, label2id=label2id,
    ignore_mismatched_sizes=True,
)
teacher_model.config.ignore_index = IGNORE_INDEX
teacher_model.to(DEVICE)

# --- collator (reuse-friendly) ---
def seg_collate(batch):
    pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0)
    labels       = torch.stack([b["labels"] for b in batch], dim=0)
    return {"pixel_values": pixel_values, "labels": labels}

# --- metrics (mIoU, IoU per class, F1(person), pixel acc) ---
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, (list, tuple)):
        logits = logits[0]
    labels = labels.astype(np.int64)
    N, C, h, w = logits.shape
    H, W = labels.shape[-2], labels.shape[-1]

    with torch.no_grad():
        t = torch.from_numpy(logits)
        t = F.interpolate(t, size=(H, W), mode="bilinear", align_corners=False)
        preds = t.argmax(dim=1).cpu().numpy()

    valid = (labels != IGNORE_INDEX)
    tp = np.logical_and.reduce((preds == 1, labels == 1, valid)).sum()
    fp = np.logical_and.reduce((preds == 1, labels == 0, valid)).sum()
    fn = np.logical_and.reduce((preds == 0, labels == 1, valid)).sum()

    tp_bg = np.logical_and.reduce((preds == 0, labels == 0, valid)).sum()
    fp_bg = np.logical_and.reduce((preds == 0, labels == 1, valid)).sum()
    fn_bg = np.logical_and.reduce((preds == 1, labels == 0, valid)).sum()

    iou_person = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
    iou_bg     = tp_bg / (tp_bg + fp_bg + fn_bg) if (tp_bg + fp_bg + fn_bg) > 0 else 0.0
    miou       = (iou_person + iou_bg) / 2.0

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_person = 2*precision*recall / (precision+recall) if (precision+recall) > 0 else 0.0
    pixel_acc = (np.logical_and(preds == labels, valid).sum() / valid.sum()) if valid.sum() > 0 else 0.0

    return {
        "miou": float(miou),
        "iou_person": float(iou_person),
        "iou_bg": float(iou_bg),
        "f1_person": float(f1_person),
        "pixel_acc": float(pixel_acc),
    }

# --- training args ---
RUN_NAME = time.strftime("segformer_b0_binary_teacher_%Y%m%d_%H%M%S")
SAVE_DIR = os.path.join(TEACHER_RUNS_DIR, RUN_NAME)
os.makedirs(SAVE_DIR, exist_ok=True)

BS = 12 if DEVICE == "cuda" else 2
EPOCHS = 20
LR = 4e-5

args = TrainingArguments(
    output_dir=os.path.join(SAVE_DIR, "checkpoints"),
    learning_rate=LR,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BS,
    per_device_eval_batch_size=BS,
    gradient_accumulation_steps=2,        # effective BS ~= 24
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="miou",
    greater_is_better=True,
    fp16=(DEVICE == "cuda" and not torch.cuda.is_bf16_supported()),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=50,
    report_to="none",
    dataloader_num_workers=2,
    dataloader_pin_memory=True,
    max_grad_norm=1.0,
)

# --- trainer ---
teacher_trainer = Trainer(
    model=teacher_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=seg_collate,
    tokenizer=processor,                  # keep legacy field to avoid breaking changes
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

print("Starting training teacher…")
teacher_trainer.train()

best_cp = teacher_trainer.state.best_model_checkpoint or args.output_dir
print("Best checkpoint (by mIoU):", best_cp)

# --- финальное сохранение ---
TEACHER_DIR = os.path.join(SAVE_DIR, "teacher_binary_segformer_b0")
teacher_trainer.save_model(TEACHER_DIR)
processor.save_pretrained(TEACHER_DIR)
print("Teacher saved in:", TEACHER_DIR)

# --- meta для следующих фаз ---
BEST_TEACHER_META = os.path.join(ROOT_DIR, "best_teacher.json")
with open(BEST_TEACHER_META, "w") as f:
    json.dump({"best_cp": best_cp, "teacher_dir": TEACHER_DIR}, f, indent=2)
print("BEST_TEACHER_META →", BEST_TEACHER_META)


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

  teacher_trainer = Trainer(


Starting training teacher…
{'loss': 0.6366, 'grad_norm': 2.5181469917297363, 'learning_rate': 9.116279069767443e-06, 'epoch': 0.2331002331002331}
{'loss': 0.5175, 'grad_norm': 1.9426133632659912, 'learning_rate': 1.8418604651162793e-05, 'epoch': 0.4662004662004662}
{'loss': 0.36, 'grad_norm': 1.837462067604065, 'learning_rate': 2.7720930232558143e-05, 'epoch': 0.6993006993006993}
{'loss': 0.2527, 'grad_norm': 1.446797490119934, 'learning_rate': 3.702325581395349e-05, 'epoch': 0.9324009324009324}
{'eval_loss': 0.17239660024642944, 'eval_miou': 0.8720695250747997, 'eval_iou_person': 0.7906738755099967, 'eval_iou_bg': 0.9534651746396028, 'eval_f1_person': 0.883102039208348, 'eval_pixel_acc': 0.9604219289926382, 'eval_runtime': 73.1, 'eval_samples_per_second': 7.825, 'eval_steps_per_second': 0.657, 'epoch': 1.0}
{'loss': 0.19, 'grad_norm': 1.5755938291549683, 'learning_rate': 3.999316326552322e-05, 'epoch': 1.1631701631701632}
{'loss': 0.1597, 'grad_norm': 1.5375727415084839, 'learning_rat

In [None]:
# === 5) Student Phase-1 (Fused-Atom + L0-HC, delayed sparsity) ===
import os, math, json, time, gc, random
from dataclasses import dataclass
from typing import Optional
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import SegformerForSemanticSegmentation, AutoImageProcessor

# ---- preconditions: needed train_loader_student / val_loader_student и best_teacher.json ----
assert 'train_loader_student' in globals() and 'val_loader_student' in globals(), "Нет train/val лоадеров"
with open(os.path.join(ROOT_DIR, "best_teacher.json"), "r") as f:
    best_meta = json.load(f)
best_cp = best_meta["best_cp"]

# ---- clearing memory/GPU ----
for v in ['student','teacher']:
    if v in globals():
        try: del globals()[v]
        except: pass
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# ---- reproducibility ----
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BF16 = torch.cuda.is_bf16_supported()
AMP_DTYPE = torch.bfloat16 if BF16 else torch.float16
IGNORE_INDEX = 255

# ====== config ======
@dataclass
class StudentCfg:
    teacher_ckpt: str = best_cp
    out_root: str = STUDENT_P1_ROOT
    # atoms
    n_atoms_attn_qkv: int = 8
    n_atoms_attn_out: int = 12
    n_atoms_ffn: int = 8
    rank_per_atom: int = 4
    # learning
    epochs: int = 30
    lr_atoms: float = 2e-4
    lr_gates: float = 1e-3
    weight_decay: float = 1e-3
    warmup_ratio: float = 0.05
    grad_clip: float = 1.0
    heartbeat_steps: int = 200
    init_log_alpha: float = 1.5   # start p≈0.95–0.99
    # KD
    kd_alpha: float = 0.70
    kd_temp: float  = 4.0
    kd_warmup_frac: float = 0.10
    # sparsity (delay before the epochs 5)
    sparsity_start_epoch: int = 5
    rho_end: float = 0.65
    rho_portion_after_start: float = 0.50
    l0_lambda_final: float = 2e-3
    l0_portion_after_start: float = 0.60
    # Resume (optional)
    resume_dir: Optional[str] = None

cfg = StudentCfg()
os.makedirs(cfg.out_root, exist_ok=True)

# ====== HardConcrete (векторный по K атомам) ======
class HardConcreteVectorGate(nn.Module):
    def __init__(self, K, init_log_alpha=1.5, beta=2./3., gamma=-0.1, zeta=1.1, device=None, dtype=torch.float32):
        super().__init__()
        self.log_alpha = nn.Parameter(torch.full((K,), float(init_log_alpha), device=device, dtype=dtype))
        self.beta, self.gamma, self.zeta = beta, gamma, zeta
        const = math.log(-gamma / zeta)
        self.register_buffer("hc_log_term", torch.tensor(const, dtype=dtype, device=device))
    def _stretch(self, s): return s * (self.zeta - self.gamma) + self.gamma
    def prob_open(self):  return torch.sigmoid(self.log_alpha - self.beta * self.hc_log_term)
    def sample(self, training: bool):
        if training:
            u = torch.rand_like(self.log_alpha)
            s = torch.sigmoid((torch.log(u) - torch.log(1-u) + self.log_alpha) / self.beta)
            s = self._stretch(s)
            return torch.clamp(s, 0, 1)
        else:
            return (self.prob_open() > 0.5).to(self.log_alpha.dtype)

# ====== Fused-Atom Linear (без циклов по атомам) ======
class FusedAtomLinear(nn.Module):
    def __init__(self, in_f, out_f, n_atoms=8, rank=4, add_bias=False, device=None, dtype=torch.float32, init_log_alpha=1.5):
        super().__init__()
        self.in_f, self.out_f = in_f, out_f
        self.n_atoms, self.rank = n_atoms, rank
        self.R = n_atoms * rank
        self.block_size = rank
        self.A = nn.Parameter(torch.zeros(out_f, self.R, device=device, dtype=dtype))
        self.B = nn.Parameter(torch.zeros(self.R,  in_f, device=device, dtype=dtype))
        self.gate = HardConcreteVectorGate(n_atoms, init_log_alpha=init_log_alpha, device=device, dtype=dtype)
        self.extra_bias = nn.Parameter(torch.zeros(out_f, device=device, dtype=dtype)) if add_bias else None
        self.use_expected_gates = True
        self.target_open_frac: Optional[float] = None  # top-k по атомам в train

    def forward(self, x):  # x:[*, in]
        z = self.gate.prob_open() if (self.use_expected_gates and self.training) else self.gate.sample(self.training)  # [K]
        z = z.to(dtype=self.A.dtype, device=self.A.device)
        if (self.target_open_frac is not None) and self.training:
            K = self.n_atoms
            keep = max(1, int(math.ceil(self.target_open_frac * K)))
            topk_idx = torch.topk(z, keep, sorted=False).indices
            maskK = torch.zeros_like(z); maskK.scatter_(0, topk_idx, 1.0); z = z * maskK
        z = z.repeat_interleave(self.block_size)  # [R]
        h = F.linear(x, self.B)   # [*, R]
        h = h * z                 # gate rank-групп
        y = F.linear(h, self.A)   # [*, out]
        if self.extra_bias is not None: y = y + self.extra_bias
        return y

@torch.no_grad()
def init_atoms_from_pretrained_linear(atom_layer: FusedAtomLinear, pretrained_linear: nn.Linear):
    W = pretrained_linear.weight.detach().cpu().float()
    out_f, in_f = W.shape
    K, r = atom_layer.n_atoms, atom_layer.rank
    R = K * r
    A_full = torch.zeros(out_f, R)
    B_full = torch.zeros(R, in_f)
    R_res = W.clone()
    col = 0
    for _ in range(K):
        U, S, Vh = torch.linalg.svd(R_res, full_matrices=False)
        rr = min(r, S.shape[0])
        if rr == 0 or float(S[0]) < 1e-8: break
        U_i, S_i, V_i = U[:, :rr], S[:rr], Vh[:rr, :]
        s = torch.sqrt(S_i)
        A_i = U_i * s.unsqueeze(0)
        B_i = s.unsqueeze(1) * V_i
        A_full[:, col:col+rr] = A_i
        B_full[col:col+rr, :] = B_i
        col += rr
        R_res -= (A_i @ B_i)
    atom_layer.A.data.zero_(); atom_layer.B.data.zero_()
    atom_layer.A.data[:, :col] = A_full[:, :col].to(atom_layer.A.dtype).to(atom_layer.A.device)
    atom_layer.B.data[:col, :] = B_full[:col, :].to(atom_layer.B.dtype).to(atom_layer.B.device)
    atom_layer.gate.log_alpha.data.fill_(cfg.init_log_alpha)

def replace_linear_with_fused_atoms(module, name, n_atoms, rank):
    old = getattr(module, name, None)
    if not isinstance(old, nn.Linear): return False
    atom = FusedAtomLinear(
        in_f=old.in_features, out_f=old.out_features,
        n_atoms=n_atoms, rank=rank, add_bias=False,
        device=old.weight.device, dtype=old.weight.dtype,
        init_log_alpha=cfg.init_log_alpha
    )
    init_atoms_from_pretrained_linear(atom, old)
    setattr(module, name, atom); return True

def maybe_replace(module, candidate_names, n_atoms, rank):
    for name in candidate_names:
        if hasattr(module, name) and isinstance(getattr(module, name), nn.Linear):
            replace_linear_with_fused_atoms(module, name, n_atoms, rank); return 1
    return 0

def iter_segformer_blocks(model):
    enc = model.segformer.encoder
    stages = getattr(enc, "block", None) or getattr(enc, "blocks", None)
    if stages is None: raise RuntimeError("SegFormer encoder stages not found.")
    for stage in stages:
        for blk in stage: yield blk

def inject_fused_atoms(model: SegformerForSemanticSegmentation,
                       n_atoms_qkv=8, n_atoms_out=12, n_atoms_ffn=8, rank=4):
    count = 0
    for blk in iter_segformer_blocks(model):
        attn = blk.attention
        attn_self = getattr(attn, "self", attn)
        attn_out  = getattr(attn, "output", attn)
        count += maybe_replace(attn_self, ["query","q"], n_atoms_qkv, rank)
        count += maybe_replace(attn_self, ["key","k"],   n_atoms_qkv, rank)
        count += maybe_replace(attn_self, ["value","v"], n_atoms_qkv, rank)
        count += maybe_replace(attn_out,  ["dense","proj","out"], n_atoms_out, rank)
        mlp = blk.mlp
        count += maybe_replace(mlp, ["dense1","fc1"], n_atoms_ffn, rank)
        count += maybe_replace(mlp, ["dense2","fc2"], n_atoms_ffn, rank)
    return count

# ----- losses/metrics -----
def ce_loss_from_logits(logits, labels, ignore_index=255):
    return F.cross_entropy(logits, labels.long(), ignore_index=ignore_index)

def kd_loss_seg_masked(student_logits, teacher_logits, labels, ignore_index=255, T=4.0):
    H, W = labels.shape[-2], labels.shape[-1]
    s = F.interpolate(student_logits, size=(H, W), mode="bilinear", align_corners=False)
    t = F.interpolate(teacher_logits, size=(H, W), mode="bilinear", align_corners=False)
    sl = F.log_softmax(s / T, dim=1); tl = F.softmax(t / T, dim=1)
    kd_map = F.kl_div(sl, tl, reduction="none").sum(dim=1)  # [B,H,W]
    valid = (labels != ignore_index).float()
    return (kd_map * valid).sum() / valid.sum().clamp_min(1.0) * (T*T)

def l0_penalty(model, lam):
    s = 0.0
    for m in model.modules():
        if isinstance(m, FusedAtomLinear):
            s = s + m.gate.prob_open().sum()
    return lam * s

@torch.no_grad()
def mean_p_open(model):
    vals = []
    for m in model.modules():
        if isinstance(m, FusedAtomLinear):
            vals.append(m.gate.prob_open().detach().mean())
    return float(torch.stack(vals).mean()) if vals else 0.0

def estimate_linear_flops_ratio(model, rho=None):
    base, new = 0.0, 0.0
    for m in model.modules():
        if isinstance(m, FusedAtomLinear):
            in_f, out_f = m.in_f, m.out_f
            K, r = m.n_atoms, m.rank
            base += in_f * out_f
            keep = K if (rho is None) else max(1, int(math.ceil(rho * K)))
            new  += keep * r * (in_f + out_f)
    return (new / base) if base > 0 else 1.0

# ----- schedules with delays sparsity -----
def make_delayed_schedules(total_steps, steps_per_epoch, start_epoch, rho_end, rho_portion_after_start, l0_final, l0_portion_after_start, kd_alpha, kd_warmup_frac):
    start_step = start_epoch * steps_per_epoch
    remain = max(1, total_steps - start_step)
    rho_len = int(max(1, rho_portion_after_start * remain))
    l0_len  = int(max(1, l0_portion_after_start  * remain))
    kd_warm = int(max(1, kd_warmup_frac * total_steps))
    def rho_sched(step):
        if step < start_step: return 1.0
        t = min(1.0, (step - start_step) / rho_len)
        return 1.0 - t * (1.0 - rho_end)
    def l0_sched(step):
        if step < start_step: return 0.0
        t = min(1.0, (step - start_step) / l0_len)
        return l0_final * t
    def kd_sched(step):
        t = min(1.0, step / kd_warm); return kd_alpha * t
    return rho_sched, l0_sched, kd_sched, start_step

# ----- Teacher / Student -----
print("Loading teacher:", cfg.teacher_ckpt)
processor_student = AutoImageProcessor.from_pretrained(cfg.teacher_ckpt)
teacher = SegformerForSemanticSegmentation.from_pretrained(cfg.teacher_ckpt).to(DEVICE)
teacher.eval();  [p.requires_grad_(False) for p in teacher.parameters()]

student = SegformerForSemanticSegmentation.from_pretrained(cfg.teacher_ckpt).to(DEVICE)
wrapped_cnt = inject_fused_atoms(student,
    n_atoms_qkv=cfg.n_atoms_attn_qkv,
    n_atoms_out=cfg.n_atoms_attn_out,
    n_atoms_ffn=cfg.n_atoms_ffn,
    rank=cfg.rank_per_atom
)
for n,p in student.named_parameters():
    p.requires_grad_( any(k in n for k in ["A","B","gate","extra_bias"]) )

n_train = sum(p.numel() for p in student.parameters() if p.requires_grad)
n_total = sum(p.numel() for p in student.parameters())
print(f"Linear layers replaced: {wrapped_cnt} | Trainable params: {n_train/1e6:.3f}M / {n_total/1e6:.3f}M")

# ----- Optim, sched, amp -----
gate_params, atom_params = [], []
for n,p in student.named_parameters():
    if not p.requires_grad: continue
    (gate_params if "gate" in n else atom_params).append(p)

optimizer = torch.optim.AdamW([
    {"params": atom_params, "lr": cfg.lr_atoms, "weight_decay": cfg.weight_decay},
    {"params": gate_params, "lr": cfg.lr_gates, "weight_decay": 0.0},
])

steps_per_epoch = len(train_loader_student)
total_steps = cfg.epochs * steps_per_epoch
warmup_steps = max(1, int(cfg.warmup_ratio * total_steps))

def lr_lambda(step):
    if step < warmup_steps: return step / float(warmup_steps)
    t = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * t))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = torch.amp.GradScaler('cuda', enabled=(DEVICE=="cuda" and not BF16))

rho_sched, l0_sched, kd_sched, START_SPARSE_STEP = make_delayed_schedules(
    total_steps, steps_per_epoch,
    start_epoch=cfg.sparsity_start_epoch-1,  # начать ПОСЛЕ завершения эпохи (4 -> с 5)
    rho_end=cfg.rho_end, rho_portion_after_start=cfg.rho_portion_after_start,
    l0_final=cfg.l0_lambda_final, l0_portion_after_start=cfg.l0_portion_after_start,
    kd_alpha=cfg.kd_alpha, kd_warmup_frac=cfg.kd_warmup_frac
)

def set_target_open_frac(model, rho: Optional[float]):
    for m in model.modules():
        if isinstance(m, FusedAtomLinear):
            m.target_open_frac = rho

# ----- Resume -----
START_EPOCH = 0; global_step = 0
if cfg.resume_dir:
    try:
        print("Resuming from:", cfg.resume_dir)
        student = SegformerForSemanticSegmentation.from_pretrained(cfg.resume_dir).to(DEVICE)
        st = torch.load(os.path.join(cfg.resume_dir, "training_state.pt"))
        START_EPOCH = int(st.get("epoch", 0))
        global_step = START_EPOCH * steps_per_epoch
    except Exception as e:
        print("Resume failed:", e)

# ----- Train -----
RUN_DIR = os.path.join(cfg.out_root, time.strftime("run_%Y%m%d_%H%M%S"))
os.makedirs(RUN_DIR, exist_ok=True)               # <<< фикс: создать RUN_DIR перед записью cfg.json
BEST_DIR = os.path.join(cfg.out_root, "best_student_fused_atoms_delayed")
os.makedirs(BEST_DIR, exist_ok=True)

with open(os.path.join(RUN_DIR, "cfg.json"), "w") as f:
    json.dump(cfg.__dict__, f, indent=2)

best_miou = -1.0
print("Начинаю обучение student (fused-atom, delayed sparsity)…")
for epoch in range(START_EPOCH, cfg.epochs):
    student.train()
    running = {"loss":0.0,"ce":0.0,"kd":0.0,"l0":0.0}
    pbar = tqdm(train_loader_student, desc=f"Epoch {epoch+1}/{cfg.epochs}")

    for it, batch in enumerate(pbar):
        pixel_values = batch["pixel_values"].to(DEVICE, non_blocking=True)
        labels       = batch["labels"].to(DEVICE, non_blocking=True)

        # --- расписания ---
        rho_now   = rho_sched(global_step)
        lam_now   = l0_sched(global_step)
        alpha_now = kd_sched(global_step)
        set_target_open_frac(student, rho_now if global_step >= START_SPARSE_STEP else 1.0)

        # --- KD teacher ---
        with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(DEVICE=="cuda")):
            t_logits_raw = teacher(pixel_values).logits

        # --- forward/backward ---
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(DEVICE=="cuda")):
            s_logits_raw = student(pixel_values).logits
            s_logits = F.interpolate(s_logits_raw, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            loss_ce = ce_loss_from_logits(s_logits, labels, ignore_index=IGNORE_INDEX)
            loss_kd = kd_loss_seg_masked(s_logits_raw, t_logits_raw, labels, ignore_index=IGNORE_INDEX, T=cfg.kd_temp) * alpha_now
            loss_l0 = l0_penalty(student, lam_now)
            loss = loss_ce + loss_kd + loss_l0

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_([p for p in student.parameters() if p.requires_grad], cfg.grad_clip)
        scaler.step(optimizer); scaler.update()
        scheduler.step()

        running["loss"] += loss.item(); running["ce"] += loss_ce.item(); running["kd"] += loss_kd.item(); running["l0"] += loss_l0.item()
        pbar.set_postfix({
            "loss": f"{running['loss']/(it+1):.3f}",
            "ce":   f"{running['ce']/(it+1):.3f}",
            "kd":   f"{running['kd']/(it+1):.3f}",
            "l0":   f"{running['l0']/(it+1):.3f}",
            "ρ": f"{rho_now:.2f}", "λ": f"{lam_now:.1e}", "αkd": f"{alpha_now:.2f}",
            "p": f"{mean_p_open(student):.3f}",
        })
        if cfg.heartbeat_steps and (global_step % cfg.heartbeat_steps == 0):
            if DEVICE == "cuda": torch.cuda.synchronize()
            print(f"[heartbeat] step={global_step}")
        global_step += 1
    pbar.close()

    # ---- Validation ----
    student.eval(); set_target_open_frac(student, rho_sched(total_steps-1))
    with torch.no_grad():
        miou_sum=i1_sum=i0_sum=val_ce=0.0; n_batches=0
        for batch in val_loader_student:
            pv = batch["pixel_values"].to(DEVICE, non_blocking=True)
            y  = batch["labels"].to(DEVICE, non_blocking=True)
            logits_raw = student(pv).logits
            logits = F.interpolate(logits_raw, size=y.shape[-2:], mode="bilinear", align_corners=False)
            val_ce += ce_loss_from_logits(logits, y, ignore_index=IGNORE_INDEX).item()
            m,i1,i0 = compute_miou_from_logits(logits_raw, y.cpu().numpy(), ignore_index=IGNORE_INDEX)
            miou_sum += m; i1_sum += i1; i0_sum += i0; n_batches += 1
    val_miou = miou_sum / max(1, n_batches)
    val_iou1 = i1_sum / max(1, n_batches)
    val_iou0 = i0_sum / max(1, n_batches)
    val_ce /= max(1, len(val_loader_student))
    flops_ratio = estimate_linear_flops_ratio(student, rho=rho_sched(total_steps-1))
    print(f"[VAL] mIoU={val_miou:.4f} | IoU(person)={val_iou1:.4f} | IoU(bg)={val_iou0:.4f} | CE={val_ce:.4f} | FLOPs(linear)≈{flops_ratio:.3f}x")

    # save best
    if val_miou > best_miou:
        best_miou = val_miou
        os.makedirs(BEST_DIR, exist_ok=True)
        student.save_pretrained(BEST_DIR)
        processor_student.save_pretrained(BEST_DIR)
        torch.save({"best_miou": best_miou, "epoch": epoch+1}, os.path.join(BEST_DIR, "best_state.pt"))
        with open(os.path.join(cfg.out_root, "student_cfg.json"), "w") as f: json.dump(cfg.__dict__, f, indent=2)
        print("Saved BEST student to:", BEST_DIR)

    # per-epoch snapshot
    ep_dir = os.path.join(cfg.out_root, f"epoch_{epoch+1:03d}_miou{val_miou:.4f}_flops{flops_ratio:.3f}")
    os.makedirs(ep_dir, exist_ok=True)
    student.save_pretrained(ep_dir)
    processor_student.save_pretrained(ep_dir)
    torch.save({"epoch": epoch+1, "global_step": global_step, "val_miou": val_miou}, os.path.join(ep_dir, "training_state.pt"))
    print("Epoch snapshot saved to:", ep_dir)

print("Done. Best mIoU (student):", best_miou)

# --- meta for Phase-2 ---
with open(os.path.join(ROOT_DIR, "best_student_phase1.json"), "w") as f:
    json.dump({"student_phase1_best_dir": os.path.join(cfg.out_root, "best_student_fused_atoms_delayed")}, f, indent=2)
print("Saved meta →", os.path.join(ROOT_DIR, "best_student_phase1.json"))


Loading teacher: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/teacher_runs/segformer_b0_binary_teacher_20250930_214520/checkpoints/checkpoint-3010
Linear layers replaced: 48 | Trainable params: 0.623M / 2.018M
Начинаю обучение student (fused-atom, delayed sparsity)…


Epoch 1/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=0
[heartbeat] step=200
[heartbeat] step=400
[heartbeat] step=600
[heartbeat] step=800
[heartbeat] step=1000
[heartbeat] step=1200
[VAL] mIoU=0.8822 | IoU(person)=0.8078 | IoU(bg)=0.9566 | CE=0.1093 | FLOPs(linear)≈0.199x
Saved BEST student to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/best_student_fused_atoms_delayed
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_001_miou0.8822_flops0.199


Epoch 2/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=1400
[heartbeat] step=1600
[heartbeat] step=1800
[heartbeat] step=2000
[heartbeat] step=2200
[heartbeat] step=2400
[VAL] mIoU=0.8736 | IoU(person)=0.7912 | IoU(bg)=0.9560 | CE=0.1150 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_002_miou0.8736_flops0.199


Epoch 3/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=2600
[heartbeat] step=2800
[heartbeat] step=3000
[heartbeat] step=3200
[heartbeat] step=3400
[heartbeat] step=3600
[heartbeat] step=3800
[VAL] mIoU=0.8968 | IoU(person)=0.8325 | IoU(bg)=0.9610 | CE=0.0977 | FLOPs(linear)≈0.199x
Saved BEST student to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/best_student_fused_atoms_delayed
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_003_miou0.8968_flops0.199


Epoch 4/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=4000
[heartbeat] step=4200
[heartbeat] step=4400
[heartbeat] step=4600
[heartbeat] step=4800
[heartbeat] step=5000
[VAL] mIoU=0.8771 | IoU(person)=0.7968 | IoU(bg)=0.9574 | CE=0.1145 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_004_miou0.8771_flops0.199


Epoch 5/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=5200
[heartbeat] step=5400
[heartbeat] step=5600
[heartbeat] step=5800
[heartbeat] step=6000
[heartbeat] step=6200
[heartbeat] step=6400
[VAL] mIoU=0.9088 | IoU(person)=0.8515 | IoU(bg)=0.9662 | CE=0.0866 | FLOPs(linear)≈0.199x
Saved BEST student to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/best_student_fused_atoms_delayed
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_005_miou0.9088_flops0.199


Epoch 6/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=6600
[heartbeat] step=6800
[heartbeat] step=7000
[heartbeat] step=7200
[heartbeat] step=7400
[heartbeat] step=7600
[VAL] mIoU=0.9082 | IoU(person)=0.8500 | IoU(bg)=0.9664 | CE=0.0851 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_006_miou0.9082_flops0.199


Epoch 7/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=7800
[heartbeat] step=8000
[heartbeat] step=8200
[heartbeat] step=8400
[heartbeat] step=8600
[heartbeat] step=8800
[VAL] mIoU=0.9010 | IoU(person)=0.8375 | IoU(bg)=0.9645 | CE=0.0959 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_007_miou0.9010_flops0.199


Epoch 8/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=9000
[heartbeat] step=9200
[heartbeat] step=9400
[heartbeat] step=9600
[heartbeat] step=9800
[heartbeat] step=10000
[heartbeat] step=10200
[VAL] mIoU=0.9104 | IoU(person)=0.8518 | IoU(bg)=0.9690 | CE=0.0798 | FLOPs(linear)≈0.199x
Saved BEST student to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/best_student_fused_atoms_delayed
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_008_miou0.9104_flops0.199


Epoch 9/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=10400
[heartbeat] step=10600
[heartbeat] step=10800
[heartbeat] step=11000
[heartbeat] step=11200
[heartbeat] step=11400
[VAL] mIoU=0.8843 | IoU(person)=0.8091 | IoU(bg)=0.9594 | CE=0.1167 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_009_miou0.8843_flops0.199


Epoch 10/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=11600
[heartbeat] step=11800
[heartbeat] step=12000
[heartbeat] step=12200
[heartbeat] step=12400
[heartbeat] step=12600
[heartbeat] step=12800
[VAL] mIoU=0.9115 | IoU(person)=0.8539 | IoU(bg)=0.9691 | CE=0.0728 | FLOPs(linear)≈0.199x
Saved BEST student to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/best_student_fused_atoms_delayed
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_010_miou0.9115_flops0.199


Epoch 11/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=13000
[heartbeat] step=13200
[heartbeat] step=13400
[heartbeat] step=13600
[heartbeat] step=13800
[heartbeat] step=14000
[VAL] mIoU=0.9035 | IoU(person)=0.8402 | IoU(bg)=0.9668 | CE=0.0791 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_011_miou0.9035_flops0.199


Epoch 12/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=14200
[heartbeat] step=14400
[heartbeat] step=14600
[heartbeat] step=14800
[heartbeat] step=15000
[heartbeat] step=15200
[heartbeat] step=15400
[VAL] mIoU=0.8793 | IoU(person)=0.8004 | IoU(bg)=0.9583 | CE=0.1169 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_012_miou0.8793_flops0.199


Epoch 13/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=15600
[heartbeat] step=15800
[heartbeat] step=16000
[heartbeat] step=16200
[heartbeat] step=16400
[heartbeat] step=16600
[VAL] mIoU=0.8794 | IoU(person)=0.8007 | IoU(bg)=0.9581 | CE=0.1097 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_013_miou0.8794_flops0.199


Epoch 14/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=16800
[heartbeat] step=17000
[heartbeat] step=17200
[heartbeat] step=17400
[heartbeat] step=17600
[heartbeat] step=17800
[VAL] mIoU=0.8612 | IoU(person)=0.7693 | IoU(bg)=0.9530 | CE=0.1306 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_014_miou0.8612_flops0.199


Epoch 15/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=18000
[heartbeat] step=18200
[heartbeat] step=18400
[heartbeat] step=18600
[heartbeat] step=18800
[heartbeat] step=19000
[heartbeat] step=19200
[VAL] mIoU=0.7863 | IoU(person)=0.6436 | IoU(bg)=0.9289 | CE=0.2353 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_015_miou0.7863_flops0.199


Epoch 16/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=19400
[heartbeat] step=19600
[heartbeat] step=19800
[heartbeat] step=20000
[heartbeat] step=20200
[heartbeat] step=20400
[VAL] mIoU=0.8451 | IoU(person)=0.7428 | IoU(bg)=0.9475 | CE=0.1557 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_016_miou0.8451_flops0.199


Epoch 17/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=20600
[heartbeat] step=20800
[heartbeat] step=21000
[heartbeat] step=21200
[heartbeat] step=21400
[heartbeat] step=21600
[heartbeat] step=21800
[VAL] mIoU=0.8616 | IoU(person)=0.7715 | IoU(bg)=0.9518 | CE=0.1160 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_017_miou0.8616_flops0.199


Epoch 18/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=22000
[heartbeat] step=22200
[heartbeat] step=22400
[heartbeat] step=22600
[heartbeat] step=22800
[heartbeat] step=23000
[VAL] mIoU=0.8426 | IoU(person)=0.7386 | IoU(bg)=0.9466 | CE=0.1459 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_018_miou0.8426_flops0.199


Epoch 19/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=23200
[heartbeat] step=23400
[heartbeat] step=23600
[heartbeat] step=23800
[heartbeat] step=24000
[heartbeat] step=24200
[VAL] mIoU=0.8369 | IoU(person)=0.7297 | IoU(bg)=0.9441 | CE=0.1648 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_019_miou0.8369_flops0.199


Epoch 20/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=24400
[heartbeat] step=24600
[heartbeat] step=24800
[heartbeat] step=25000
[heartbeat] step=25200
[heartbeat] step=25400
[heartbeat] step=25600
[VAL] mIoU=0.8270 | IoU(person)=0.7152 | IoU(bg)=0.9387 | CE=0.1642 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_020_miou0.8270_flops0.199


Epoch 21/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=25800
[heartbeat] step=26000
[heartbeat] step=26200
[heartbeat] step=26400
[heartbeat] step=26600
[heartbeat] step=26800
[VAL] mIoU=0.8284 | IoU(person)=0.7154 | IoU(bg)=0.9414 | CE=0.1660 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_021_miou0.8284_flops0.199


Epoch 22/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=27000
[heartbeat] step=27200
[heartbeat] step=27400
[heartbeat] step=27600
[heartbeat] step=27800
[heartbeat] step=28000
[heartbeat] step=28200
[VAL] mIoU=0.8218 | IoU(person)=0.7052 | IoU(bg)=0.9384 | CE=0.1785 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_022_miou0.8218_flops0.199


Epoch 23/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=28400
[heartbeat] step=28600
[heartbeat] step=28800
[heartbeat] step=29000
[heartbeat] step=29200
[heartbeat] step=29400
[VAL] mIoU=0.8009 | IoU(person)=0.6699 | IoU(bg)=0.9320 | CE=0.2044 | FLOPs(linear)≈0.199x
Epoch snapshot saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase1_alora_delayed/epoch_023_miou0.8009_flops0.199


Epoch 24/30:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=29600
[heartbeat] step=29800


In [5]:
# === 6) Student Phase-2 (KR-hard + EMA) ===
import os, json
from pathlib import Path
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import SegformerForSemanticSegmentation, AutoImageProcessor

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BF16 = torch.cuda.is_bf16_supported()
AMP_DTYPE = torch.bfloat16 if BF16 else torch.float16
IGNORE_INDEX = 255

# --------- Read paths from meta ---------
with open(os.path.join(ROOT_DIR, "best_teacher.json")) as f:
    TEACHER_CKPT = json.load(f)["best_cp"]

meta_p1 = os.path.join(ROOT_DIR, "best_student_phase1.json")
if os.path.isfile(meta_p1):
    with open(meta_p1) as f:
        STUDENT_DIR = json.load(f)["student_phase1_best_dir"]
else:
    STUDENT_DIR = os.path.join(STUDENT_P1_ROOT, "best_student_fused_atoms_delayed")
assert os.path.isdir(STUDENT_DIR), f"Не найден каталог Phase-1: {STUDENT_DIR}"

OUT_ROOT = os.path.join(STUDENT_P2_ROOT, "kr_hard_ema")
SAVE_BEST_DIR = os.path.join(STUDENT_P2_ROOT, "best")
os.makedirs(OUT_ROOT, exist_ok=True)
os.makedirs(SAVE_BEST_DIR, exist_ok=True)

# --------- KR/EMA hparams ----------
@dataclass
class KRConf:
    epochs: int = 15
    lr: float = 6e-5
    weight_decay: float = 5e-3
    ema_decay: float = 0.999
    kd_alpha: float = 0.85
    kd_temp: float = 2.0
cfg = KRConf()

# --------- Minimal fused-atom modules (matching Phase-1) ----------
class HardConcreteGate(nn.Module):
    def __init__(self, init_log_alpha=1.5, beta=2./3., gamma=-0.1, zeta=1.1, device=None, dtype=torch.float32):
        super().__init__()
        self.log_alpha = nn.Parameter(torch.tensor(float(init_log_alpha), device=device, dtype=dtype))
        self.beta, self.gamma, self.zeta = beta, gamma, zeta
        const = float(torch.log(torch.tensor(-gamma / zeta)))
        self.register_buffer("hc_log_term", torch.tensor(const, dtype=torch.float32, device=device))
    def prob_open(self):
        return torch.sigmoid(self.log_alpha - self.beta * self.hc_log_term)
    def sample(self, training: bool):
        if training:
            u = torch.rand((), device=self.log_alpha.device, dtype=self.log_alpha.dtype)
            s = torch.sigmoid((torch.log(u) - torch.log(1-u) + self.log_alpha) / self.beta)
            s = s * (self.zeta - self.gamma) + self.gamma
            return torch.clamp(s, 0, 1)
        else:
            return (self.prob_open() > 0.5).to(self.log_alpha.dtype)

class FusedAtomLinear(nn.Module):
    def __init__(self, in_f, out_f, n_atoms=8, rank=4, add_bias=False, device=None, dtype=torch.float32):
        super().__init__()
        self.in_f, self.out_f = in_f, out_f
        self.n_atoms, self.rank = n_atoms, rank
        self.R = n_atoms * rank
        self.A = nn.Parameter(torch.zeros(out_f, self.R, device=device, dtype=dtype))
        self.B = nn.Parameter(torch.zeros(self.R, in_f,  device=device, dtype=dtype))
        self.gates = nn.ModuleList([HardConcreteGate(device=device, dtype=dtype) for _ in range(n_atoms)])
        self.extra_bias = nn.Parameter(torch.zeros(out_f, device=device, dtype=dtype)) if add_bias else None
        self.block_size = rank
    def forward(self, x):
        zs = []
        training = self.training
        for g in self.gates:
            z = g.prob_open() if not training else g.sample(True)
            zs.append(z)
        z = torch.stack(zs, dim=0).to(self.A.dtype).to(self.A.device)      # [K]
        z = z.repeat_interleave(self.block_size)                            # [R]
        h = F.linear(x, self.B)                                            # [*, R]
        h = h * z                                                          # mask rank channels
        y = F.linear(h, self.A)                                            # [*, out]
        if self.extra_bias is not None:
            y = y + self.extra_bias
        return y

def _maybe_replace(module, name, n_atoms, rank):
    old = getattr(module, name, None)
    if not isinstance(old, nn.Linear): return 0
    fused = FusedAtomLinear(old.in_features, old.out_features, n_atoms=n_atoms, rank=rank,
                            add_bias=False, device=old.weight.device, dtype=old.weight.dtype)
    # weights will be loaded from Phase-1 checkpoint, so no SVD init here
    setattr(module, name, fused)
    return 1

def inject_fused_atoms_into_segformer(model, n_atoms_qkv=8, n_atoms_out=12, n_atoms_ffn=8, rank=4):
    cnt = 0
    enc = model.segformer.encoder
    stages = getattr(enc, "block", None) or getattr(enc, "blocks", None)
    if stages is None: raise RuntimeError("SegFormer encoder stages not found.")
    for stage in stages:
        for blk in stage:
            attn = blk.attention
            attn_self = getattr(attn, "self", attn)
            attn_out  = getattr(attn, "output", attn)
            cnt += _maybe_replace(attn_self, "query", n_atoms_qkv, rank) or _maybe_replace(attn_self, "q", n_atoms_qkv, rank)
            cnt += _maybe_replace(attn_self, "key",   n_atoms_qkv, rank) or _maybe_replace(attn_self, "k", n_atoms_qkv, rank)
            cnt += _maybe_replace(attn_self, "value", n_atoms_qkv, rank) or _maybe_replace(attn_self, "v", n_atoms_qkv, rank)
            cnt += _maybe_replace(attn_out,  "dense", n_atoms_out,  rank) or _maybe_replace(attn_out, "proj", n_atoms_out, rank) or _maybe_replace(attn_out, "out", n_atoms_out, rank)
            mlp = blk.mlp
            cnt += _maybe_replace(mlp, "dense1", n_atoms_ffn, rank) or _maybe_replace(mlp, "fc1", n_atoms_ffn, rank)
            cnt += _maybe_replace(mlp, "dense2", n_atoms_ffn, rank) or _maybe_replace(mlp, "fc2", n_atoms_ffn, rank)
    return cnt

# --------- KD / utils ----------
def upsample_to(labels_hw, logits):
    return F.interpolate(logits, size=labels_hw, mode="bilinear", align_corners=False)

def ce_loss(logits, labels):
    return F.cross_entropy(logits, labels.long(), ignore_index=IGNORE_INDEX)

def kd_loss_masked(s_logits, t_logits, labels, T=2.0):
    if t_logits.shape[-2:] != s_logits.shape[-2:]:
        t_logits = F.interpolate(t_logits, size=s_logits.shape[-2:], mode="bilinear", align_corners=False)
    sl = F.log_softmax(s_logits / T, dim=1)
    tl = F.softmax(t_logits / T, dim=1)
    kd_map = F.kl_div(sl, tl, reduction="none").sum(dim=1)  # [B,h,w]
    with torch.no_grad():
        if labels.shape[-2:] != s_logits.shape[-2:]:
            labels_small = F.interpolate(labels.float().unsqueeze(1), size=s_logits.shape[-2:], mode="nearest").squeeze(1).long()
        else:
            labels_small = labels
        valid = (labels_small != IGNORE_INDEX).float()
    denom = valid.sum().clamp_min(1.0)
    return (kd_map * valid).sum() / denom * (T*T)

@torch.no_grad()
def init_ema_from_model(model):
    ema = {}
    for n, p in model.named_parameters():
        if p.requires_grad:
            ema[n] = p.detach().clone()
    return ema

@torch.no_grad()
def update_ema(ema, model, decay):
    for n, p in model.named_parameters():
        if n in ema:
            ema[n].mul_(decay).add_(p.detach(), alpha=1.0 - decay)

@torch.no_grad()
def swap_in_ema(model, ema):
    backup = {}
    for n, p in model.named_parameters():
        if n in ema:
            backup[n] = p.detach().clone()
            p.data.copy_(ema[n])
    return backup

@torch.no_grad()
def restore_backup(model, backup):
    for n, p in model.named_parameters():
        if n in backup:
            p.data.copy_(backup[n])

# --------- Read Phase-1 fused-atom hparams ----------
cfg_path = Path(STUDENT_P1_ROOT) / "student_cfg.json"
if cfg_path.exists():
    with open(cfg_path) as f:
        _d = json.load(f)
    n_qkv = int(_d.get("n_atoms_attn_qkv", 8))
    n_out = int(_d.get("n_atoms_attn_out", 12))
    n_ffn = int(_d.get("n_atoms_ffn", 8))
    rnk   = int(_d.get("rank_per_atom", 4))
else:
    n_qkv, n_out, n_ffn, rnk = 8, 12, 8, 4

print(f"[KR+EMA] Using fused-atom hparams: qkv={n_qkv}, out={n_out}, ffn={n_ffn}, rank={rnk}")

# --------- Build student, inject fused atoms, load Phase-1 weights ----------
kr_student = SegformerForSemanticSegmentation.from_pretrained(TEACHER_CKPT).to(DEVICE)
wrap_cnt = inject_fused_atoms_into_segformer(kr_student, n_atoms_qkv=n_qkv, n_atoms_out=n_out, n_atoms_ffn=n_ffn, rank=rnk)
print(f"[KR+EMA] Injected fused layers: {wrap_cnt}")

# Load weights from STUDENT_DIR
state = None
st_path = Path(STUDENT_DIR) / "model.safetensors"
pt_path = Path(STUDENT_DIR) / "pytorch_model.bin"
if st_path.exists():
    try:
        from safetensors.torch import load_file as safe_load
        state = safe_load(str(st_path))
    except Exception:
        state = None
if state is None and pt_path.exists():
    state = torch.load(str(pt_path), map_location="cpu")
if state is None:
    raise FileNotFoundError(f"Нет model.safetensors/pytorch_model.bin в {STUDENT_DIR}")

missing, unexpected = kr_student.load_state_dict(state, strict=False)
print(f"[KR+EMA] load_state_dict: missing={len(missing)}, unexpected={len(unexpected)}")

image_processor_kr = AutoImageProcessor.from_pretrained(STUDENT_DIR)

# Freeze gates; train only A/B (+bias); hard masks via eval()
for n, p in kr_student.named_parameters():
    if "gates" in n or "log_alpha" in n:
        p.requires_grad_(False)
    else:
        p.requires_grad_(("A" in n) or ("B" in n) or ("extra_bias" in n))

train_params = [p for p in kr_student.parameters() if p.requires_grad]
optim = torch.optim.AdamW(train_params, lr=cfg.lr, weight_decay=cfg.weight_decay)
scaler = torch.amp.GradScaler('cuda', enabled=(DEVICE=='cuda' and not BF16))

# Teacher for KD
teacher = SegformerForSemanticSegmentation.from_pretrained(TEACHER_CKPT).to(DEVICE)
teacher.eval()
for p in teacher.parameters(): p.requires_grad_(False)

# EMA
ema_state = init_ema_from_model(kr_student)

# --------- Train (eval mode: hard gates) ----------
best_miou = -1.0
global_step = 0
print("KR-hard + EMA training started...")
kr_student.eval()

for epoch in range(1, cfg.epochs+1):
    run = {"loss":0.0, "ce":0.0, "kd":0.0}
    pbar = tqdm(train_loader_student, desc=f"[KR-hard+EMA] Epoch {epoch}/{cfg.epochs}")
    for it, batch in enumerate(pbar):
        pixel_values = batch["pixel_values"].to(DEVICE, non_blocking=True)
        labels       = batch["labels"].to(DEVICE, non_blocking=True)

        with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(DEVICE=='cuda')):
            t_logits_raw = teacher(pixel_values).logits

        optim.zero_grad(set_to_none=True)
        with torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(DEVICE=='cuda')):
            s_logits_raw = kr_student(pixel_values).logits
            s_logits = upsample_to(labels.shape[-2:], s_logits_raw)
            loss_ce = ce_loss(s_logits, labels)
            loss_kd = kd_loss_masked(s_logits_raw, t_logits_raw, labels, T=cfg.kd_temp) * cfg.kd_alpha
            loss = loss_ce + loss_kd

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(train_params, 1.0)
        scaler.step(optim); scaler.update()
        update_ema(ema_state, kr_student, cfg.ema_decay)

        run["loss"] += loss.item(); run["ce"] += loss_ce.item(); run["kd"] += loss_kd.item()
        global_step += 1
        if global_step % 200 == 0:
            print(f"[heartbeat] step={global_step}")
        pbar.set_postfix({
            "loss": f"{run['loss']/(it+1):.3f}",
            "ce":   f"{run['ce']/(it+1):.3f}",
            "kd":   f"{run['kd']/(it+1):.3f}",
        })
    pbar.close()

    # ----- Validation on EMA -----
    with torch.no_grad():
        backup = swap_in_ema(kr_student, ema_state)
        val_ce, miou_sum, n_batches = 0.0, 0.0, 0
        for batch in val_loader_student:
            pv = batch["pixel_values"].to(DEVICE, non_blocking=True)
            y  = batch["labels"].to(DEVICE, non_blocking=True)
            logits_raw = kr_student(pv).logits
            logits = upsample_to(y.shape[-2:], logits_raw)
            val_ce += ce_loss(logits, y).item()
            m, _, _ = compute_miou_from_logits(logits_raw, y.cpu().numpy(), ignore_index=IGNORE_INDEX)
            miou_sum += m; n_batches += 1
        val_ce /= max(1, len(val_loader_student))
        val_miou = miou_sum / max(1, n_batches)
        restore_backup(kr_student, backup)

    print(f"[KR-hard+EMA] Val CE(EMA): {val_ce:.4f} | Val mIoU(EMA): {val_miou:.4f}")

    # Save best (EMA weights)
    if val_miou > best_miou:
        best_miou = val_miou
        backup = swap_in_ema(kr_student, ema_state)
        kr_student.save_pretrained(SAVE_BEST_DIR)
        image_processor_kr.save_pretrained(SAVE_BEST_DIR)
        torch.save({
            "global_step": global_step,
            "best_miou": best_miou,
            "stage": "kr-hard+ema",
            "hparams": cfg.__dict__
        }, os.path.join(SAVE_BEST_DIR, "training_state.pt"))
        restore_backup(kr_student, backup)
        print("Saved BEST (EMA, by mIoU) →", SAVE_BEST_DIR)

    # Per-epoch snapshot (EMA weights)
    ep_dir = os.path.join(OUT_ROOT, f"epoch_{epoch:03d}_miou{val_miou:.4f}")
    os.makedirs(ep_dir, exist_ok=True)
    backup = swap_in_ema(kr_student, ema_state)
    kr_student.save_pretrained(ep_dir)
    image_processor_kr.save_pretrained(ep_dir)
    torch.save({
        "epoch": epoch,
        "global_step": global_step,
        "val_miou": val_miou,
        "stage": "kr-hard+ema",
        "hparams": cfg.__dict__
    }, os.path.join(ep_dir, "training_state.pt"))
    restore_backup(kr_student, backup)
    print("Epoch checkpoint (EMA) saved to:", ep_dir)

print("KR-hard+EMA finished. Best Val mIoU (EMA):", best_miou)


[KR+EMA] Using fused-atom hparams: qkv=8, out=12, ffn=8, rank=4
[KR+EMA] Injected fused layers: 48
[KR+EMA] load_state_dict: missing=832, unexpected=96
KR-hard + EMA training started...


[KR-hard+EMA] Epoch 1/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=200
[heartbeat] step=400
[heartbeat] step=600
[heartbeat] step=800
[heartbeat] step=1000
[heartbeat] step=1200
[KR-hard+EMA] Val CE(EMA): 0.0619 | Val mIoU(EMA): 0.9287
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_001_miou0.9287


[KR-hard+EMA] Epoch 2/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=1400
[heartbeat] step=1600
[heartbeat] step=1800
[heartbeat] step=2000
[heartbeat] step=2200
[heartbeat] step=2400
[KR-hard+EMA] Val CE(EMA): 0.0587 | Val mIoU(EMA): 0.9315
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_002_miou0.9315


[KR-hard+EMA] Epoch 3/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=2600
[heartbeat] step=2800
[heartbeat] step=3000
[heartbeat] step=3200
[heartbeat] step=3400
[heartbeat] step=3600
[heartbeat] step=3800
[KR-hard+EMA] Val CE(EMA): 0.0568 | Val mIoU(EMA): 0.9332
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_003_miou0.9332


[KR-hard+EMA] Epoch 4/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=4000
[heartbeat] step=4200
[heartbeat] step=4400
[heartbeat] step=4600
[heartbeat] step=4800
[heartbeat] step=5000
[KR-hard+EMA] Val CE(EMA): 0.0554 | Val mIoU(EMA): 0.9345
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_004_miou0.9345


[KR-hard+EMA] Epoch 5/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=5200
[heartbeat] step=5400
[heartbeat] step=5600
[heartbeat] step=5800
[heartbeat] step=6000
[heartbeat] step=6200
[heartbeat] step=6400
[KR-hard+EMA] Val CE(EMA): 0.0553 | Val mIoU(EMA): 0.9349
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_005_miou0.9349


[KR-hard+EMA] Epoch 6/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=6600
[heartbeat] step=6800
[heartbeat] step=7000
[heartbeat] step=7200
[heartbeat] step=7400
[heartbeat] step=7600
[KR-hard+EMA] Val CE(EMA): 0.0546 | Val mIoU(EMA): 0.9356
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_006_miou0.9356


[KR-hard+EMA] Epoch 7/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=7800
[heartbeat] step=8000
[heartbeat] step=8200
[heartbeat] step=8400
[heartbeat] step=8600
[heartbeat] step=8800
[KR-hard+EMA] Val CE(EMA): 0.0563 | Val mIoU(EMA): 0.9335
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_007_miou0.9335


[KR-hard+EMA] Epoch 8/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=9000
[heartbeat] step=9200
[heartbeat] step=9400
[heartbeat] step=9600
[heartbeat] step=9800
[heartbeat] step=10000
[heartbeat] step=10200
[KR-hard+EMA] Val CE(EMA): 0.0562 | Val mIoU(EMA): 0.9342
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_008_miou0.9342


[KR-hard+EMA] Epoch 9/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=10400
[heartbeat] step=10600
[heartbeat] step=10800
[heartbeat] step=11000
[heartbeat] step=11200
[heartbeat] step=11400
[KR-hard+EMA] Val CE(EMA): 0.0554 | Val mIoU(EMA): 0.9356
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_009_miou0.9356


[KR-hard+EMA] Epoch 10/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=11600
[heartbeat] step=11800
[heartbeat] step=12000
[heartbeat] step=12200
[heartbeat] step=12400
[heartbeat] step=12600
[heartbeat] step=12800
[KR-hard+EMA] Val CE(EMA): 0.0560 | Val mIoU(EMA): 0.9348
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_010_miou0.9348


[KR-hard+EMA] Epoch 11/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=13000
[heartbeat] step=13200
[heartbeat] step=13400
[heartbeat] step=13600
[heartbeat] step=13800
[heartbeat] step=14000
[KR-hard+EMA] Val CE(EMA): 0.0552 | Val mIoU(EMA): 0.9354
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_011_miou0.9354


[KR-hard+EMA] Epoch 12/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=14200
[heartbeat] step=14400
[heartbeat] step=14600
[heartbeat] step=14800
[heartbeat] step=15000
[heartbeat] step=15200
[heartbeat] step=15400
[KR-hard+EMA] Val CE(EMA): 0.0546 | Val mIoU(EMA): 0.9364
Saved BEST (EMA, by mIoU) → /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/best
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_012_miou0.9364


[KR-hard+EMA] Epoch 13/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=15600
[heartbeat] step=15800
[heartbeat] step=16000
[heartbeat] step=16200
[heartbeat] step=16400
[heartbeat] step=16600
[KR-hard+EMA] Val CE(EMA): 0.0547 | Val mIoU(EMA): 0.9363
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_013_miou0.9363


[KR-hard+EMA] Epoch 14/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=16800
[heartbeat] step=17000
[heartbeat] step=17200
[heartbeat] step=17400
[heartbeat] step=17600
[heartbeat] step=17800
[KR-hard+EMA] Val CE(EMA): 0.0558 | Val mIoU(EMA): 0.9357
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_014_miou0.9357


[KR-hard+EMA] Epoch 15/15:   0%|          | 0/1284 [00:00<?, ?it/s]

[heartbeat] step=18000
[heartbeat] step=18200
[heartbeat] step=18400
[heartbeat] step=18600
[heartbeat] step=18800
[heartbeat] step=19000
[heartbeat] step=19200
[KR-hard+EMA] Val CE(EMA): 0.0551 | Val mIoU(EMA): 0.9361
Epoch checkpoint (EMA) saved to: /content/drive/MyDrive/human_segmentation_experiments_article0_final_flow/student_phase2_kr_ema/kr_hard_ema/epoch_015_miou0.9361
KR-hard+EMA finished. Best Val mIoU (EMA): 0.9363926633325608
