# üéØ VT-MOT Far-View Gated Mid-Fusion Training

**Strategy:** Transfer Learning V2 (Near‚ÜíFar Domain Adaptation)  
**Model:** YOLOv11x-RGBT with GatedSpatialFusion_V3  
**Dataset:** vtmot_far (284k images, single class: person)  

## Prerequisites
1. Upload dataset zips as Kaggle Datasets:
   - `vtmot-far-train`: vtmot_far_train_part{1..4}.zip
   - `vtmot-far-valtest`: vtmot_far_val_test_part{1..2}.zip
2. Upload `vtmot_weights.zip` to Google Drive, set sharing to 'Anyone with link'
3. Enable **GPU** accelerator
4. Set persistence to **Files**

## Cell 1: Clone Repository & Install Dependencies

In [None]:
import os, subprocess, sys

WORK_DIR = "/kaggle/working"
REPO_DIR = os.path.join(WORK_DIR, "tracking-weapon")

# Clone repo
if not os.path.exists(REPO_DIR):
    !git clone --depth 1 https://github.com/zok213/tracking-weapon.git {REPO_DIR}
    print("‚úÖ Repository cloned.")
else:
    print("‚úÖ Repository already exists.")

# Install modified ultralytics (with GatedSpatialFusion_V3)
!pip install -e {REPO_DIR}/YOLOv11-RGBT --quiet --no-deps
!pip install einops>=0.7 timm>=0.9 efficientnet-pytorch>=0.7.1 albumentations>=1.0.3 thop psutil gdown --quiet

# Verify
import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

from ultralytics import YOLO
from ultralytics.nn.modules.block import GatedSpatialFusion_V3
print("‚úÖ GatedSpatialFusion_V3 loaded successfully!")

## Cell 2: Download Weights from Google Drive

**Instructions:** Upload `vtmot_weights.zip` (395MB) to Google Drive ‚Üí Share ‚Üí Anyone with link ‚Üí Copy the file ID from the URL.

In [None]:
import gdown, zipfile

# ============================================================
# ‚ö†Ô∏è  PASTE YOUR GOOGLE DRIVE FILE ID HERE
# From URL: https://drive.google.com/file/d/<FILE_ID>/view
# ============================================================
GDRIVE_FILE_ID = "PASTE_YOUR_FILE_ID_HERE"  # <-- CHANGE THIS!
# ============================================================

WEIGHTS_DIR = os.path.join(WORK_DIR, "weights")
os.makedirs(WEIGHTS_DIR, exist_ok=True)

weights_zip = os.path.join(WEIGHTS_DIR, "vtmot_weights.zip")
best_pt = os.path.join(WEIGHTS_DIR, "best_near_gated_phase1.pt")

if not os.path.exists(best_pt):
    print("üì• Downloading weights from Google Drive...")
    gdown.download(id=GDRIVE_FILE_ID, output=weights_zip, quiet=False)
    
    print("üì¶ Extracting weights...")
    with zipfile.ZipFile(weights_zip, 'r') as zf:
        zf.extractall(WEIGHTS_DIR)
    os.remove(weights_zip)
    print("‚úÖ Weights extracted.")
else:
    print("‚úÖ Weights already exist.")

# List weights
for f in os.listdir(WEIGHTS_DIR):
    fp = os.path.join(WEIGHTS_DIR, f)
    if f.endswith('.pt'):
        print(f"   {f}: {os.path.getsize(fp) / (1024**2):.0f} MB")

## Cell 3: Extract Dataset from Kaggle Inputs

In [None]:
import glob, zipfile
from tqdm import tqdm

DATASET_DIR = os.path.join(WORK_DIR, "datasets/vtmot_far")
os.makedirs(DATASET_DIR, exist_ok=True)

KAGGLE_INPUT = "/kaggle/input"

# Auto-discover all vtmot zip files from inputs
all_zips = glob.glob(os.path.join(KAGGLE_INPUT, "**", "*.zip"), recursive=True)
vtmot_zips = [z for z in all_zips if "vtmot_far" in os.path.basename(z)]

if not vtmot_zips:
    # Try broader search
    vtmot_zips = all_zips
    print(f"‚ö†Ô∏è No 'vtmot_far' zips found. Found {len(all_zips)} total zips:")
    for z in all_zips:
        print(f"   {z} ({os.path.getsize(z)/(1024**3):.1f} GB)")

print(f"\nüì¶ Found {len(vtmot_zips)} dataset zip files:")
for z in sorted(vtmot_zips):
    print(f"   {os.path.basename(z)}: {os.path.getsize(z)/(1024**3):.1f} GB")

# Check if already extracted
train_images = os.path.join(DATASET_DIR, "images", "train")
if os.path.exists(train_images) and len(os.listdir(train_images)) > 1000:
    print(f"\n‚úÖ Dataset already extracted ({len(os.listdir(train_images))} train images)")
