# Continue Training xPatch Models from W&B

Simple notebook to continue training 3 specific models from Weights & Biases.

## Steps:
1. Set your W&B run IDs below
2. Download configs and checkpoints from W&B  
3. Continue training for more epochs
4. Results logged back to W&B

In [9]:
# Import libraries (same as finetune.ipynb)
from utils.metrics import metric
from data_provider.data_loader import Dataset_Custom
from models import xPatch
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from data_provider.data_factory import data_provider
from exp.exp_main import Exp_Main
import sys
import os
import time
import warnings
import math
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
import wandb

warnings.filterwarnings('ignore')

project_root = os.path.abspath('./')
if project_root not in sys.path:
    sys.path.append(project_root)

print("✅ Libraries imported!")

✅ Libraries imported!


In [10]:
# Set random seed
torch.manual_seed(2021)
np.random.seed(2021)
print("✅ Random seed set!")

✅ Random seed set!


In [11]:
# 🔧 EDIT THESE: Your W&B Run IDs
wandb_runs = [
    "xplstm/CS7643-GroupProject/hzvg0y5w",
]

# Training settings
additional_epochs = 20        # How many more epochs to train
new_project = "continued_training_xpatch"  # W&B project for new runs

print(f"📋 Will continue training {len(wandb_runs)} models")
print(f"🎯 Additional epochs: {additional_epochs}")
print(f"📈 Will use original learning rate from W&B config")
print("⚠️  Update the wandb_runs list above with your actual run IDs!")

📋 Will continue training 1 models
🎯 Additional epochs: 20
📈 Will use original learning rate from W&B config
⚠️  Update the wandb_runs list above with your actual run IDs!


In [12]:
# Download config and checkpoint from W&B
def download_from_wandb(run_path):
    """Download config and checkpoint from W&B run"""
    print(f"📡 Downloading from: {run_path}")

    try:
        api = wandb.Api()
        run = api.run(run_path)
        config = dict(run.config)

        print(f"   Run name: {run.name}")
        print(f"   Config keys: {len(config)} parameters")

        # Try to find checkpoint - first in files, then local
        model_path = None

        # Look for checkpoint files in W&B
        print("   🔍 Looking for checkpoint files...")
        files = run.files()
        for file in files:
            if 'checkpoint.pth' in file.name:
                print(f"   📥 Downloading: {file.name}")
                model_path = file.download(replace=True).name
                print(f"   ✅ Downloaded: {file.name}")
                break

        # If not found, try local checkpoint
        if model_path is None and 'model_id' in config:
            local_path = f"./checkpoints/{config['model_id']}/checkpoint.pth"
            if os.path.exists(local_path):
                model_path = local_path
                print(f"   ✅ Using local: {local_path}")

        if model_path is None:
            raise FileNotFoundError(f"No checkpoint found for {run_path}")

        return config, model_path, run.name

    except Exception as e:
        print(f"   ❌ Error: {str(e)}")
        raise


print("✅ Download function ready!")

✅ Download function ready!


In [13]:
# Create Args class (similar to finetune.ipynb)
def create_args_from_config(config, model_idx):
    """Create args from W&B config"""

    class Args:
        def __init__(self):
            # Copy all config parameters
            for key, value in config.items():
                # Handle tensor values from W&B
                if hasattr(value, 'item'):
                    value = value.item()
                setattr(self, key, value)

            # Override with continue training settings
            self.is_training = 1
            self.train_epochs = additional_epochs
            # Keep original learning_rate from W&B config (don't override)

            # Update model_id for continued training
            original_id = getattr(self, 'model_id', f'model_{model_idx}')
            self.model_id = f"{original_id}_continued_{model_idx+1}"

            # W&B settings for new run
            self.use_wandb = True
            self.wandb_project = new_project
            self.wandb_entity = 'xplstm'
            self.experiment_notes = 'Continued training with additional epochs'

            # Type conversions (critical for LSTM parameters)
            int_params = ['lstm_hidden_size', 'lstm_layers', 'train_epochs', 'batch_size',
                          'seq_len', 'pred_len', 'patch_len', 'stride', 'd_model', 'd_ff', 'e_layers', 'k']
            for param in int_params:
                if hasattr(self, param):
                    setattr(self, param, int(float(str(getattr(self, param)))))

            float_params = ['learning_rate', 'dropout', 'alpha', 'beta', 'lstm_dropout',
                            'directional_alpha', 'directional_beta', 'directional_gamma']
            for param in float_params:
                if hasattr(self, param):
                    setattr(self, param, float(str(getattr(self, param))))

            bool_params = ['use_lstm', 'lstm_bidirectional', 'revin', 'decomp']
            for param in bool_params:
                if hasattr(self, param):
                    val = getattr(self, param)
                    setattr(self, param, bool(int(val)) if isinstance(
                        val, (int, float)) else bool(val))

    return Args()


