## 0. Infrastructure Setup

### 0.1 Utils Module
All helpful methods including validate local path, local logging, serialise and deserialise json file, read and write files, create and delete path

# Minecraft Voxel World LLM Training

## Project Overview
This notebook implements a complete pipeline for training Large Language Models (LLMs) on Minecraft voxel-based sequential prediction tasks:
- **Frame Prediction**: Given current voxel state + action → predict next voxel state
- **Action Recognition**: Given current voxel state + next voxel state → predict action taken

## Dataset
- **Source**: Minecraft gameplay data with 3D voxel representations
- **Format**: Sequential .npy files containing voxel grids and actions
- **Structure**: Each frame contains a 3×3×3 block grid and action vector

## Models
- **Qwen 3 0.6B**: Small-scale model for efficient training
- **Qwen 3 4B**: Larger model for improved performance

## Methods
- In-context learning (few-shot prompting with training examples)
- Supervised fine-tuning with LoRA for frame reconstruction
- Supervised fine-tuning with LoRA for action recognition

---

## Table of Contents

### 0. Infrastructure Setup
- **0.1** Utils Module - File I/O, logging, JSON operations
- **0.2** Model Wrapper Class - Training, evaluation, checkpoint management
- **0.3** Plot Evaluation Class - Conference-quality visualizations
- **0.4** Hyperparameter Configuration - Grid search support

### 1. Setup
- **1.1** Load Models - Qwen 3 0.6B and 4B configuration
- **1.2** Load Minecraft Dataset - Sequential voxel frames with actions
- **1.3** Split Data - Train/val/test split (70%/15%/15%)

### 2. In-Context Learning Evaluation
- **2.1** Frame Reconstruction - Input: x+y, Output: z (with 3 training examples as context)
- **2.2** Frame Reconstruction Plots - Visualization of results
- **2.3** Action Recognition - Input: x+z, Output: y (with 3 training examples as context)
- **2.4** Action Recognition Plots - Visualization of results

### 3. Supervised Fine-Tuning (LoRA) for Frame Reconstruction 
- **3.1** Fine-tune Frame Reconstruction - LoRA adaptation with W&B monitoring
- **3.2** Evaluate Fine-tuned Models - Test set performance
- **3.3** Plot Fine-tuning Results - Compare in-context vs fine-tuned

### 4. Supervised Fine-Tuning (LoRA) for Action Recognition
- **4.1** Fine-tune Action Recognition - Train LoRA adapters to predict discrete actions
- **4.2** Evaluate Action Recognition - Test set metrics and JSON export
- **4.3** Plot Action Recognition Comparison - Bar charts versus zero-shot baseline

---

In [1]:
from utils_module import Utils

### 0.2 Model Wrapper Class
Including loading with name method, train with dataloaders method, and evaluate method.

With loaded data train and stop in val and monitor via W&B. Do not pass model parameters to W&B. Keep them in local dir `checkpoints/` with proper naming and also keep a log in the dir `logs/`.  Create a JSON file with proper name of task in the working dir given the match between the run folder path under checkpoints and the run log path.

The checkpoint resume from latest feature should be implemented - we do not want to train repeatedly.

In [2]:
from model_wrapper import ModelWrapper

### 0.3 Plot Evaluation Class
Including all methods we need to plot conference-level paper quality plots.

In [3]:
from plot_utils import PlotUtils

### 0.4 Hyperparameter Configuration  
Define all configurable hyperparameters and provide grid search method.

Keep a local JSON called `grid-search-record.json` to save past running results. Each time we run the whole notebook, if we enable grid search, we have to read the JSON file and continue to the next grid search values.

In [4]:
from hyperparameter_config import HyperparameterConfig

config = HyperparameterConfig({
    "learning_rate": 5e-5,
    "num_epochs": 3,
    "batch_size": 32,
    "max_length": 1024,
    "lora_r": 8,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    "warmup_steps": 100,
    "max_grad_norm": 1.0,
    "wandb_project": "minecraft-llm",
})
config.print_config()

HYPERPARAMETER CONFIGURATION
batch_size          : 32
learning_rate       : 5e-05
lora_alpha          : 32
lora_dropout        : 0.1
lora_r              : 8
max_grad_norm       : 1.0
max_length          : 1024
num_epochs          : 3
wandb_project       : minecraft-llm
warmup_steps        : 100


