## 0. Infrastructure Setup

### 0.1 Utils Module
All helpful methods including validate local path, local logging, serialise and deserialise json file, read and write files, create and delete path

# Minecraft Voxel World LLM Training

## Project Overview
This notebook implements a complete pipeline for training Large Language Models (LLMs) on Minecraft voxel-based sequential prediction tasks:
- **Frame Prediction**: Given current voxel state + action → predict next voxel state
- **Action Recognition**: Given current voxel state + next voxel state → predict action taken

## Dataset
- **Source**: Minecraft gameplay data with 3D voxel representations
- **Format**: Sequential .npy files containing voxel grids and actions
- **Structure**: Each frame contains a 3×3×3 block grid and action vector

## Models
- **Qwen 3 0.6B**: Small-scale model for efficient training
- **Qwen 3 4B**: Larger model for improved performance

## Methods
- In-context learning (few-shot prompting with training examples)
- Supervised fine-tuning with LoRA for frame reconstruction
- Supervised fine-tuning with LoRA for action recognition

---

## Table of Contents

### 0. Infrastructure Setup
- **0.1** Utils Module - File I/O, logging, JSON operations
- **0.2** Model Wrapper Class - Training, evaluation, checkpoint management
- **0.3** Plot Evaluation Class - Conference-quality visualizations
- **0.4** Hyperparameter Configuration - Grid search support

### 1. Setup
- **1.1** Load Models - Qwen 3 0.6B and 4B configuration
- **1.2** Load Minecraft Dataset - Sequential voxel frames with actions
- **1.3** Split Data - Train/val/test split (70%/15%/15%)

### 2. In-Context Learning Evaluation
- **2.1** Frame Reconstruction - Input: x+y, Output: z (with 3 training examples as context)
- **2.2** Frame Reconstruction Plots - Visualization of results
- **2.3** Action Recognition - Input: x+z, Output: y (with 3 training examples as context)
- **2.4** Action Recognition Plots - Visualization of results

### 3. Supervised Fine-Tuning (LoRA) for Frame Reconstruction 
- **3.1** Fine-tune Frame Reconstruction - LoRA adaptation with W&B monitoring
- **3.2** Evaluate Fine-tuned Models - Test set performance
- **3.3** Plot Fine-tuning Results - Compare in-context vs fine-tuned

### 4. Supervised Fine-Tuning (LoRA) for Action Recognition
- **4.1** Fine-tune Action Recognition - Train LoRA adapters to predict discrete actions
- **4.2** Evaluate Action Recognition - Test set metrics and JSON export
- **4.3** Plot Action Recognition Comparison - Bar charts versus zero-shot baseline

---

In [1]:
import json
import logging
import os
from pathlib import Path
from typing import Any, Union
from datetime import datetime

class Utils:
    """Utility class for file operations, logging, and path management."""
    
    @staticmethod
    def validate_path(path: Union[str, Path], create: bool = False) -> Path:
        """Validate and optionally create a path."""
        path = Path(path)
        if create:
            path.mkdir(parents=True, exist_ok=True)
        return path
    
    @staticmethod
    def setup_logging(log_dir: Union[str, Path], name: str = "experiment") -> logging.Logger:
        """Setup logging to file and console."""
        log_dir = Utils.validate_path(log_dir, create=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = log_dir / f"{name}_{timestamp}.log"
        
        logger = logging.getLogger(name)
        logger.setLevel(logging.INFO)
        
        # Clear existing handlers
        logger.handlers.clear()
        
        # File handler
        fh = logging.FileHandler(log_file)
        fh.setLevel(logging.INFO)
        
        # Console handler
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        
        # Formatter
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        ch.setFormatter(formatter)
        
        logger.addHandler(fh)
        logger.addHandler(ch)
        
        return logger
    
    @staticmethod
    def save_json(data: Any, path: Union[str, Path]) -> None:
        """Save data to JSON file."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
    
    @staticmethod
    def load_json(path: Union[str, Path]) -> Any:
        """Load data from JSON file."""
        path = Path(path)
        if not path.exists():
            return None
        with open(path, 'r') as f:
            return json.load(f)
    
    @staticmethod
    def read_file(path: Union[str, Path]) -> str:
        """Read text file."""
        with open(path, 'r') as f:
            return f.read()
    
    @staticmethod
    def write_file(path: Union[str, Path], content: str) -> None:
        """Write text file."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, 'w') as f:
            f.write(content)
    
    @staticmethod
    def delete_path(path: Union[str, Path]) -> None:
        """Delete file or directory."""
        path = Path(path)
        if path.is_file():
            path.unlink()
        elif path.is_dir():
            import shutil
            shutil.rmtree(path)

print("Utils module loaded successfully")

Utils module loaded successfully


### 0.2 Model Wrapper Class
Including loading with name method, train with dataloaders method, and evaluate method.

With loaded data train and stop in val and monitor via W&B. Do not pass model parameters to W&B. Keep them in local dir `checkpoints/` with proper naming and also keep a log in the dir `logs/`.  Create a JSON file with proper name of task in the working dir given the match between the run folder path under checkpoints and the run log path.

The checkpoint resume from latest feature should be implemented - we do not want to train repeatedly.

In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel
import wandb
from tqdm import tqdm
from datetime import datetime
import glob
from pathlib import Path

class ModelWrapper:
    """Wrapper class for loading, training, and evaluating LLM models."""

    def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        """Initialize model wrapper."""
        self.model_name = model_name
        self.device = device
        self.model = None
        self.tokenizer = None
        self.is_lora = False
        self.checkpoint_dir = None
        self.log_dir = None

    def load_model(self, use_lora: bool = False, lora_config: dict = None):
        """Load model and tokenizer."""
        print(f"Loading model: {self.model_name}")

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device_map="auto" if self.device == "cuda" else None
        )

        if use_lora:
            default_config = {
                "r": 8,
                "lora_alpha": 32,
                "target_modules": ["q_proj", "v_proj"],
                "lora_dropout": 0.1,
                "bias": "none",
                "task_type": "CAUSAL_LM"
            }
            if lora_config:
                default_config.update(lora_config)

            lora_config_obj = LoraConfig(**default_config)
            self.model = get_peft_model(self.model, lora_config_obj)
            self.is_lora = True
            print("LoRA adapter applied")
            print("NOTE: Only LoRA adapter weights will be saved (saves disk space)")

        if self.device != "cuda":
            self.model = self.model.to(self.device)

        print(f"Model loaded on {self.device}")
        return self

    def setup_training(self, task_name: str, run_name: str = None):
        """Setup directories and logging for training."""
        if run_name is None:
            run_name = f"{task_name}_{self.model_name.split('/')[-1]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

        self.checkpoint_dir = Utils.validate_path(f"checkpoints/{run_name}", create=True)
        self.log_dir = Utils.validate_path(f"logs/{run_name}", create=True)

        metadata = {
            "task_name": task_name,
            "run_name": run_name,
            "model_name": self.model_name,
            "checkpoint_dir": str(self.checkpoint_dir),
            "log_dir": str(self.log_dir),
            "is_lora": self.is_lora,
            "created_at": datetime.now().isoformat()
        }

        return run_name, metadata

    def find_latest_checkpoint(self):
        """Find latest checkpoint for resuming training."""
        if not self.checkpoint_dir or not Path(self.checkpoint_dir).exists():
            return None

        checkpoints = glob.glob(str(Path(self.checkpoint_dir) / "checkpoint_epoch_*.pt"))
        if not checkpoints:
            return None

        epochs = [int(Path(c).stem.split('_')[-1]) for c in checkpoints]
        latest_epoch = max(epochs)
        return Path(self.checkpoint_dir) / f"checkpoint_epoch_{latest_epoch}.pt"

    def train(self, train_loader: DataLoader, val_loader: DataLoader,
              config: dict, task_name: str, use_wandb: bool = True):
        """Train model with LoRA - saves only adapter weights to save disk space."""
        run_name, metadata = self.setup_training(task_name)
        logger = Utils.setup_logging(self.log_dir, task_name)

        latest_checkpoint = self.find_latest_checkpoint()
        start_epoch = 0

        if latest_checkpoint:
            print(f"Found checkpoint: {latest_checkpoint}")
            checkpoint = torch.load(latest_checkpoint)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resuming from epoch {start_epoch}")

        if use_wandb:
            wandb.init(
                project=config.get('wandb_project', 'minecraft-llm'),
                name=run_name,
                config={k: v for k, v in config.items() if not k.startswith('wandb')}
            )

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=config['learning_rate'])
        num_training_steps = len(train_loader) * config['num_epochs']
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config.get('warmup_steps', 0),
            num_training_steps=num_training_steps
        )

        best_val_loss = float('inf')

        for epoch in range(start_epoch, config['num_epochs']):
            self.model.train()
            train_loss = 0

            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
            for batch in pbar:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss
                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), config.get('max_grad_norm', 1.0))

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                train_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})

                if use_wandb:
                    wandb.log({'train_loss_step': loss.item()})

            avg_train_loss = train_loss / len(train_loader)
            val_loss = self.evaluate_loss(val_loader)

            logger.info(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_loss={val_loss:.4f}")

            if use_wandb:
                wandb.log({
                    'epoch': epoch + 1,
                    'train_loss': avg_train_loss,
                    'val_loss': val_loss
                })

            try:
                if self.is_lora:
                    adapter_path = Path(self.checkpoint_dir) / f"adapter_epoch_{epoch}"
                    self.model.save_pretrained(adapter_path)
                    state_path = Path(self.checkpoint_dir) / f"state_epoch_{epoch}.pt"
                    torch.save({
                        'epoch': epoch,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_loss': avg_train_loss,
                        'val_loss': val_loss
                    }, state_path)
                else:
                    checkpoint_path = Path(self.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pt"
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_loss': avg_train_loss,
                        'val_loss': val_loss
                    }, checkpoint_path)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if self.is_lora:
                        best_adapter_path = Path(self.checkpoint_dir) / "best_lora_adapter"
                        self.model.save_pretrained(best_adapter_path)
                        logger.info(f"Best LoRA adapter saved with val_loss={val_loss:.4f}")
                    else:
                        best_model_path = Path(self.checkpoint_dir) / "best_model.pt"
                        torch.save(self.model.state_dict(), best_model_path)
                        logger.info(f"Best model saved with val_loss={val_loss:.4f}")

            except RuntimeError as e:
                logger.error(f"Failed to save checkpoint: {e}")
                print(f"WARNING: Could not save checkpoint due to disk space. Error: {e}")

        if use_wandb:
            wandb.finish()

        metadata['training_completed'] = datetime.now().isoformat()
        metadata['best_val_loss'] = best_val_loss

        return metadata

    def evaluate_loss(self, dataloader: DataLoader):
        """Evaluate loss on a dataloader."""
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                total_loss += outputs.loss.item()

        return total_loss / len(dataloader)

    def evaluate(self, test_loader: DataLoader, max_new_tokens: int = 256):
        """Evaluate model and return predictions."""
        self.model.eval()
        predictions = []
        targets = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)

                outputs = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_new_tokens,
                    pad_token_id=self.tokenizer.pad_token_id,
                    do_sample=False
                )

                for i, output in enumerate(outputs):
                    input_len = input_ids[i].shape[0]
                    generated = output[input_len:]
                    pred_text = self.tokenizer.decode(generated, skip_special_tokens=True)
                    predictions.append(pred_text)

                if 'target_text' in batch:
                    targets.extend(batch['target_text'])

        return predictions, targets

    def load_checkpoint(self, checkpoint_path: str):
        """Load a specific checkpoint."""
        checkpoint_path = Path(checkpoint_path)

        if self.is_lora and checkpoint_path.is_dir():
            self.model = PeftModel.from_pretrained(self.model, checkpoint_path)
            print(f"LoRA adapter loaded from {checkpoint_path}")
        else:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
            print(f"Checkpoint loaded from {checkpoint_path}")
        return self