print("✅ Args creation function ready!")

✅ Args creation function ready!


In [14]:
# Continue training function (setup only first)
def continue_training_setup(run_path, model_idx):
    """Setup for continuing training a single model"""
    print(f"\n🚀 Model {model_idx+1}: Setting up continued training")
    
    # Download config and checkpoint
    config, model_path, run_name = download_from_wandb(run_path)
    
    # Create args
    args = create_args_from_config(config, model_idx)
    
    print(f"   Original: {run_name}")  
    print(f"   New ID: {args.model_id}")
    print(f"   Epochs: {args.train_epochs}")
    print(f"   LR: {args.learning_rate}")
    
    # Create experiment
    print(f"   Creating experiment...")
    exp = Exp_Main(args)
    
    # Load pre-trained weights
    print(f"   Loading weights from: {model_path}")
    exp.model.load_state_dict(torch.load(model_path, map_location=exp.device))
    
    print(f"   ✅ Setup completed for Model {model_idx+1}!")
    return args, exp

def continue_training_full(run_path, model_idx):
    """Full continue training including training and testing"""
    args, exp = continue_training_setup(run_path, model_idx)
    
    # Continue training
    print(f"   🔥 Training...")
    exp.train(args.model_id)
    
    # Test
    print(f"   🧪 Testing...")  
    exp.test(args.model_id)
    
    # Cleanup
    torch.cuda.empty_cache()
    wandb.finish()
    
    print(f"   ✅ Model {model_idx+1} completed!")
    return args.model_id

print("✅ Continue training functions ready!")

✅ Continue training functions ready!


In [15]:
# 🧪 TEST: Try setup only (no training yet)
test_run = wandb_runs[0]
print(f"Testing setup from: {test_run}")

try:
    args, exp = continue_training_setup(test_run, 0)
    print(f"✅ Setup successful!")
    print(f"   Args created with LR: {args.learning_rate}")
    print(f"   Model created and weights loaded!")
except Exception as e:
    print(f"❌ Setup failed: {str(e)}")
    import traceback
    traceback.print_exc()

Testing setup from: xplstm/CS7643-GroupProject/hzvg0y5w

🚀 Model 1: Setting up continued training
📡 Downloading from: xplstm/CS7643-GroupProject/hzvg0y5w


   Run name: sweep_toasty-sweep-21
   Config keys: 56 parameters
   🔍 Looking for checkpoint files...
   ✅ Using local: ./checkpoints/sweep_toasty-sweep-21/checkpoint.pth
   Original: sweep_toasty-sweep-21
   New ID: sweep_toasty-sweep-21_continued_1
   Epochs: 20
   LR: 0.0008475905977571697
   Creating experiment...
Use CPU


   Loading weights from: ./checkpoints/sweep_toasty-sweep-21/checkpoint.pth
   ✅ Setup completed for Model 1!
✅ Setup successful!
   Args created with LR: 0.0008475905977571697
   Model created and weights loaded!


In [16]:
# 🚀 RUN CONTINUED TRAINING
# Update the wandb_runs list above before running this cell!

completed_models = []

for i, run_path in enumerate(wandb_runs):
    try:
        model_id = continue_training_full(run_path, i)
        completed_models.append(model_id)
    except Exception as e:
        print(f"❌ Error with model {i+1}: {str(e)}")
        
print(f"\n🎉 Summary: {len(completed_models)}/{len(wandb_runs)} models completed")
for model_id in completed_models:
    print(f"   ✅ {model_id}")