1.1 load model via transformers, we pick Qwen3-0.6B and Qwen3-4B

In [5]:
from model_registry import (
    MODEL_PATHS,
    DEVICE,
    WANDB_ENABLED,
    get_model_wrapper,
    release_model,
    release_all_models,
)

print(f"Available models: {list(MODEL_PATHS.keys())}")
print(f"Using device: {DEVICE}")
print("Call get_model_wrapper('qwen3-0.6b') to load a model when needed.")


Available models: ['qwen3-0.6b', 'qwen3-4b']
Using device: cuda
Call get_model_wrapper('qwen3-0.6b') to load a model when needed.


1.2 load custom data from local datasets dir, the data with 3 types of data, 

- x : current frame in ascii art, 
- y: current action token, 
- z: next frame ascii art, 

all in plain text format.

In [6]:
from pathlib import Path
from dataset_loader import load_full_dataset, preview_dataset_example

DATA_DIR = Path("datasets/minecraft/data")
full_dataset, TOTAL_PAIRS, UNIQUE_ACTIONS = load_full_dataset(config, DATA_DIR)

preview_dataset_example(full_dataset, 0)


Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:0
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:1
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:10
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:11
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:12
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:13
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:14
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:15
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:16
Loading 1000 frames from datasets/minecraft/data/seq0-49/creative:17
Loaded 10000 frames across 10 episodes
Created 9900 training pairs
Loaded 9900 sequential frame/action pairs from datasets/minecraft/data.
Unique actions: ['straight: backward\npan: left\njump: jump\n', 'straight: backward\npan: left\njump: noop\n', 'straight: backward\npan: noop\njump: jump\n', 'straight: back

1.3 split loaded data to all, train, val and test.

In [7]:
from dataset_utils import (
    compute_dataset_splits,
    build_dataloader,
    build_action_embedder,
)

TRAIN_SUBSET_FRACTION = 0.1
SPLITS, SELECTED_INDICES = compute_dataset_splits(full_dataset, subset_fraction=TRAIN_SUBSET_FRACTION)

INFERENCE_SPLITS, _ = compute_dataset_splits(full_dataset, subset_fraction=0.05)
INFERENCE_TEST_INDICES = INFERENCE_SPLITS['test']

TOTAL_SAMPLES = len(SELECTED_INDICES)
MAX_NEW_TOKENS = 512
ACTION_EMBEDDER = build_action_embedder()

print("Dataset split sizes:", {split: len(idxs) for split, idxs in SPLITS.items()})
print("Inference test split size:", len(INFERENCE_TEST_INDICES))
print(f"History length per sample: {full_dataset.history_length}. Prompt context disabled (using loaded history).")


Dataset split sizes: {'train': 693, 'val': 148, 'test': 149}
Inference test split size: 75
History length per sample: 10. Prompt context disabled (using loaded history).


2.1 evaluate both models using history-only inputs for frame reconstruction: h → predict z, save to 2.1-result.json


In [8]:
from pathlib import Path

frame_results_path = Path("2.1-result.json")
frame_raw_path = Path("2.1-raw.json")

if frame_results_path.exists() and frame_raw_path.exists():
    frame_results = Utils.load_json(frame_results_path) or {}
    frame_raw_outputs = Utils.load_json(frame_raw_path) or {}
    print(f"Loaded cached frame reconstruction results from {frame_results_path}")
    print(f"Loaded cached raw outputs from {frame_raw_path}")
else:
    test_indices = INFERENCE_TEST_INDICES if 'INFERENCE_TEST_INDICES' in globals() else SPLITS['test']

    print(
        f"Evaluating frame reconstruction on {len(test_indices)} samples using history length {full_dataset.history_length}."
    )

    frame_results = {}
    frame_raw_outputs = {}
    for model_key in MODEL_PATHS:
        wrapper = get_model_wrapper(model_key)
        metrics = wrapper.evaluate_task(
            full_dataset,
            test_indices,
            task_type="frame_reconstruction",
            model_key=model_key,
            batch_size=1,
            max_new_tokens=MAX_NEW_TOKENS,
        )

        predictions = metrics.get("predictions", [])
        targets = metrics.get("targets", [])
        raw_records = []
        if len(predictions) != len(test_indices):
            print(
                f"Warning: prediction count {len(predictions)} does not match test indices {len(test_indices)} for {model_key}."
            )
        for idx, pred, target in zip(test_indices, predictions, targets):
            pair = full_dataset.data_pairs[int(idx)]
            history = pair.get("history_reconstruction") or pair.get("history_action") or pair.get("x")
            raw_records.append(
                {
                    "index": int(idx),
                    "episode": pair.get("episode"),
                    "history": history,
                    "z_label": target,
                    "z_prediction": pred,
                }
            )
        frame_raw_outputs[model_key] = raw_records

        frame_results[model_key] = {k: v for k, v in metrics.items() if k not in {"predictions", "targets"}}
        release_model(model_key)

    Utils.save_json(frame_results, frame_results_path)
    Utils.save_json(frame_raw_outputs, frame_raw_path)
    print(f"Saved zero-shot frame reconstruction summary to {frame_results_path}")
    print(f"Saved frame reconstruction raw outputs to {frame_raw_path}")

    release_all_models()

frame_results


Evaluating frame reconstruction on 75 samples using history length 10.
Loading model: models/Qwen3-0.6B


`torch_dtype` is deprecated! Use `dtype` instead!


Model loaded on cuda


Evaluating:   0%|          | 0/75 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating: 100%|██████████| 75/75 [06:23<00:00,  5.11s/it]


Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded on cuda


Evaluating: 100%|██████████| 75/75 [13:55<00:00, 11.14s/it]


Saved zero-shot frame reconstruction summary to 2.1-result.json
Saved frame reconstruction raw outputs to 2.1-raw.json


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.05948903146593025,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.05948903146593025,
  'reconstruction_scores': [0.03258145363408521,
   0.06,
   0.06,
   0.06,
   0.06,
   0.06,
   0.0345252774352651,
   0.0012368583797155227,
   0.007416563658838072,
   0.007416563658838072,
   0.007416563658838072,
   0.007416563658838072,
   0.04522613065326633,
   0.03258145363408521,
   0.041327489041953665,
   0.041327489041953665,
   0.041327489041953665,
   0.041327489041953665,
   0.04066543438077634,
   0.001229256299938537,
   0.007371007371007371,
   0.007371007371007371,
   0.007371007371007371,
   0.007371007371007371,
   0.0225140712945591,
   0.034934497816593885,
   0.034934497816593885,
   0.059738643434972,
   0.059738643434972,
   0.059738643434972,
   0.05867970660146699,
   0.001229256299938537,
   0.0196319018404908,
   0

2.2 update zero-shot frame reconstruction plots (h → z)


In [9]:

plot_utils = PlotUtils()
frame_results = Utils.load_json("2.1-result.json") or {}

if not frame_results:
    print("No zero-shot frame reconstruction results found. Run cell 2.1 first.")
else:
    PlotUtils.plot_multi_metric_bar(
        frame_results,
        metric_keys=["strict_match_accuracy", "reconstruction_accuracy"],
        metric_labels=["Strict Match Accuracy", "Reconstruction Accuracy"],
        title="Zero-Shot Frame Reconstruction Metrics",
        save_path="plots/2.2-frame-metric-bars.png",
        scales=[100.0, 100.0],
        ylabel="Score (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_metrics_heatmap(
        frame_results,
        "Zero-Shot Frame Reconstruction Heatmap",
        "plots/2.2-frame-heatmap.png",
        metrics=["strict_match_accuracy", "reconstruction_accuracy"],
        metric_labels=["Strict Match Accuracy (%)", "Reconstruction Accuracy (%)"],
        scales=[100.0, 100.0],
    )

frame_results


Plot saved to plots/2.2-frame-metric-bars.png
Plot saved to plots/2.2-frame-heatmap.png


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.05948903146593025,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.05948903146593025,
  'reconstruction_scores': [0.03258145363408521,
   0.06,
   0.06,
   0.06,
   0.06,
   0.06,
   0.0345252774352651,
   0.0012368583797155227,
   0.007416563658838072,
   0.007416563658838072,
   0.007416563658838072,
   0.007416563658838072,
   0.04522613065326633,
   0.03258145363408521,
   0.041327489041953665,
   0.041327489041953665,
   0.041327489041953665,
   0.041327489041953665,
   0.04066543438077634,
   0.001229256299938537,
   0.007371007371007371,
   0.007371007371007371,
   0.007371007371007371,
   0.007371007371007371,
   0.0225140712945591,
   0.034934497816593885,
   0.034934497816593885,
   0.059738643434972,
   0.059738643434972,
   0.059738643434972,
   0.05867970660146699,
   0.001229256299938537,
   0.0196319018404908,
   0

2.3 evaluate both models using history-only inputs for action recognition: h → predict y, save to 2.3-result.json


In [10]:
from pathlib import Path

action_results_path = Path("2.3-result.json")
action_raw_path = Path("2.3-raw.json")

if action_results_path.exists() and action_raw_path.exists():
    action_results = Utils.load_json(action_results_path) or {}
    action_raw_outputs = Utils.load_json(action_raw_path) or {}
    print(f"Loaded cached action recognition results from {action_results_path}")
    print(f"Loaded cached raw outputs from {action_raw_path}")
else:
    test_indices = INFERENCE_TEST_INDICES if 'INFERENCE_TEST_INDICES' in globals() else SPLITS['test']

    print(
        f"Evaluating action recognition on {len(test_indices)} samples using history length {full_dataset.history_length}."
    )

    action_results = {}
    action_raw_outputs = {}
    for model_key in MODEL_PATHS:
        wrapper = get_model_wrapper(model_key)
        metrics = wrapper.evaluate_task(
            full_dataset,
            test_indices,
            task_type="action_recognition",
            model_key=model_key,
            batch_size=1,
            max_new_tokens=MAX_NEW_TOKENS,
            action_embedder=ACTION_EMBEDDER,
        )

        predictions = metrics.get("predictions", [])
        targets = metrics.get("targets", [])
        raw_records = []
        if len(predictions) != len(test_indices):
            print(
                f"Warning: prediction count {len(predictions)} does not match test indices {len(test_indices)} for {model_key}."
            )
        for idx, pred, target in zip(test_indices, predictions, targets):
            pair = full_dataset.data_pairs[int(idx)]
            history = pair.get("history_action") or pair.get("history_reconstruction") or pair.get("x")
            raw_records.append(
                {
                    "index": int(idx),
                    "episode": pair.get("episode"),
                    "history": history,
                    "y_label": target,
                    "y_prediction": pred,
                }
            )
        action_raw_outputs[model_key] = raw_records

        action_results[model_key] = {k: v for k, v in metrics.items() if k not in {"predictions", "targets"}}
        release_model(model_key)

    Utils.save_json(action_results, action_results_path)
    Utils.save_json(action_raw_outputs, action_raw_path)
    print(f"Saved zero-shot action recognition summary to {action_results_path}")
    print(f"Saved action recognition raw outputs to {action_raw_path}")

    release_all_models()

action_results


Evaluating action recognition on 75 samples using history length 10.
Loading model: models/Qwen3-0.6B
Model loaded on cuda


Evaluating: 100%|██████████| 75/75 [06:15<00:00,  5.01s/it]


Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded on cuda


Evaluating: 100%|██████████| 75/75 [13:54<00:00, 11.12s/it]
  type_pred = type_of_target(y_pred, input_name="y_pred")
  ys_types = set(type_of_target(x) for x in ys)
  type_pred = type_of_target(y_pred, input_name="y_pred")
  ys_types = set(type_of_target(x) for x in ys)
  type_pred = type_of_target(y_pred, input_name="y_pred")


Saved zero-shot action recognition summary to 2.3-result.json
Saved action recognition raw outputs to 2.3-raw.json


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': ['straight: backward\npan: left\njump: jump',
   'straight: backward\npan: left\njump: noop',
   'straight: backward\npan: noop\njump: jump',
   'straight: backward\npan: noop\njump: noop',
   'straight: backward\npan: right\njump: jump',
   'straight: backward\npan: right\njump: noop',
   'straight: forward\npan: left\njump: jump',
   'straight: forward\npan: left\njump: noop',
   'straight: forward\npan: noop\njump: jump',
   'straight: forward\npan: noop\njump: noop',
   'straight: forward\npan: right\njump: jump',
   'straight: forward\npan: right\njump: noop',
   'straight: noop\npan: left\njump: jump',
   'straight: noop\npan: left\njump: noop',
   'straight: noop\npan: noop\njump: jump',
   'straight: noop\npan: noop\njump: noop',
   'straight: noop\npan: right\njump: jump',
   'straight: noop\npan: right\njump: noop',
   '|air|air|

2.4 update zero-shot action recognition plots (h → y)


In [12]:
import numpy as np
plot_utils = PlotUtils()
action_results = Utils.load_json("2.3-result.json") or {}

if not action_results:
    print("No zero-shot action recognition results found. Run cell 2.3 first.")
else:
    PlotUtils.plot_multi_metric_bar(
        action_results,
        metric_keys=["strict_match_accuracy", "word2vec_cosine", "f1"],
        metric_labels=["Strict Match Accuracy", "Word2Vec Cosine", "Macro F1"],
        title="Zero-Shot Action Recognition Metrics",
        save_path="plots/2.4-action-metric-bars.png",
        scales=[100.0, 100.0, 100.0],
        ylabel="Score (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_metrics_heatmap(
        action_results,
        "Zero-Shot Action Recognition Heatmap",
        "plots/2.4-action-heatmap.png",
        metrics=["strict_match_accuracy", "word2vec_cosine", "precision", "recall", "f1"],
        metric_labels=[
            "Strict Match Accuracy (%)",
            "Word2Vec Cosine (%)",
            "Precision (%)",
            "Recall (%)",
            "Macro F1 (%)",
        ],
        scales=[100.0, 100.0, 100.0, 100.0, 100.0],
    )
    for model_key, metrics in action_results.items():
        conf = metrics.get("confusion_matrix")
        labels = metrics.get("labels", [])
        if conf and labels:
            PlotUtils.plot_confusion_matrix(
                np.array(conf),
                labels,
                f"Action Recognition Confusion Matrix ({model_key})",
                f"plots/2.4-confusion-{model_key}.png",
            )

action_results


Plot saved to plots/2.4-action-metric-bars.png
Plot saved to plots/2.4-action-heatmap.png
Plot saved to plots/2.4-confusion-qwen3-0.6b.png


  plt.tight_layout()


Plot saved to plots/2.4-confusion-qwen3-4b.png


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': ['straight: backward\npan: left\njump: jump',
   'straight: backward\npan: left\njump: noop',
   'straight: backward\npan: noop\njump: jump',
   'straight: backward\npan: noop\njump: noop',
   'straight: backward\npan: right\njump: jump',
   'straight: backward\npan: right\njump: noop',
   'straight: forward\npan: left\njump: jump',
   'straight: forward\npan: left\njump: noop',
   'straight: forward\npan: noop\njump: jump',
   'straight: forward\npan: noop\njump: noop',
   'straight: forward\npan: right\njump: jump',
   'straight: forward\npan: right\njump: noop',
   'straight: noop\npan: left\njump: jump',
   'straight: noop\npan: left\njump: noop',
   'straight: noop\npan: noop\njump: jump',
   'straight: noop\npan: noop\njump: noop',
   'straight: noop\npan: right\njump: jump',
   'straight: noop\npan: right\njump: noop',
   '|air|air|

3.1 fine tune both model with lora method, task is next frame reconstraction, input x and y, output z, with loaded data train and stop in val and monitor via w&b, do not pass model parameter to w&b, keep them in local dir checkpoints with peroper naming and also keep a log in the dir logs. and create 3.1-training-metadata.json file in the working dir given the match betwen the run folder path under checkpoints and the run log path.

In [None]:
from pathlib import Path

training_config = config.get_config()
training_config["batch_size"] = max(1, min(training_config["batch_size"], len(SPLITS["train"])))
lora_config = {
    "r": training_config["lora_r"],
    "lora_alpha": training_config["lora_alpha"],
    "lora_dropout": training_config["lora_dropout"],
}

ENABLE_TRAINING = False
TRAINING_MODELS = ["qwen3-0.6b", "qwen3-4b"]

training_metadata = {}

if ENABLE_TRAINING and SPLITS["train"]:
    for model_key in TRAINING_MODELS:
        print(f"Starting LoRA fine-tuning for {model_key}...")
        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=lora_config, force_reload=True)

        train_loader = build_dataloader(
            full_dataset,
            SPLITS["train"],
            wrapper.tokenizer,
            "frame_reconstruction",
            batch_size=training_config["batch_size"],
            shuffle=True,
            context_examples=None,
        )
        val_loader = build_dataloader(
            full_dataset,
            SPLITS["val"],
            wrapper.tokenizer,
            "frame_reconstruction",
            batch_size=1,
            shuffle=False,
            context_examples=None,
        )

        metadata = wrapper.train(
            train_loader,
            val_loader,
            training_config,
            task_name=f"frame_reconstruction_{model_key}",
            use_wandb=WANDB_ENABLED,
        )

        metadata_path = Path(f"3.1-training-metadata-{model_key}.json")
        Utils.save_json(metadata, metadata_path)
        training_metadata[model_key] = metadata
        release_model(model_key)
        del wrapper
else:
    print("Supervised LoRA training skipped. Set ENABLE_TRAINING = True to run fine-tuning.")

release_all_models()

training_metadata


3.2 evaluate on test dataset, and save to 3.2-result.json

In [None]:
from pathlib import Path

fine_tuned_results_path = Path("3.2-result.json")
fine_tuned_raw_path = Path("3.2-raw.json")
test_indices = INFERENCE_TEST_INDICES if 'INFERENCE_TEST_INDICES' in globals() else SPLITS['test']

if fine_tuned_results_path.exists() and fine_tuned_raw_path.exists():
    fine_tuned_results = Utils.load_json(fine_tuned_results_path) or {}
    fine_tuned_raw_outputs = Utils.load_json(fine_tuned_raw_path) or {}
    print(f"Loaded cached fine-tuned frame reconstruction results from {fine_tuned_results_path}")
    print(f"Loaded cached raw outputs from {fine_tuned_raw_path}")
else:
    fine_tuned_results = {}
    fine_tuned_raw_outputs = {}
    for model_key in TRAINING_MODELS:
        metadata_path = Path(f"3.1-training-metadata-{model_key}.json")
        metadata = Utils.load_json(metadata_path)
        if not metadata:
            print(f"No training metadata found for {model_key}; skipping.")
            continue

        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=lora_config, force_reload=True)
        checkpoint_dir = Path(metadata["checkpoint_dir"])
        adapter_path = checkpoint_dir / "best_lora_adapter"
        model_path = checkpoint_dir / "best_model.pt"

        if adapter_path.exists():
            wrapper.load_checkpoint(str(adapter_path))
        elif model_path.exists():
            wrapper.load_checkpoint(str(model_path))
        else:
            print(f"No fine-tuned weights found for {model_key}; skipping evaluation.")
            release_model(model_key)
            del wrapper
            continue

        metrics = evaluate_wrapper(
            wrapper,
            model_key,
            "frame_reconstruction",
            test_indices,
            context_examples=None,
        )

        predictions = metrics.get("predictions", [])
        targets = metrics.get("targets", [])
        raw_records = []
        if len(predictions) != len(test_indices):
            print(
                f"Warning: prediction count {len(predictions)} does not match test indices {len(test_indices)} for {model_key}."
            )
        for idx, pred, target in zip(test_indices, predictions, targets):
            pair = full_dataset.data_pairs[int(idx)]
            history = pair.get("history_reconstruction") or pair.get("history_action") or pair.get("x")
            raw_records.append(
                {
                    "index": int(idx),
                    "episode": pair.get("episode"),
                    "history": history,
                    "z_label": target,
                    "z_prediction": pred,
                }
            )
        fine_tuned_raw_outputs[model_key] = raw_records

        fine_tuned_results[model_key] = {k: v for k, v in metrics.items() if k not in {"predictions", "targets"}}
        release_model(model_key)
        del wrapper

    if fine_tuned_results:
        Utils.save_json(fine_tuned_results, fine_tuned_results_path)
        Utils.save_json(fine_tuned_raw_outputs, fine_tuned_raw_path)
        print(f"Saved fine-tuned evaluation results to {fine_tuned_results_path}")
        print(f"Saved fine-tuned raw outputs to {fine_tuned_raw_path}")
    else:
        print("No fine-tuned results to save.")

    release_all_models()

fine_tuned_results


3.3 plot the evaluation

In [None]:

plot_utils = PlotUtils()

zero_shot_frame = Utils.load_json("2.1-result.json") or {}
fine_tuned_frame = Utils.load_json("3.2-result.json") or {}

method_results = {}
if zero_shot_frame:
    method_results["zero_shot"] = zero_shot_frame
if fine_tuned_frame:
    method_results["fine_tuned"] = fine_tuned_frame

if len(method_results) >= 2:
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="reconstruction_accuracy",
        title="Frame Reconstruction: Reconstruction Accuracy Comparison",
        save_path="plots/3.3-frame-reconstruction-accuracy.png",
        scale=100.0,
        ylabel="Reconstruction Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Reconstruction Accuracy (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="strict_match_accuracy",
        title="Frame Reconstruction: Strict Match Accuracy Comparison",
        save_path="plots/3.3-frame-strict-accuracy.png",
        scale=100.0,
        ylabel="Strict Match Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Strict Match Accuracy (%)",
        ylim=(0, 100),
    )