else:
    for z in tqdm(sorted(vtmot_zips), desc="Extracting"):
        print(f"   üì¶ {os.path.basename(z)}...")
        with zipfile.ZipFile(z, 'r') as zf:
            zf.extractall(DATASET_DIR)
    print("‚úÖ All zips extracted.")

# Verify structure
print("\nüìä Dataset Verification:")
for split in ["train", "val", "test"]:
    img_dir = os.path.join(DATASET_DIR, "images", split)
    lbl_dir = os.path.join(DATASET_DIR, "labels", split)
    n_imgs = len(os.listdir(img_dir)) if os.path.exists(img_dir) else 0
    n_lbls = len(os.listdir(lbl_dir)) if os.path.exists(lbl_dir) else 0
    status = "‚úÖ" if n_imgs > 0 else "‚ùå"
    print(f"   {status} {split}: {n_imgs:,} images, {n_lbls:,} labels")

## Cell 4: Create Dataset YAML

In [None]:
yaml_path = os.path.join(DATASET_DIR, "far_view_kaggle.yaml")

yaml_content = f"""path: {DATASET_DIR}
train: images/train
val: images/val
test: images/test
names:
  0: person
"""

with open(yaml_path, 'w') as f:
    f.write(yaml_content)

print(f"‚úÖ Dataset YAML created: {yaml_path}")
print(f"\nContents:")
print(yaml_content)

## Cell 5: Setup Training Components

MCFTrainer with Gate Supervision + Gradient Clipping + bbox_decode patch

In [None]:
import torch
from pathlib import Path
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils import loss as loss_module

# Add repo to path for gate_supervision, visualize_gates
sys.path.insert(0, REPO_DIR)

# v2.7 bbox_decode device patch
_original_bbox_decode = loss_module.v8DetectionLoss.bbox_decode
def _patched_bbox_decode(self, anchor_points, pred_dist):
    if self.use_dfl:
        if self.proj.device != pred_dist.device:
            self.proj = self.proj.to(pred_dist.device)
    return _original_bbox_decode(self, anchor_points, pred_dist)
loss_module.v8DetectionLoss.bbox_decode = _patched_bbox_decode
print("[OK] bbox_decode device patch applied.")


class MCFTrainer(DetectionTrainer):
    """Custom trainer with Gated Fusion + Gate Supervision."""
    
    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None, mcf_model=None):
        self._mcf_model = mcf_model
        super().__init__(cfg, overrides, _callbacks)
    
    def setup_model(self):
        model = self._mcf_model if self._mcf_model is not None else super().setup_model()
        count = 0
        try:
            from ultralytics.nn.modules.block import GatedSpatialFusion_V3
            modules = model.modules() if hasattr(model, 'modules') else model.model.modules()
            for m in modules:
                if isinstance(m, GatedSpatialFusion_V3):
                    m.export_gates = True
                    count += 1
        except ImportError:
            pass
        print(f"[MCFTrainer] Gate export enabled on {count} layers.")
        
        if self._mcf_model is not None:
            self.model = model
            self.model.to(self.device)
            self.model.args = self.args
            return self.model
        return model

    def get_loss(self):
        loss = super().get_loss()
        try:
            from gate_supervision import GatedDetectionLoss
            print("[MCFTrainer] Gate supervision loss active.")
            return GatedDetectionLoss(self.model, loss)
        except ImportError:
            print("[MCFTrainer] gate_supervision not found, standard loss.")
            return loss


def on_train_start(trainer):
    """Gradient clipping callback."""
    if hasattr(trainer, 'optimizer') and trainer.optimizer is not None:
        orig_step = trainer.optimizer.step
        def clipped_step(closure=None):
            torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), max_norm=10.0)
            return orig_step(closure)
        trainer.optimizer.step = clipped_step
        print("‚ö° Gradient Clipping (norm=10.0) applied.")

print("‚úÖ MCFTrainer + callbacks defined.")

## Cell 6: Load Model Weights

In [None]:
device = 0
resume_flag = False

runs_dir = os.path.join(WORK_DIR, "runs")
project_dir = os.path.join(runs_dir, "far_gated_deployment")
run_name = "far_view_gated_kaggle"
last_ckpt = os.path.join(project_dir, run_name, "weights", "last.pt")

# Weight selection priority: resume > near best.pt > FLIR > random
best_pt_path = os.path.join(WEIGHTS_DIR, "best_near_gated_phase1.pt")
flir_pt_path = None
for f in os.listdir(WEIGHTS_DIR):
    if "FLIR" in f and f.endswith(".pt"):
        flir_pt_path = os.path.join(WEIGHTS_DIR, f)

if os.path.exists(last_ckpt):
    print(f"[RESUME] Loading checkpoint: {last_ckpt}")
    model = YOLO(last_ckpt)
    resume_flag = True
elif os.path.exists(best_pt_path):
    print(f"[TRANSFER] Loading near-view best.pt: {best_pt_path}")
    model = YOLO(best_pt_path)
    print("‚úÖ Transfer Learning V2: Near‚ÜíFar domain adaptation")
    print("   Gated Fusion layers pre-trained from near-view!")