🚀 Model 1: Setting up continued training
📡 Downloading from: xplstm/CS7643-GroupProject/hzvg0y5w
   Run name: sweep_toasty-sweep-21
   Config keys: 56 parameters
   🔍 Looking for checkpoint files...
   ✅ Using local: ./checkpoints/sweep_toasty-sweep-21/checkpoint.pth
   Original: sweep_toasty-sweep-21
   New ID: sweep_toasty-sweep-21_continued_1
   Epochs: 20
   LR: 0.0008475905977571697
   Creating experiment...
Use CPU
   Run name: sweep_toasty-sweep-21
   Config keys: 56 parameters
   🔍 Looking for checkpoint files...
   ✅ Using local: ./checkpoints/sweep_toasty-sweep-21/checkpoint.pth
   Original: sweep_toasty-sweep-21
   New ID: sweep_toasty-sweep-21_continued_1
   Epochs: 20
   LR: 0.0008475905977571697
   Creating experiment...
Use CPU


   Loading weights from: ./checkpoints/sweep_toasty-sweep-21/checkpoint.pth
   ✅ Setup completed for Model 1!
   🔥 Training...
train 2620
val 388
test 774
	iters: 100, epoch: 1 | loss: 0.1467405
	speed: 0.0091s/iter; left time: 58.8782s
	iters: 100, epoch: 1 | loss: 0.1467405
	speed: 0.0091s/iter; left time: 58.8782s
	iters: 200, epoch: 1 | loss: 0.0703050
	speed: 0.0081s/iter; left time: 51.2294s
	iters: 200, epoch: 1 | loss: 0.0703050
	speed: 0.0081s/iter; left time: 51.2294s
	iters: 300, epoch: 1 | loss: 0.0859839
	speed: 0.0080s/iter; left time: 49.6437s
	iters: 300, epoch: 1 | loss: 0.0859839
	speed: 0.0080s/iter; left time: 49.6437s
Epoch: 1 cost time: 2.754502534866333
Epoch: 1 cost time: 2.754502534866333
Epoch: 1, Steps: 327 | Train Loss: 0.1161316 Vali Loss: 0.1521475 Test Loss: 0.0727933
Validation loss decreased (inf --> 0.152147).  Saving model ...
Updating learning rate to 0.0008475905977571697
Epoch: 1, Steps: 327 | Train Loss: 0.1161316 Vali Loss: 0.1521475 Test Loss: 0

0,1
batch,▇▁▂▃▆▅▂▂▃▄▁▆▂▂▄█▁▁▆▁▂▄▁▇▃▇█▆█▂▆▇▂▅▂▄▅▆▇█
batch_loss,▆▆▃▆▃▆▅▅▄▃█▅▄▅▃▅▅▃▃▅▃▃▃▄▃▆▅▄▅█▄▃▅▂▄▆▆▅▁▅
best_sample_mse,▁
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇████
epoch_time,▁▃▄▅▄▄▄▃▃▄▁▃▃▃▇█▅▃▂▄
final_test_mae,▁
final_test_mse,▁
learning_rate,█████████▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mse_std,▁
test_loss,█▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch,320.0
batch_loss,0.14087
best_sample_mse,0.00017
epoch,20.0
epoch_time,3.65981
final_test_mae,0.15448
final_test_mse,0.04784
learning_rate,0.0
mse_std,0.12552
test_loss,0.04761


   ✅ Model 1 completed!

🎉 Summary: 1/1 models completed
   ✅ sweep_toasty-sweep-21_continued_1


## Example Usage

### 1. Update Run IDs
Edit the `wandb_runs` list in cell 4 with your actual W&B run IDs:
```python
wandb_runs = [
    "xplstm/CS7643-GroupProject/3abc123d",
    "xplstm/CS7643-GroupProject/4def456e", 
    "xplstm/CS7643-GroupProject/5ghi789f"
]
```

### 2. Adjust Training Settings
Modify these variables in cell 4:
- `additional_epochs`: How many more epochs to train
- `new_learning_rate`: New learning rate (or None to keep original)
- `new_project`: W&B project name for continued training runs

### 3. Run Cells
Execute cells 1-8 in order. Each model will:
- Download config and checkpoint from W&B
- Continue training for the specified epochs
- Test the final model
- Log results to new W&B runs

The new models will be saved in `./checkpoints/` with names like `original_name_continued_1`.