else:
    print("Need results from at least two methods to plot comparisons. Run zero-shot (2.1) and fine-tuned (3.2) evaluations.")

{"zero_shot": zero_shot_frame, "fine_tuned": fine_tuned_frame}


### 4.1 LoRA Fine-Tuning for Action Recognition

Fine-tune both Qwen models on the action recognition task using LoRA. Each run stores checkpoints in `checkpoints/`, logs in `logs/`, and records training metadata to `4.1-training-metadata.json` for downstream evaluation.

In [None]:
from pathlib import Path

action_training_config = config.get_config()
action_training_config["batch_size"] = max(1, min(action_training_config["batch_size"], len(SPLITS["train"])))
action_lora_config = {
    "r": action_training_config["lora_r"],
    "lora_alpha": action_training_config["lora_alpha"],
    "lora_dropout": action_training_config["lora_dropout"],
}

ENABLE_ACTION_TRAINING = True
ACTION_TRAINING_MODELS = ["qwen3-0.6b", "qwen3-4b"]

action_training_metadata = {}

if not SPLITS["train"]:
    print("No training samples available for action recognition. Populate SPLITS['train'] before running fine-tuning.")
elif ENABLE_ACTION_TRAINING:
    for model_key in ACTION_TRAINING_MODELS:
        print(f"Starting LoRA action recognition fine-tuning for {model_key}...")
        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=action_lora_config, force_reload=True)

        train_loader = build_dataloader(
            full_dataset,
            SPLITS["train"],
            wrapper.tokenizer,
            "action_recognition",
            batch_size=action_training_config["batch_size"],
            shuffle=True,
            context_examples=None,
        )
        val_loader = build_dataloader(
            full_dataset,
            SPLITS["val"],
            wrapper.tokenizer,
            "action_recognition",
            batch_size=1,
            shuffle=False,
            context_examples=None,
        )

        metadata = wrapper.train(
            train_loader,
            val_loader,
            action_training_config,
            task_name=f"action_recognition_{model_key}",
            use_wandb=WANDB_ENABLED,
        )
        metadata["task_type"] = "action_recognition"
        metadata["config"] = {
            key: action_training_config[key]
            for key in [
                "learning_rate",
                "num_epochs",
                "batch_size",
                "lora_r",
                "lora_alpha",
                "lora_dropout",
                "warmup_steps",
                "max_grad_norm",
            ]
            if key in action_training_config
        }

        metadata_path = Path(f"4.1-training-metadata-{model_key}.json")
        Utils.save_json(metadata, metadata_path)
        action_training_metadata[model_key] = metadata
        release_model(model_key)
        del wrapper

    Utils.save_json(action_training_metadata, "4.1-training-metadata.json")
    print("Saved aggregated training metadata to 4.1-training-metadata.json")