elif flir_pt_path and os.path.exists(flir_pt_path):
    print(f"[FALLBACK] Using FLIR pretrained: {flir_pt_path}")
    model_yaml = os.path.join(REPO_DIR, "YOLOv11-RGBT/ultralytics/cfg/models/11-RGBT/yolo11x-RGBT-gated-v3.yaml")
    model = YOLO(model_yaml)
    ckpt = torch.load(flir_pt_path, map_location='cpu')
    if 'model' in ckpt:
        chk_sd = ckpt['model'].state_dict()
        mdl_sd = model.model.state_dict()
        filtered = {k: v for k, v in chk_sd.items() if k in mdl_sd and v.shape == mdl_sd[k].shape}
        if filtered:
            model.model.load_state_dict(filtered, strict=False)
            print(f"‚úÖ Transferred {len(filtered)} FLIR layers.")
else:
    print("[INIT] No pretrained weights found. Random init.")
    model_yaml = os.path.join(REPO_DIR, "YOLOv11-RGBT/ultralytics/cfg/models/11-RGBT/yolo11x-RGBT-gated-v3.yaml")
    model = YOLO(model_yaml)

print(f"\nModel ready. Resume: {resume_flag}")

## Cell 7: üöÄ Launch Training

- **30 epochs**, SGD + Cosine LR
- **batch=8** (Kaggle 16GB VRAM safe)
- Strong augmentations for far-view small objects
- Checkpoints saved every 5 epochs

In [None]:
config = {
    'model': last_ckpt if resume_flag else 'yolo11x.pt',
    'data': yaml_path,
    'epochs': 30,
    'imgsz': 640,
    'batch': 8,
    'device': device,
    'use_simotm': 'RGBRGB6C',
    'channels': 6,
    'pairs_rgb_ir': ['_rgb_', '_ir_'],
    'optimizer': 'SGD',
    'lr0': 0.005,
    'lrf': 0.01,
    'cos_lr': True,
    'momentum': 0.937,
    'weight_decay': 0.0005,
    'warmup_epochs': 2,
    'warmup_bias_lr': 0.05,
    'mosaic': 1.0,
    'mixup': 0.15,
    'copy_paste': 0.1,
    'scale': 0.7,
    'close_mosaic': 10,
    'patience': 15,
    'save_period': 5,
    'freeze': None,
    'project': project_dir,
    'name': run_name,
    'exist_ok': True,
    'cache': 'disk',
    'workers': 2,
    'resume': resume_flag,
}

print("=" * 70)
print("üöÄ FAR-VIEW GATED MID-FUSION ‚Äî DEPLOYMENT TRAINING")
print("=" * 70)
print(f"  GPU:       {torch.cuda.get_device_name(0)}")
print(f"  Dataset:   {yaml_path}")
print(f"  Epochs:    {config['epochs']}")
print(f"  Batch:     {config['batch']}")
print(f"  Optimizer: SGD (lr=0.005, cosine)")
print(f"  Output:    {project_dir}/{run_name}")
print("=" * 70)

trainer = MCFTrainer(
    overrides=config,
    mcf_model=None if resume_flag else model.model
)
trainer.add_callback("on_train_start", on_train_start)
trainer.train()

print("\n‚úÖ Training Complete!")
print(f"   Best weights: {project_dir}/{run_name}/weights/best.pt")

## Cell 8: Save & Download Results

In [None]:
import shutil

output_dir = "/kaggle/working/output"
os.makedirs(output_dir, exist_ok=True)

# Copy best weights
best_weight = os.path.join(project_dir, run_name, "weights", "best.pt")
last_weight = os.path.join(project_dir, run_name, "weights", "last.pt")

for src, name in [(best_weight, "best_far_gated.pt"), (last_weight, "last_far_gated.pt")]:
    if os.path.exists(src):
        dest = os.path.join(output_dir, name)
        shutil.copy2(src, dest)
        print(f"‚úÖ {name}: {os.path.getsize(dest)/(1024**2):.0f} MB")

# Copy results CSV
results_csv = os.path.join(project_dir, run_name, "results.csv")
if os.path.exists(results_csv):
    shutil.copy2(results_csv, os.path.join(output_dir, "results.csv"))
    print("‚úÖ results.csv saved.")

# Copy training plots
for plot in ['results.png', 'confusion_matrix.png', 'PR_curve.png', 'F1_curve.png']:
    src = os.path.join(project_dir, run_name, plot)
    if os.path.exists(src):
        shutil.copy2(src, os.path.join(output_dir, plot))
        print(f"‚úÖ {plot} saved.")

print(f"\nüéØ All outputs in: {output_dir}/")
print("   Download from Kaggle Output tab.")
print(f"\nFiles:")
for f in sorted(os.listdir(output_dir)):
    fp = os.path.join(output_dir, f)
    print(f"   {f}: {os.path.getsize(fp)/(1024**2):.1f} MB")