print("ModelWrapper class loaded successfully")



ModelWrapper class loaded successfully


### 0.3 Plot Evaluation Class
Including all methods we need to plot conference-level paper quality plots.

In [3]:

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from typing import Dict, List, Sequence, Optional

class PlotUtils:
    """Utility class for creating conference-quality plots."""

    def __init__(self, style='seaborn-v0_8-paper'):
        """Initialize plotting style."""
        try:
            plt.style.use(style)
        except:
            plt.style.use('seaborn-v0_8')

        sns.set_palette("husl")
        self.colors = sns.color_palette("husl", 10)

        # Ensure plots directory exists
        os.makedirs('plots', exist_ok=True)

    @staticmethod
    def _resolve_scales(metric_keys: Sequence[str], scales: Optional[Sequence[float]], default: float = 100.0) -> List[float]:
        """Resolve scaling factors for metrics."""
        if scales is None:
            return [default] * len(metric_keys)
        scale_list = list(scales)
        if len(scale_list) != len(metric_keys):
            raise ValueError("Number of scales must match number of metric keys")
        return [float(value) for value in scale_list]

    @staticmethod
    def plot_multi_metric_bar(results: Dict[str, Dict], metric_keys: Sequence[str],
                              metric_labels: Sequence[str], title: str, save_path: str,
                              scales: Optional[Sequence[float]] = None,
                              ylabel: Optional[str] = None, ylim: Optional[Sequence[float]] = None):
        """Plot grouped bar chart for multiple metrics per model."""
        os.makedirs('plots', exist_ok=True)
        models = list(results.keys())
        if not models or not metric_keys:
            print("No data to plot multi-metric bar chart.")
            return

        scale_values = PlotUtils._resolve_scales(metric_keys, scales)
        if ylabel is None:
            if all(abs(scale - 100.0) < 1e-6 for scale in scale_values):
                ylabel = 'Score (%)'
            else:
                ylabel = 'Score'

        x = np.arange(len(models))
        width = 0.8 / max(1, len(metric_keys))
        fig, ax = plt.subplots(figsize=(12, 6))

        for idx, (metric, label, scale) in enumerate(zip(metric_keys, metric_labels, scale_values)):
            values = [results[model].get(metric, 0.0) * scale for model in models]
            offset = (idx - (len(metric_keys) - 1) / 2) * width
            bars = ax.bar(x + offset, values, width, label=label)
            for bar, value in zip(bars, values):
                ax.text(bar.get_x() + bar.get_width() / 2.0, value,
                        f'{value:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

        ax.set_xticks(x)
        ax.set_xticklabels(models)
        ax.set_ylabel(ylabel)
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        ax.legend()
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        if ylim is not None:
            ax.set_ylim(ylim)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plot saved to {save_path}")

    @staticmethod
    def plot_metrics_heatmap(results: Dict[str, Dict], title: str, save_path: str,
                             metrics: Optional[Sequence[str]] = None,
                             metric_labels: Optional[Sequence[str]] = None,
                             scales: Optional[Sequence[float]] = None):
        """Plot metrics heatmap for multiple models."""
        os.makedirs('plots', exist_ok=True)
        models = list(results.keys())
        if not models:
            print("No data to plot heatmap.")
            return

        if metrics is None:
            metrics = ['accuracy', 'precision', 'recall', 'f1']
        if metric_labels is None:
            metric_labels = [metric.replace('_', ' ').title() for metric in metrics]
        if len(metric_labels) != len(metrics):
            raise ValueError("Metric labels length must match metrics length")

        scale_values = PlotUtils._resolve_scales(metrics, scales)
        data = []
        for model in models:
            row = []
            for metric, scale in zip(metrics, scale_values):
                value = results.get(model, {}).get(metric, 0.0)
                row.append(value * scale)
            data.append(row)
        data = np.array(data)

        fig, ax = plt.subplots(figsize=(10, 6))
        im = ax.imshow(data, cmap='YlGnBu', aspect='auto')

        ax.set_xticks(np.arange(len(metrics)))
        ax.set_yticks(np.arange(len(models)))
        ax.set_xticklabels(metric_labels, fontsize=12)
        ax.set_yticklabels(models, fontsize=12)

        for i in range(len(models)):
            for j in range(len(metrics)):
                ax.text(j, i, f'{data[i, j]:.1f}', ha='center', va='center',
                        color='black', fontsize=11, fontweight='bold')

        cbar = plt.colorbar(im, ax=ax)
        if all(abs(scale - 100.0) < 1e-6 for scale in scale_values):
            cbar.set_label('Score (%)', rotation=270, labelpad=20)
        else:
            cbar.set_label('Score', rotation=270, labelpad=20)

        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plot saved to {save_path}")

    @staticmethod
    def plot_training_curves(train_losses: List[float], val_losses: List[float],
                             title: str, save_path: str):
        """Plot training and validation loss curves."""
        os.makedirs('plots', exist_ok=True)

        epochs = list(range(1, len(train_losses) + 1))

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(epochs, train_losses, marker='o', label='Training Loss', linewidth=2)
        ax.plot(epochs, val_losses, marker='s', label='Validation Loss', linewidth=2)

        ax.set_xlabel('Epoch', fontsize=14, fontweight='bold')
        ax.set_ylabel('Loss', fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        ax.legend(fontsize=12, loc='best')
        ax.grid(alpha=0.3, linestyle='--')

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

        print(f"Plot saved to {save_path}")

    @staticmethod
    def plot_method_metric_bar(method_results: Dict[str, Dict[str, Dict]], metric_key: str,
                                title: str, save_path: str, scale: float = 100.0,
                                ylabel: Optional[str] = None,
                                method_labels: Optional[Dict[str, str]] = None,
                                metric_label: Optional[str] = None,
                                ylim: Optional[Sequence[float]] = None):
        """Plot comparison of a single metric across methods and models."""
        os.makedirs('plots', exist_ok=True)
        if not method_results:
            print("No method data to plot.")
            return

        methods = list(method_results.keys())
        model_names = sorted({model for res in method_results.values() for model in res.keys()})
        if not model_names:
            print("No model data to plot.")
            return

        x = np.arange(len(model_names))
        width = 0.8 / max(1, len(methods))
        if ylabel is None:
            ylabel = 'Score (%)' if abs(scale - 100.0) < 1e-6 else 'Score'
        if metric_label is None:
            metric_label = metric_key.replace('_', ' ').title()

        fig, ax = plt.subplots(figsize=(12, 6))
        for idx, method in enumerate(methods):
            display_name = method_labels.get(method, method.replace('_', ' ').title()) if method_labels else method.replace('_', ' ').title()
            values = []
            for model in model_names:
                value = method_results[method].get(model, {}).get(metric_key, 0.0)
                values.append(value * scale)
            offset = (idx - (len(methods) - 1) / 2) * width
            bars = ax.bar(x + offset, values, width, label=display_name)
            for bar, value in zip(bars, values):
                ax.text(bar.get_x() + bar.get_width() / 2.0, value,
                        f'{value:.1f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

        ax.set_xticks(x)
        ax.set_xticklabels(model_names)
        ax.set_ylabel(ylabel)
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        ax.legend()
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        if ylim is not None:
            ax.set_ylim(ylim)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plot saved to {save_path}")

    @staticmethod
    def plot_method_comparison(results_dict: Dict[str, Dict[str, Dict]], title: str, save_path: str,
                               methods: List[str], metric: str = 'accuracy'):
        """Compatibility wrapper for grouped method comparisons."""
        filtered = {method: results_dict[method] for method in methods if method in results_dict}
        if not filtered:
            print("No matching methods to plot.")
            return
        PlotUtils.plot_method_metric_bar(
            filtered,
            metric_key=metric,
            title=title,
            save_path=save_path,
            scale=100.0,
            ylabel=f"{metric.replace('_', ' ').title()} (%)"
        )

    @staticmethod
    def plot_confusion_matrix(confusion_matrix: np.ndarray, class_names: List[str],
                             title: str, save_path: str):
        """Plot confusion matrix."""
        os.makedirs('plots', exist_ok=True)

        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(confusion_matrix, cmap='Blues', aspect='auto')

        # Set ticks
        ax.set_xticks(np.arange(len(class_names)))
        ax.set_yticks(np.arange(len(class_names)))
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.set_yticklabels(class_names)

        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Count', rotation=270, labelpad=20)

        # Add text annotations
        for i in range(len(class_names)):
            for j in range(len(class_names)):
                text = ax.text(j, i, int(confusion_matrix[i, j]),
                             ha="center", va="center", 
                             color="white" if confusion_matrix[i, j] > confusion_matrix.max()/2 else "black")

        ax.set_xlabel('Predicted', fontsize=14, fontweight='bold')
        ax.set_ylabel('True', fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

        print(f"Plot saved to {save_path}")

print("PlotUtils class loaded successfully")


PlotUtils class loaded successfully


### 0.4 Hyperparameter Configuration  
Define all configurable hyperparameters and provide grid search method.

Keep a local JSON called `grid-search-record.json` to save past running results. Each time we run the whole notebook, if we enable grid search, we have to read the JSON file and continue to the next grid search values.

In [4]:
import itertools
from typing import Dict, List, Any

class HyperparameterConfig:
    """Hyperparameter configuration with grid search support."""
    
    DEFAULT_CONFIG = {
        'learning_rate': 5e-5,
        'num_epochs': 3,
        'batch_size': 8,
        'max_length': 512,
        'lora_r': 8,
        'lora_alpha': 32,
        'lora_dropout': 0.1,
        'warmup_steps': 100,
        'max_grad_norm': 1.0,
        'wandb_project': 'minecraft-llm'
    }
    
    def __init__(self, config: Dict[str, Any] = None):
        """Initialize configuration."""
        self.config = self.DEFAULT_CONFIG.copy()
        if config:
            self.config.update(config)
        
        self.grid_search_file = "grid-search-record.json"
    
    def get_config(self) -> Dict[str, Any]:
        """Get current configuration."""
        return self.config.copy()
    
    def update_config(self, updates: Dict[str, Any]):
        """Update configuration with new values."""
        self.config.update(updates)
    
    def get_grid_search_params(self) -> Dict[str, List[Any]]:
        """Define grid search parameters."""
        return {
            'learning_rate': [1e-5, 5e-5, 1e-4],
            'batch_size': [4, 8, 16],
            'lora_r': [4, 8, 16],
            'lora_alpha': [16, 32, 64],
            'num_epochs': [3, 5]
        }
    
    def generate_grid_configs(self, param_grid: Dict[str, List[Any]] = None) -> List[Dict[str, Any]]:
        """Generate all combinations of hyperparameters for grid search."""
        if param_grid is None:
            param_grid = self.get_grid_search_params()
        
        # Get parameter names and values
        param_names = list(param_grid.keys())
        param_values = [param_grid[name] for name in param_names]
        
        # Generate all combinations
        configs = []
        for combination in itertools.product(*param_values):
            config = self.config.copy()
            for name, value in zip(param_names, combination):
                config[name] = value
            configs.append(config)
        
        print(f"Generated {len(configs)} grid search configurations")
        return configs
    
    def load_grid_search_record(self) -> Dict[str, Any]:
        """Load previous grid search results."""
        try:
            record = Utils.load_json(self.grid_search_file)
            if record is None:
                record = {
                    'completed_configs': [],
                    'results': [],
                    'best_config': None,
                    'best_score': -float('inf')
                }
            return record
        except:
            return {
                'completed_configs': [],
                'results': [],
                'best_config': None,
                'best_score': -float('inf')
            }
    
    def save_grid_search_record(self, record: Dict[str, Any]):
        """Save grid search results."""
        Utils.save_json(record, self.grid_search_file)
        print(f"Grid search record saved to {self.grid_search_file}")
    
    def get_next_grid_config(self, all_configs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Get the next configuration to try in grid search."""
        record = self.load_grid_search_record()
        completed = record['completed_configs']
        
        # Find first config not yet completed
        for config in all_configs:
            config_key = self._config_to_key(config)
            if config_key not in completed:
                return config
        
        return None  # All configs completed
    
    def update_grid_search_result(self, config: Dict[str, Any], result: Dict[str, Any], score: float):
        """Update grid search record with a new result."""
        record = self.load_grid_search_record()
        
        config_key = self._config_to_key(config)
        record['completed_configs'].append(config_key)
        record['results'].append({
            'config': config,
            'result': result,
            'score': score
        })
        
        # Update best config if this is better
        if score > record['best_score']:
            record['best_config'] = config
            record['best_score'] = score
            print(f"New best configuration found! Score: {score:.4f}")
        
        self.save_grid_search_record(record)
    
    def _config_to_key(self, config: Dict[str, Any]) -> str:
        """Convert config to a unique string key."""
        # Only use grid search parameters
        grid_params = self.get_grid_search_params()
        key_parts = []
        for param in sorted(grid_params.keys()):
            if param in config:
                key_parts.append(f"{param}={config[param]}")
        return "_".join(key_parts)
    
    def print_config(self):
        """Print current configuration."""
        print("=" * 80)
        print("HYPERPARAMETER CONFIGURATION")
        print("=" * 80)
        for key, value in sorted(self.config.items()):
            print(f"{key:20s}: {value}")
        print("=" * 80)
    
    def print_grid_search_summary(self):
        """Print summary of grid search results."""
        record = self.load_grid_search_record()
        
        print("=" * 80)
        print("GRID SEARCH SUMMARY")
        print("=" * 80)
        print(f"Completed configurations: {len(record['completed_configs'])}")
        print(f"Best score: {record['best_score']:.4f}")
        
        if record['best_config']:
            print("\nBest configuration:")
            for key, value in sorted(record['best_config'].items()):
                if key in self.get_grid_search_params():
                    print(f"  {key:20s}: {value}")
        print("=" * 80)

# Initialize default configuration
config = HyperparameterConfig()
config.print_config()

print("HyperparameterConfig class loaded successfully")

HYPERPARAMETER CONFIGURATION
batch_size          : 8
learning_rate       : 5e-05
lora_alpha          : 32
lora_dropout        : 0.1
lora_r              : 8
max_grad_norm       : 1.0
max_length          : 512
num_epochs          : 3
wandb_project       : minecraft-llm
warmup_steps        : 100
HyperparameterConfig class loaded successfully


1.1 load model via transformers, we pick Qwen3-0.6B and Qwen3-4B

In [5]:

import os
from typing import Dict, Optional

import torch

MODEL_PATHS: Dict[str, str] = {
    "qwen3-0.6b": "models/Qwen3-0.6B",
    "qwen3-4b": "models/Qwen3-4B",
}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WANDB_ENABLED = False  # Toggle to True to log to Weights & Biases when credentials are available

os.environ.setdefault("WANDB_MODE", "offline")
os.environ.setdefault("WANDB_SILENT", "true")

loaded_models: Dict[str, ModelWrapper] = {}

def get_model_wrapper(model_key: str, *, use_lora: bool = False, lora_config: Optional[dict] = None, force_reload: bool = False) -> ModelWrapper:
    """Return a ModelWrapper for the requested model, loading weights on demand."""
    if model_key not in MODEL_PATHS:
        raise KeyError(f"Unknown model key '{model_key}'. Available: {list(MODEL_PATHS.keys())}")

    wrapper = loaded_models.get(model_key)
    needs_reload = force_reload or wrapper is None or getattr(wrapper, "model", None) is None or (use_lora and not getattr(wrapper, "is_lora", False))

    if wrapper is None:
        wrapper = ModelWrapper(model_name=MODEL_PATHS[model_key], device=DEVICE)
        needs_reload = True

    if needs_reload:
        wrapper.load_model(use_lora=use_lora, lora_config=lora_config)
        loaded_models[model_key] = wrapper

    return wrapper

def release_model(model_key: str) -> None:
    """Remove a loaded model from memory to free GPU/CPU resources."""
    wrapper = loaded_models.pop(model_key, None)
    if wrapper is None:
        return

    model = getattr(wrapper, "model", None)
    try:
        if model is not None:
            try:
                model.to("cpu")
            except Exception as move_exc:
                print(f"Warning: unable to move model for {model_key} to CPU: {move_exc}")
        # Explicitly delete model reference to encourage cleanup
        del model
    except Exception as exc:
        print(f"Warning: cleanup for model {model_key} raised an exception: {exc}")

    wrapper.model = None
    wrapper.tokenizer = None
    wrapper.is_lora = False

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, "ipc_collect"):
            torch.cuda.ipc_collect()
    try:
        import gc
        gc.collect()
    except Exception:
        pass

    wrapper = None

def release_all_models() -> None:
    """Release all cached models to ensure GPU memory is freed."""
    for cached_key in list(loaded_models.keys()):
        release_model(cached_key)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, "ipc_collect"):
            torch.cuda.ipc_collect()
    try:
        import gc
        gc.collect()
    except Exception:
        pass

print(f"Available models: {list(MODEL_PATHS.keys())}")
print(f"Using device: {DEVICE}")
print("Call get_model_wrapper('<model_key>') to load a model when needed.")


Available models: ['qwen3-0.6b', 'qwen3-4b']
Using device: cuda
Call get_model_wrapper('<model_key>') to load a model when needed.


1.2 load custom data from local datasets dir, the data with 3 types of data, 

- x : current frame in ascii art, 
- y: current action token, 
- z: next frame ascii art, 

all in plain text format.

In [6]:
from minecraft_dataset import MinecraftDataset

DATA_DIR = Path("datasets/minecraft/data")

full_dataset = MinecraftDataset(
    data_dir=str(DATA_DIR),
    max_length=config.get_config()["max_length"]
)

TOTAL_PAIRS = len(full_dataset)
UNIQUE_ACTIONS = sorted({pair["y"] for pair in full_dataset.data_pairs})

print(f"Loaded {TOTAL_PAIRS} sequential frame/action pairs from {DATA_DIR}.")
print(f"Unique actions: {UNIQUE_ACTIONS}")


def preview_dataset_example(idx: int = 0) -> None:
    """Print a dataset example for quick inspection."""
    example = full_dataset[idx]
    print("=" * 80)
    print(f"Example {idx}")
    print("- Current Frame (x):")
    print(example["x"])
    print("- Action (y):")
    print(example["y"])
    print("- Next Frame (z):")
    print(example["z"])
    print("=" * 80)


preview_dataset_example(0)


Loading 9 frames from datasets/minecraft/data
Created 8 training pairs
Loaded 8 sequential frame/action pairs from datasets/minecraft/data.
Unique actions: ['straight: forward\npan: left\njump: jump\n']
Example 0
- Current Frame (x):
|air|grass block|dirt|
|air|grass block|dirt|
|air|grass block|dirt|
|air|air|dirt|
|air|air|dirt|
|air|air|dirt|
|air|air|grass block|
|air|air|grass block|
|air|air|grass block|

- Action (y):
straight: forward
pan: left
jump: jump

- Next Frame (z):
|air|grass block|dirt|
|air|grass block|dirt|
|air|grass block|dirt|
|air|air|dirt|
|air|air|dirt|
|air|air|dirt|
|air|air|grass block|
|air|air|grass block|
|air|air|grass block|



1.3 split loaded data to all, train, val and test.

In [7]:

import re
from difflib import SequenceMatcher
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

TOTAL_SAMPLES = len(full_dataset)
indices = list(range(TOTAL_SAMPLES))

if TOTAL_SAMPLES < 3:
    train_indices = indices[:max(1, TOTAL_SAMPLES - 2)]
    val_indices = indices[len(train_indices):len(train_indices) + (1 if TOTAL_SAMPLES - len(train_indices) > 1 else 0)]
    test_indices = indices[len(train_indices) + len(val_indices):]
else:
    train_size = max(1, int(np.floor(0.7 * TOTAL_SAMPLES)))
    val_size = max(1, int(np.floor(0.15 * TOTAL_SAMPLES)))
    remaining = TOTAL_SAMPLES - train_size - val_size

    if remaining < 1:
        deficit = 1 - remaining
        if val_size - deficit >= 1:
            val_size -= deficit
        else:
            deficit -= (val_size - 1)
            val_size = 1
            train_size = max(1, train_size - deficit)
        remaining = 1

    test_size = remaining
    train_indices = indices[:train_size]
    val_start = train_size
    val_indices = indices[val_start:val_start + val_size]
    test_indices = indices[val_start + val_size:val_start + val_size + test_size]

SPLITS = {
    "train": train_indices,
    "val": val_indices,
    "test": test_indices,
}

print("Dataset split sizes:", {split: len(idxs) for split, idxs in SPLITS.items()})

PROMPT_CONTEXT_LENGTH = 100
CONTEXT_EXAMPLES = [
    full_dataset.data_pairs[i]
    for i in SPLITS["train"][-PROMPT_CONTEXT_LENGTH:]
]
CONTEXT_EXAMPLE_COUNT = len(CONTEXT_EXAMPLES)
MAX_NEW_TOKENS = 256

print(
    f"Using {CONTEXT_EXAMPLE_COUNT} sequential training frames as context (up to {PROMPT_CONTEXT_LENGTH} frames)."
)


def create_dataset_subset(indices, *, context_examples=None):
    """Create an in-memory subset of the Minecraft dataset for custom splits."""
    subset = MinecraftDataset.__new__(MinecraftDataset)
    subset.data_dir = full_dataset.data_dir
    subset.tokenizer = None
    subset.max_length = full_dataset.max_length
    subset.context_examples = list(context_examples or [])
    subset.data_pairs = [full_dataset.data_pairs[i] for i in indices]
    return subset


def build_dataloader(indices, tokenizer, task_type, *, batch_size=1, shuffle=False, context_examples=None):
    """Create a dataloader for the requested task and split."""
    dataset_subset = create_dataset_subset(indices, context_examples=context_examples)
    collate_fn = (
        dataset_subset.collate_frame_reconstruction
        if task_type == "frame_reconstruction"
        else dataset_subset.collate_action_recognition
    )
    return DataLoader(
        dataset_subset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: collate_fn(batch, tokenizer),
    )


class Word2VecManager:
    """Lightweight Word2Vec-style embeddings trained on local action text."""

    def __init__(self, texts, embedding_dim=32, window_size=2, epochs=300, lr=5e-3):
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.epochs = epochs
        self.lr = lr
        self._device = torch.device("cpu")

        self.sentences = [self.tokenize(text) for text in texts if text and text.strip()]
        self.vocab = []
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.input_embeddings = None
        self.output_layer = None
        self.training_pairs = []

        if self.sentences:
            self._build_vocab()
            self._build_training_pairs()
            if self.training_pairs:
                self._train_embeddings()

    @staticmethod
    def tokenize(text):
        return [tok for tok in re.findall(r"[A-Za-z0-9_]+", text.lower())]

    def _build_vocab(self):
        vocab = sorted({token for sent in self.sentences for token in sent})
        self.vocab = vocab
        self.word_to_idx = {word: idx for idx, word in enumerate(vocab)}
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}

    def _build_training_pairs(self):
        pairs = []
        for sent in self.sentences:
            if not sent:
                continue
            for center_idx, token in enumerate(sent):
                center_id = self.word_to_idx[token]
                start = max(0, center_idx - self.window_size)
                end = min(len(sent), center_idx + self.window_size + 1)
                for context_pos in range(start, end):
                    if context_pos == center_idx:
                        continue
                    context_token = sent[context_pos]
                    pairs.append((center_id, self.word_to_idx[context_token]))
        self.training_pairs = pairs

    def _train_embeddings(self):
        torch.manual_seed(42)
        vocab_size = len(self.vocab)
        self.input_embeddings = torch.nn.Embedding(vocab_size, self.embedding_dim).to(self._device)
        self.output_layer = torch.nn.Linear(self.embedding_dim, vocab_size, bias=False).to(self._device)

        optimizer = torch.optim.Adam(
            list(self.input_embeddings.parameters()) + list(self.output_layer.parameters()),
            lr=self.lr,
        )
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(self.epochs):
            total_loss = 0.0
            for center_idx, context_idx in self.training_pairs:
                center_tensor = torch.tensor([center_idx], dtype=torch.long, device=self._device)
                context_tensor = torch.tensor([context_idx], dtype=torch.long, device=self._device)

                logits = self.output_layer(self.input_embeddings(center_tensor))
                loss = loss_fn(logits, context_tensor)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            if total_loss < 1e-6:
                break

    def encode(self, text):
        if not self.training_pairs or self.input_embeddings is None:
            return np.zeros(self.embedding_dim, dtype=float)

        tokens = self.tokenize(text)
        ids = [self.word_to_idx[token] for token in tokens if token in self.word_to_idx]
        if not ids:
            return np.zeros(self.embedding_dim, dtype=float)

        with torch.no_grad():
            tensor = torch.tensor(ids, dtype=torch.long, device=self._device)
            embeddings = self.input_embeddings(tensor)
            vector = embeddings.mean(dim=0)
        return vector.cpu().numpy()

    def cosine_similarity(self, text_a, text_b):
        vec_a = self.encode(text_a)
        vec_b = self.encode(text_b)
        denom = np.linalg.norm(vec_a) * np.linalg.norm(vec_b)
        if denom == 0.0:
            return 0.0
        return float(np.dot(vec_a, vec_b) / denom)


def regex_fullmatch(text: str, pattern: str) -> bool:
    """Return True when text fully matches pattern using regex with fallback to literal match."""
    text = text.strip()
    pattern = pattern.strip()
    if pattern == "":
        return text == pattern

    flags = re.IGNORECASE | re.DOTALL

    try:
        compiled = re.compile(pattern, flags)
        if compiled.fullmatch(text):
            return True
    except re.error:
        compiled = None

    literal_compiled = re.compile(re.escape(pattern), flags)
    return bool(literal_compiled.fullmatch(text))


ACTION_EMBEDDER = Word2VecManager([pair["y"] for pair in full_dataset.data_pairs])


def compute_text_metrics(predictions, targets, *, task_type: str):
    """Compute metrics for text generation tasks, including regex and similarity metrics."""
    if not targets:
        return {
            "accuracy": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0,
            "labels": [],
            "confusion_matrix": [],
            "regex_matches": 0,
            "strict_match_accuracy": 0.0,
            "reconstruction_accuracy": 0.0,
            "word2vec_cosine": 0.0,
        }

    regex_matches = []
    normalized_predictions = []
    reconstruction_scores = []
    cosine_scores = []

    for pred, target in zip(predictions, targets):
        match = regex_fullmatch(pred, target)
        regex_matches.append(1 if match else 0)
        normalized_predictions.append(target if match else pred)

        if task_type == "frame_reconstruction":
            score = SequenceMatcher(None, target, pred).ratio()
            reconstruction_scores.append(score)
        elif task_type == "action_recognition":
            cosine_scores.append(ACTION_EMBEDDER.cosine_similarity(pred, target))

    strict_accuracy = float(np.mean(regex_matches))

    metrics = {
        "regex_matches": int(sum(regex_matches)),
        "strict_match_accuracy": strict_accuracy,
    }

    if task_type == "frame_reconstruction":
        reconstruction_accuracy = float(np.mean(reconstruction_scores)) if reconstruction_scores else 0.0
        metrics.update({
            "accuracy": reconstruction_accuracy,
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0,
            "labels": [],
            "confusion_matrix": [],
            "reconstruction_accuracy": reconstruction_accuracy,
            "reconstruction_scores": reconstruction_scores,
            "word2vec_cosine": 0.0,
        })
    else:
        label_space = sorted(set(targets + normalized_predictions))
        label_to_idx = {label: idx for idx, label in enumerate(label_space)}
        y_true = [label_to_idx[t] for t in targets]
        y_pred = [label_to_idx[p] for p in normalized_predictions]

        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true,
            y_pred,
            average="macro",
            zero_division=0,
        )
        conf = confusion_matrix(
            y_true,
            y_pred,
            labels=list(range(len(label_space))),
        )

        cosine_mean = float(np.mean(cosine_scores)) if cosine_scores else 0.0

        metrics.update({
            "accuracy": strict_accuracy,
            "precision": float(precision),
            "recall": float(recall),
            "f1": float(f1),
            "labels": label_space,
            "confusion_matrix": conf.astype(int).tolist(),
            "reconstruction_accuracy": 0.0,
            "word2vec_cosine": cosine_mean,
            "word2vec_scores": cosine_scores,
        })

    return metrics


def evaluate_wrapper(wrapper, model_key, task_type, indices, *, context_examples=None, batch_size=1, max_new_tokens=MAX_NEW_TOKENS):
    """Run evaluation for a pre-loaded wrapper."""
    dataloader = build_dataloader(
        indices,
        wrapper.tokenizer,
        task_type,
        batch_size=batch_size,
        shuffle=False,
        context_examples=context_examples,
    )
    predictions, targets = wrapper.evaluate(dataloader, max_new_tokens=max_new_tokens)
    metrics = compute_text_metrics(predictions, targets, task_type=task_type)
    metrics.update({
        "model": model_key,
        "num_samples": len(targets),
        "num_context_examples": len(context_examples) if context_examples else 0,
        "predictions": predictions,
        "targets": targets,
    })
    return metrics


def evaluate_model_on_task(model_key, task_type, indices, *, context_examples=None, batch_size=1, max_new_tokens=MAX_NEW_TOKENS, use_lora=False, lora_config=None):
    """Convenience wrapper that loads a model, runs evaluation, and returns metrics."""
    wrapper = get_model_wrapper(model_key, use_lora=use_lora, lora_config=lora_config)
    metrics = evaluate_wrapper(
        wrapper,
        model_key,
        task_type,
        indices,
        context_examples=context_examples,
        batch_size=batch_size,
        max_new_tokens=max_new_tokens,
    )
    return metrics


Dataset split sizes: {'train': 5, 'val': 1, 'test': 2}
Using 5 sequential training frames as context (up to 100 frames).


2.1 evaluate both model on all data, for next frame reconstraction task, input x and y, output z, save result to 2.1-result.json

In [8]:

from pathlib import Path

frame_prompt_examples = CONTEXT_EXAMPLES
print(
    f"Evaluating frame reconstruction with {len(frame_prompt_examples)} sequential prompt examples."
)

frame_results = {}
for model_key in MODEL_PATHS:
    metrics = evaluate_model_on_task(
        model_key,
        "frame_reconstruction",
        SPLITS["test"],
        context_examples=frame_prompt_examples,
        batch_size=1,
    )
    frame_results[model_key] = {
        k: v for k, v in metrics.items() if k not in {"predictions", "targets"}
    }
    release_model(model_key)

frame_results_path = Path("2.1-result.json")
Utils.save_json(frame_results, frame_results_path)
print(f"Saved zero-shot frame reconstruction results to {frame_results_path}")

release_all_models()

frame_results


Evaluating frame reconstruction with 5 sequential prompt examples.
Loading model: models/Qwen3-0.6B


`torch_dtype` is deprecated! Use `dtype` instead!


Model loaded on cuda


Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating: 100%|██████████| 2/2 [00:05<00:00,  2.65s/it]


Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded on cuda


Evaluating: 100%|██████████| 2/2 [00:10<00:00,  5.45s/it]


Saved zero-shot frame reconstruction results to 2.1-result.json


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.055205047318611984,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.055205047318611984,
  'reconstruction_scores': [0.055205047318611984, 0.055205047318611984],
  'word2vec_cosine': 0.0,
  'model': 'qwen3-0.6b',
  'num_samples': 2,
  'num_context_examples': 5},
 'qwen3-4b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.07267221801665405,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.07267221801665405,
  'reconstruction_scores': [0.07267221801665405, 0.07267221801665405],
  'word2vec_cosine': 0.0,
  'model': 'qwen3-4b',
  'num_samples': 2,
  'num_context_examples': 5}}

2.2 plot the result of the evalutation

In [9]:

plot_utils = PlotUtils()
frame_results = Utils.load_json("2.1-result.json") or {}

if not frame_results:
    print("No zero-shot frame reconstruction results found. Run cell 2.1 first.")
else:
    PlotUtils.plot_multi_metric_bar(
        frame_results,
        metric_keys=["strict_match_accuracy", "reconstruction_accuracy"],
        metric_labels=["Strict Match Accuracy", "Reconstruction Accuracy"],
        title="Zero-Shot Frame Reconstruction Metrics",
        save_path="plots/2.2-frame-metric-bars.png",
        scales=[100.0, 100.0],
        ylabel="Score (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_metrics_heatmap(
        frame_results,
        "Zero-Shot Frame Reconstruction Heatmap",
        "plots/2.2-frame-heatmap.png",
        metrics=["strict_match_accuracy", "reconstruction_accuracy"],
        metric_labels=["Strict Match Accuracy (%)", "Reconstruction Accuracy (%)"],
        scales=[100.0, 100.0],
    )

frame_results


Plot saved to plots/2.2-frame-metric-bars.png
Plot saved to plots/2.2-frame-heatmap.png


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.055205047318611984,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.055205047318611984,
  'reconstruction_scores': [0.055205047318611984, 0.055205047318611984],
  'word2vec_cosine': 0.0,
  'model': 'qwen3-0.6b',
  'num_samples': 2,
  'num_context_examples': 5},
 'qwen3-4b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.07267221801665405,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': [],
  'confusion_matrix': [],
  'reconstruction_accuracy': 0.07267221801665405,
  'reconstruction_scores': [0.07267221801665405, 0.07267221801665405],
  'word2vec_cosine': 0.0,
  'model': 'qwen3-4b',
  'num_samples': 2,
  'num_context_examples': 5}}

2.3 evaluate both model on all data for the action recognition task, input x and z, output y, save result to 2.3-result.json

In [10]:

from pathlib import Path

action_results = {}
for model_key in MODEL_PATHS:
    metrics = evaluate_model_on_task(
        model_key,
        "action_recognition",
        SPLITS["test"],
        context_examples=CONTEXT_EXAMPLES,
        batch_size=1,
    )
    action_results[model_key] = {k: v for k, v in metrics.items() if k not in {"predictions", "targets"}}
    release_model(model_key)

action_results_path = Path("2.3-result.json")
Utils.save_json(action_results, action_results_path)
print(f"Saved zero-shot action recognition results to {action_results_path}")

release_all_models()

action_results


Loading model: models/Qwen3-0.6B
Model loaded on cuda


Evaluating: 100%|██████████| 2/2 [00:05<00:00,  2.60s/it]


Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded on cuda


Evaluating: 100%|██████████| 2/2 [00:10<00:00,  5.46s/it]


Saved zero-shot action recognition results to 2.3-result.json


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': ['irt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nAction:\nstraight: forward\npan: left\njump: jump\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass',
   'straight: forward\npan: left\njump: jump'],
  'confusion_matrix': [[0, 0], [2, 0]],
  'reconstruction_accuracy': 0.0,
  'word2vec_cosine': 0.9999999403953552,
  'wo

2.4 plot the result of the evalutation

In [11]:

plot_utils = PlotUtils()
action_results = Utils.load_json("2.3-result.json") or {}

if not action_results:
    print("No zero-shot action recognition results found. Run cell 2.3 first.")
else:
    PlotUtils.plot_multi_metric_bar(
        action_results,
        metric_keys=["strict_match_accuracy", "word2vec_cosine", "f1"],
        metric_labels=["Strict Match Accuracy", "Word2Vec Cosine", "Macro F1"],
        title="Zero-Shot Action Recognition Metrics",
        save_path="plots/2.4-action-metric-bars.png",
        scales=[100.0, 100.0, 100.0],
        ylabel="Score (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_metrics_heatmap(
        action_results,
        "Zero-Shot Action Recognition Heatmap",
        "plots/2.4-action-heatmap.png",
        metrics=["strict_match_accuracy", "word2vec_cosine", "precision", "recall", "f1"],
        metric_labels=[
            "Strict Match Accuracy (%)",
            "Word2Vec Cosine (%)",
            "Precision (%)",
            "Recall (%)",
            "Macro F1 (%)",
        ],
        scales=[100.0, 100.0, 100.0, 100.0, 100.0],
    )
    for model_key, metrics in action_results.items():
        conf = metrics.get("confusion_matrix")
        labels = metrics.get("labels", [])
        if conf and labels:
            PlotUtils.plot_confusion_matrix(
                np.array(conf),
                labels,
                f"Action Recognition Confusion Matrix ({model_key})",
                f"plots/2.4-confusion-{model_key}.png",
            )

action_results


Plot saved to plots/2.4-action-metric-bars.png
Plot saved to plots/2.4-action-heatmap.png
Plot saved to plots/2.4-confusion-qwen3-0.6b.png


  plt.tight_layout()


Plot saved to plots/2.4-confusion-qwen3-4b.png


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': ['irt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nAction:\nstraight: forward\npan: left\njump: jump\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass',
   'straight: forward\npan: left\njump: jump'],
  'confusion_matrix': [[0, 0], [2, 0]],
  'reconstruction_accuracy': 0.0,
  'word2vec_cosine': 0.9999999403953552,
  'wo

3.1 fine tune both model with lora method, task is next frame reconstraction, input x and y, output z, with loaded data train and stop in val and monitor via w&b, do not pass model parameter to w&b, keep them in local dir checkpoints with peroper naming and also keep a log in the dir logs. and create 3.1-training-metadata.json file in the working dir given the match betwen the run folder path under checkpoints and the run log path.

In [12]:

from pathlib import Path

training_config = config.get_config()
training_config["batch_size"] = max(1, min(training_config["batch_size"], len(SPLITS["train"])))
lora_config = {
    "r": training_config["lora_r"],
    "lora_alpha": training_config["lora_alpha"],
    "lora_dropout": training_config["lora_dropout"],
}

ENABLE_TRAINING = False
TRAINING_MODELS = ["qwen3-0.6b", "qwen3-4b"]

training_metadata = {}

if ENABLE_TRAINING and SPLITS["train"]:
    for model_key in TRAINING_MODELS:
        print(f"Starting LoRA fine-tuning for {model_key}...")
        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=lora_config, force_reload=True)

        train_loader = build_dataloader(
            SPLITS["train"],
            wrapper.tokenizer,
            "frame_reconstruction",
            batch_size=training_config["batch_size"],
            shuffle=True,
            context_examples=None,
        )
        val_loader = build_dataloader(
            SPLITS["val"],
            wrapper.tokenizer,
            "frame_reconstruction",
            batch_size=1,
            shuffle=False,
            context_examples=None,
        )

        metadata = wrapper.train(
            train_loader,
            val_loader,
            training_config,
            task_name=f"frame_reconstruction_{model_key}",
            use_wandb=WANDB_ENABLED,
        )

        metadata_path = Path(f"3.1-training-metadata-{model_key}.json")
        Utils.save_json(metadata, metadata_path)
        training_metadata[model_key] = metadata
        release_model(model_key)
        del wrapper
else:
    print("Supervised LoRA training skipped. Set ENABLE_TRAINING = True to run fine-tuning.")

release_all_models()

training_metadata


Supervised LoRA training skipped. Set ENABLE_TRAINING = True to run fine-tuning.


{}

3.2 evaluate on test dataset, and save to 3.2-result.json

In [13]:
from pathlib import Path

fine_tuned_results = {}
for model_key in TRAINING_MODELS:
    metadata_path = Path(f"3.1-training-metadata-{model_key}.json")
    metadata = Utils.load_json(metadata_path)
    if not metadata:
        print(f"No training metadata found for {model_key}; skipping.")
        continue

    wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=lora_config, force_reload=True)
    checkpoint_dir = Path(metadata["checkpoint_dir"])
    adapter_path = checkpoint_dir / "best_lora_adapter"
    model_path = checkpoint_dir / "best_model.pt"

    if adapter_path.exists():
        wrapper.load_checkpoint(str(adapter_path))
    elif model_path.exists():
        wrapper.load_checkpoint(str(model_path))
    else:
        print(f"No fine-tuned weights found for {model_key}; skipping evaluation.")
        release_model(model_key)
        del wrapper
        continue

    metrics = evaluate_wrapper(
        wrapper,
        model_key,
        "frame_reconstruction",
        SPLITS["test"],
        context_examples=None,
    )
    fine_tuned_results[model_key] = {k: v for k, v in metrics.items() if k not in {"predictions", "targets"}}
    release_model(model_key)
    del wrapper

if fine_tuned_results:
    Utils.save_json(fine_tuned_results, "3.2-result.json")
    print("Saved fine-tuned evaluation results to 3.2-result.json")
else:
    print("No fine-tuned results to save.")

release_all_models()

fine_tuned_results


Loading model: models/Qwen3-0.6B
LoRA adapter applied
NOTE: Only LoRA adapter weights will be saved (saves disk space)
Model loaded on cuda
No fine-tuned weights found for qwen3-0.6b; skipping evaluation.
Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LoRA adapter applied
NOTE: Only LoRA adapter weights will be saved (saves disk space)
Model loaded on cuda
No fine-tuned weights found for qwen3-4b; skipping evaluation.
No fine-tuned results to save.


{}

3.3 plot the evaluation

In [14]:

plot_utils = PlotUtils()

zero_shot_frame = Utils.load_json("2.1-result.json") or {}
fine_tuned_frame = Utils.load_json("3.2-result.json") or {}

method_results = {}
if zero_shot_frame:
    method_results["zero_shot"] = zero_shot_frame
if fine_tuned_frame:
    method_results["fine_tuned"] = fine_tuned_frame

if len(method_results) >= 2:
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="reconstruction_accuracy",
        title="Frame Reconstruction: Reconstruction Accuracy Comparison",
        save_path="plots/3.3-frame-reconstruction-accuracy.png",
        scale=100.0,
        ylabel="Reconstruction Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Reconstruction Accuracy (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="strict_match_accuracy",
        title="Frame Reconstruction: Strict Match Accuracy Comparison",
        save_path="plots/3.3-frame-strict-accuracy.png",
        scale=100.0,
        ylabel="Strict Match Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Strict Match Accuracy (%)",
        ylim=(0, 100),
    )
else:
    print("Need results from at least two methods to plot comparisons. Run zero-shot (2.1) and fine-tuned (3.2) evaluations.")

{"zero_shot": zero_shot_frame, "fine_tuned": fine_tuned_frame}


Plot saved to plots/3.3-frame-reconstruction-accuracy.png
Plot saved to plots/3.3-frame-strict-accuracy.png


{'zero_shot': {'qwen3-0.6b': {'regex_matches': 0,
   'strict_match_accuracy': 0.0,
   'accuracy': 0.055205047318611984,
   'precision': 0.0,
   'recall': 0.0,
   'f1': 0.0,
   'labels': [],
   'confusion_matrix': [],
   'reconstruction_accuracy': 0.055205047318611984,
   'reconstruction_scores': [0.055205047318611984, 0.055205047318611984],
   'word2vec_cosine': 0.0,
   'model': 'qwen3-0.6b',
   'num_samples': 2,
   'num_context_examples': 5},
  'qwen3-4b': {'regex_matches': 0,
   'strict_match_accuracy': 0.0,
   'accuracy': 0.07267221801665405,
   'precision': 0.0,
   'recall': 0.0,
   'f1': 0.0,
   'labels': [],
   'confusion_matrix': [],
   'reconstruction_accuracy': 0.07267221801665405,
   'reconstruction_scores': [0.07267221801665405, 0.07267221801665405],
   'word2vec_cosine': 0.0,
   'model': 'qwen3-4b',
   'num_samples': 2,
   'num_context_examples': 5}},
 'fine_tuned': {'qwen3-0.6b': {'regex_matches': 0,
   'strict_match_accuracy': 0.0,
   'accuracy': 0.06538461538461539,
   '

### 4.1 LoRA Fine-Tuning for Action Recognition

Fine-tune both Qwen models on the action recognition task using LoRA. Each run stores checkpoints in `checkpoints/`, logs in `logs/`, and records training metadata to `4.1-training-metadata.json` for downstream evaluation.

In [15]:
from pathlib import Path

action_training_config = config.get_config()
action_training_config["batch_size"] = max(1, min(action_training_config["batch_size"], len(SPLITS["train"])))
action_lora_config = {
    "r": action_training_config["lora_r"],
    "lora_alpha": action_training_config["lora_alpha"],
    "lora_dropout": action_training_config["lora_dropout"],
}

ENABLE_ACTION_TRAINING = True
ACTION_TRAINING_MODELS = ["qwen3-0.6b", "qwen3-4b"]

action_training_metadata = {}

if not SPLITS["train"]:
    print("No training samples available for action recognition. Populate SPLITS['train'] before running fine-tuning.")
elif ENABLE_ACTION_TRAINING:
    for model_key in ACTION_TRAINING_MODELS:
        print(f"Starting LoRA action recognition fine-tuning for {model_key}...")
        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=action_lora_config, force_reload=True)

        train_loader = build_dataloader(
            SPLITS["train"],
            wrapper.tokenizer,
            "action_recognition",
            batch_size=action_training_config["batch_size"],
            shuffle=True,
            context_examples=None,
        )
        val_loader = build_dataloader(
            SPLITS["val"],
            wrapper.tokenizer,
            "action_recognition",
            batch_size=1,
            shuffle=False,
            context_examples=None,
        )

        metadata = wrapper.train(
            train_loader,
            val_loader,
            action_training_config,
            task_name=f"action_recognition_{model_key}",
            use_wandb=WANDB_ENABLED,
        )
        metadata["task_type"] = "action_recognition"
        metadata["config"] = {
            key: action_training_config[key]
            for key in [
                "learning_rate",
                "num_epochs",
                "batch_size",
                "lora_r",
                "lora_alpha",
                "lora_dropout",
                "warmup_steps",
                "max_grad_norm",
            ]
            if key in action_training_config
        }

        metadata_path = Path(f"4.1-training-metadata-{model_key}.json")
        Utils.save_json(metadata, metadata_path)
        action_training_metadata[model_key] = metadata
        release_model(model_key)
        del wrapper

    Utils.save_json(action_training_metadata, "4.1-training-metadata.json")
    print("Saved aggregated training metadata to 4.1-training-metadata.json")
else:
    print("Action recognition LoRA training skipped. Set ENABLE_ACTION_TRAINING = True to run fine-tuning.")

release_all_models()

action_training_metadata


Starting LoRA action recognition fine-tuning for qwen3-0.6b...
Loading model: models/Qwen3-0.6B
LoRA adapter applied
NOTE: Only LoRA adapter weights will be saved (saves disk space)
Model loaded on cuda


Epoch 1/3: 100%|██████████| 1/1 [00:00<00:00,  6.88it/s, loss=1.66]
2025-10-16 00:00:55,867 - action_recognition_qwen3-0.6b - INFO - Epoch 1: train_loss=1.6580, val_loss=1.6582
2025-10-16 00:00:55,912 - action_recognition_qwen3-0.6b - INFO - Best LoRA adapter saved with val_loss=1.6582
Epoch 2/3: 100%|██████████| 1/1 [00:00<00:00, 10.65it/s, loss=1.66]
2025-10-16 00:00:56,025 - action_recognition_qwen3-0.6b - INFO - Epoch 2: train_loss=1.6580, val_loss=1.6580
2025-10-16 00:00:56,062 - action_recognition_qwen3-0.6b - INFO - Best LoRA adapter saved with val_loss=1.6580
Epoch 3/3: 100%|██████████| 1/1 [00:00<00:00, 10.74it/s, loss=1.66]
2025-10-16 00:00:56,175 - action_recognition_qwen3-0.6b - INFO - Epoch 3: train_loss=1.6579, val_loss=1.6575
2025-10-16 00:00:56,209 - action_recognition_qwen3-0.6b - INFO - Best LoRA adapter saved with val_loss=1.6575


Starting LoRA action recognition fine-tuning for qwen3-4b...
Loading model: models/Qwen3-4B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LoRA adapter applied
NOTE: Only LoRA adapter weights will be saved (saves disk space)
Model loaded on cuda


Epoch 1/3: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s, loss=1.68]
2025-10-16 00:00:58,428 - action_recognition_qwen3-4b - INFO - Epoch 1: train_loss=1.6804, val_loss=1.6809
2025-10-16 00:00:58,484 - action_recognition_qwen3-4b - INFO - Best LoRA adapter saved with val_loss=1.6809
Epoch 2/3: 100%|██████████| 1/1 [00:00<00:00,  2.59it/s, loss=1.68]
2025-10-16 00:00:58,925 - action_recognition_qwen3-4b - INFO - Epoch 2: train_loss=1.6804, val_loss=1.6809
Epoch 3/3:   0%|          | 0/1 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 574.00 MiB. GPU 0 has a total capacity of 15.57 GiB of which 246.62 MiB is free. Including non-PyTorch memory, this process has 15.26 GiB memory in use. Of the allocated memory 13.85 GiB is allocated by PyTorch, and 1.12 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

### 4.2 Evaluate Fine-Tuned Action Recognition

Load the best adapters from Section 4.1, run inference on the test split, and persist aggregated metrics to `4.2-result.json`.

In [None]:
from pathlib import Path

action_finetuned_results = {}
metadata_index = Utils.load_json("4.1-training-metadata.json") or {}

if not metadata_index:
    print("No action recognition training metadata found. Run cell 4.1 first.")
elif not SPLITS["test"]:
    print("No test samples available for action recognition evaluation. Populate SPLITS['test'] before running evaluation.")
else:
    for model_key, metadata in metadata_index.items():
        checkpoint_dir = Path(metadata.get("checkpoint_dir", ""))
        if not checkpoint_dir.exists():
            print(f"Checkpoint directory {checkpoint_dir} not found for {model_key}; skipping.")
            continue

        wrapper = get_model_wrapper(model_key, use_lora=True, lora_config=action_lora_config, force_reload=True)

        adapter_path = checkpoint_dir / "best_lora_adapter"
        model_path = checkpoint_dir / "best_model.pt"

        if adapter_path.exists():
            wrapper.load_checkpoint(str(adapter_path))
        elif model_path.exists():
            wrapper.load_checkpoint(str(model_path))
        else:
            print(f"No fine-tuned weights found for {model_key}; skipping evaluation.")
            release_model(model_key)
            del wrapper
            continue

        metrics = evaluate_wrapper(
            wrapper,
            model_key,
            "action_recognition",
            SPLITS["test"],
            context_examples=None,
        )
        action_finetuned_results[model_key] = {
            k: v for k, v in metrics.items() if k not in {"predictions", "targets"}
        }
        release_model(model_key)
        del wrapper

if action_finetuned_results:
    Utils.save_json(action_finetuned_results, "4.2-result.json")
    print("Saved action recognition fine-tuned evaluation results to 4.2-result.json")
else:
    print("No fine-tuned action recognition results to save.")

release_all_models()

action_finetuned_results


Loading model: models/Qwen3-0.6B




LoRA adapter applied
NOTE: Only LoRA adapter weights will be saved (saves disk space)
Model loaded on cuda
LoRA adapter loaded from checkpoints/action_recognition_qwen3-0.6b_Qwen3-0.6B_20251015_235814/best_lora_adapter


Evaluating: 100%|██████████| 2/2 [00:06<00:00,  3.42s/it]


Saved action recognition fine-tuned evaluation results to 4.2-result.json


{'qwen3-0.6b': {'regex_matches': 0,
  'strict_match_accuracy': 0.0,
  'accuracy': 0.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'labels': ['\nturn: left\nturn: right\nAction: turn: left\nAction: turn: right\nAction: jump: jump\nAction: straight: forward\nAction: pan: left\nAction: pan: right\nAction: jump: jump\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\nAction: turn: right\nAction: turn: left\

### 4.3 Plot Action Recognition Comparison

Compare zero-shot and LoRA fine-tuned performance using bar charts saved under `plots/`.

In [None]:
plot_utils = PlotUtils()

zero_shot_action = Utils.load_json("2.3-result.json") or {}
fine_tuned_action = Utils.load_json("4.2-result.json") or {}

method_results = {}
if zero_shot_action:
    method_results["zero_shot"] = zero_shot_action
if fine_tuned_action:
    method_results["fine_tuned"] = fine_tuned_action

if len(method_results) >= 2:
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="strict_match_accuracy",
        title="Action Recognition: Strict Match Accuracy Comparison",
        save_path="plots/4.3-action-strict-accuracy.png",
        scale=100.0,
        ylabel="Strict Match Accuracy (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Strict Match Accuracy (%)",
        ylim=(0, 100),
    )
    PlotUtils.plot_method_metric_bar(
        method_results,
        metric_key="f1",
        title="Action Recognition: Macro F1 Comparison",
        save_path="plots/4.3-action-macro-f1.png",
        scale=100.0,
        ylabel="Macro F1 (%)",
        method_labels={"zero_shot": "Zero-Shot", "fine_tuned": "LoRA Fine-Tuned"},
        metric_label="Macro F1 (%)",
        ylim=(0, 100),
    )
else:
    print("Need zero-shot and fine-tuned results to plot comparisons. Run cells 2.3 and 4.2.")

{"zero_shot": zero_shot_action, "fine_tuned": fine_tuned_action}

Plot saved to plots/4.3-action-strict-accuracy.png
Plot saved to plots/4.3-action-macro-f1.png


{'zero_shot': {'qwen3-0.6b': {'regex_matches': 0,
   'strict_match_accuracy': 0.0,
   'accuracy': 0.0,
   'precision': 0.0,
   'recall': 0.0,
   'f1': 0.0,
   'labels': ['irt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nAction:\nstraight: forward\npan: left\njump: jump\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass block|\n|air|air|grass block|\n|air|air|grass block|\n\n\nFrame:\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|grass block|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|dirt|\n|air|air|grass',
    'straight: forward\npan: left\njump: jump'],
   'confusion_matrix': [[0, 0], [2, 0]],
   'reconstruction_accuracy': 0.0,
   'word2vec_cosine': 0