else:
    print("Action recognition LoRA training skipped. Set ENABLE_ACTION_TRAINING = True to run fine-tuning.")

release_all_models()

action_training_metadata


### 4.2 Evaluate Fine-Tuned Action Recognition

Load the best adapters from Section 4.1, run inference on the test split, and persist aggregated metrics to `4.2-result.json`.

In [None]:
from pathlib import Path

action_finetuned_results_path = Path("4.2-result.json")
action_finetuned_raw_path = Path("4.2-raw.json")
test_indices = INFERENCE_TEST_INDICES if 'INFERENCE_TEST_INDICES' in globals() else SPLITS['test']

if action_finetuned_results_path.exists() and action_finetuned_raw_path.exists():
    action_finetuned_results = Utils.load_json(action_finetuned_results_path) or {}
    action_finetuned_raw_outputs = Utils.load_json(action_finetuned_raw_path) or {}
    print(f"Loaded cached fine-tuned action recognition results from {action_finetuned_results_path}")
    print(f"Loaded cached raw outputs from {action_finetuned_raw_path}")
else:
    action_finetuned_results = {}
    action_finetuned_raw_outputs = {}
    metadata_index = Utils.load_json("4.1-training-metadata.json") or {}

    if not metadata_index:
        print("No action recognition training metadata found. Run cell 4.1 first.")
    elif not test_indices:
        print("No test samples available for action recognition evaluation. Populate test indices before running evaluation.")
    else:
        for model_key, metadata in metadata_index.items():
            checkpoint_dir = Path(metadata.get("checkpoint_dir", ""))
            if not checkpoint_dir.exists():
                print(f"Checkpoint directory {checkpoint_dir} not found for {model_key}; skipping.")
                continue

            wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=action_lora_config, force_reload=True)

            adapter_path = checkpoint_dir / "best_lora_adapter"
            model_path = checkpoint_dir / "best_model.pt"

            if adapter_path.exists():
                wrapper.load_checkpoint(str(adapter_path))
            elif model_path.exists():
                wrapper.load_checkpoint(str(model_path))
            else:
                print(f"No fine-tuned weights found for {model_key}; skipping evaluation.")
                release_model(model_key)
                del wrapper
                continue

            metrics = evaluate_wrapper(
                wrapper,
                model_key,
                "action_recognition",
                test_indices,
                context_examples=None,
            )

            predictions = metrics.get("predictions", [])
            targets = metrics.get("targets", [])
            raw_records = []
            if len(predictions) != len(test_indices):
                print(
                    f"Warning: prediction count {len(predictions)} does not match test indices {len(test_indices)} for {model_key}."
                )
            for idx, pred, target in zip(test_indices, predictions, targets):
                pair = full_dataset.data_pairs[int(idx)]
                history = pair.get("history_action") or pair.get("history_reconstruction") or pair.get("x")
                raw_records.append(
                    {
                        "index": int(idx),
                        "episode": pair.get("episode"),
                        "history": history,
                        "y_label": target,
                        "y_prediction": pred,
                    }
                )
            action_finetuned_raw_outputs[model_key] = raw_records

            action_finetuned_results[model_key] = {
                k: v for k, v in metrics.items() if k not in {"predictions", "targets"}
            }
            release_model(model_key)
            del wrapper

    if action_finetuned_results:
        Utils.save_json(action_finetuned_results, action_finetuned_results_path)
        Utils.save_json(action_finetuned_raw_outputs, action_finetuned_raw_path)
        print(f"Saved action recognition fine-tuned evaluation results to {action_finetuned_results_path}")
        print(f"Saved action recognition fine-tuned raw outputs to {action_finetuned_raw_path}")
    elif not (action_finetuned_results_path.exists() and action_finetuned_raw_path.exists()):
        print("No fine-tuned action recognition results to save.")

    release_all_models()

action_finetuned_results


### 4.3 Plot Action Recognition Comparison

Compare zero-shot and LoRA fine-tuned performance using bar charts saved under `plots/`.

In [None]:
plot_utils = PlotUtils()

zero_shot_action = Utils.load_json("2.3-result.json") or {}
fine_tuned_action = Utils.load_json("4.2-result.json") or {}

method_results = {}
if zero_shot_action:
    method_results["zero_shot"] = zero_shot_action
if fine_tuned_action:
    method_results["fine_tuned"] = fine_tuned_action

if len(method_results) >= 2:
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="strict_match_accuracy",
        title="Action Recognition: Strict Match Accuracy Comparison",
        save_path="plots/4.3-action-strict-accuracy.png",
        scale=100.0,
        ylabel="Strict Match Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Strict Match Accuracy (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="f1",
        title="Action Recognition: Macro F1 Comparison",
        save_path="plots/4.3-action-macro-f1.png",
        scale=100.0,
        ylabel="Macro F1 (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Macro F1 (%)",
        ylim=(0, 100),
    )
else:
    print("Need zero-shot and fine-tuned results to plot comparisons. Run cells 2.3 and 4.2.")

{"zero_shot": zero_shot_action, "fine_tuned": fine_tuned_action}