
# `ModelRecorder.load_model` — Usage Examples

This notebook shows practical, **runnable** examples for loading models with `ModelRecorder.load_model`, mirroring the style of the `save_model_examples.ipynb` notebook.

We cover:
1. Loading a scikit-learn model from a `.pkl` file.
2. Loading a PyTorch model from a `.pt`/`.pth` file **into an instantiated model**.
3. Loading only the **state_dict** (when you don't have/instantiate a model).
4. Common error handling patterns (unsupported extension, file-like object, corrupted file).


In [5]:
import os
import tempfile
import pickle
import torch
import torch.nn as nn
from ThreeWToolkit.utils import ModelRecorder
from sklearn import linear_model

In [2]:
class SimpleTorchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)


## 1) Load a scikit-learn model (`.pkl`)

We'll create and save a simple `LogisticRegression` model, then load it back using `ModelRecorder.load_model`.


In [7]:
# --- Save a sklearn model ---
sk_model = linear_model.LogisticRegression()
tmp_pkl = tempfile.NamedTemporaryFile(suffix=".pkl", delete=False)
tmp_pkl.close()
with open(tmp_pkl.name, "wb") as f:
    pickle.dump(sk_model, f)

print("Saved sklearn model to:", tmp_pkl.name)

# --- Load it back ---
loaded_sk_model = ModelRecorder.load_model(tmp_pkl.name)
print("Loaded type:", type(loaded_sk_model))
assert isinstance(loaded_sk_model, linear_model.LogisticRegression)

os.remove(tmp_pkl.name)


Saved sklearn model to: /tmp/tmpxp2jtidw.pkl
Loaded type: <class 'sklearn.linear_model._logistic.LogisticRegression'>



## 2) Load a PyTorch model into an instantiated module (`.pt`/`.pth`)

We'll create a simple `nn.Module`, save its `state_dict`, then **instantiate a new model** and load weights into it.


In [8]:

# --- Create & save a torch state_dict ---
torch_model = SimpleTorchModel()
tmp_pt = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
tmp_pt.close()
torch.save(torch_model.state_dict(), tmp_pt.name)
print("Saved torch state_dict to:", tmp_pt.name)

# --- Load into a fresh instance ---
fresh_model = SimpleTorchModel()
restored_model = ModelRecorder.load_model(tmp_pt.name, model=fresh_model)

# Verify parameters match
for k, v in torch_model.state_dict().items():
    assert torch.allclose(restored_model.state_dict()[k], v)

print("Restored model successfully, parameters match.")


Saved torch state_dict to: /tmp/tmprf9sjp2w.pt
Restored model successfully, parameters match.



## 3) Load only the `state_dict` (no instantiated model)

If you don't pass a model instance, `load_model` returns the raw `state_dict`. This is useful when you only need the parameters, or you want to reconstruct the architecture later.


In [9]:
# Reuse the same .pt path created above (tmp_pt)
state = ModelRecorder.load_model(tmp_pt.name)
print("Type of returned object:", type(state))
print("Keys in state_dict:", list(state.keys())[:5], "...")
assert isinstance(state, dict)  # OrderedDict or dict depending on PyTorch version

os.remove(tmp_pt.name)

Type of returned object: <class 'collections.OrderedDict'>
Keys in state_dict: ['linear.weight', 'linear.bias'] ...



## 4) Error handling patterns

Below are examples demonstrating how to **catch and diagnose** common errors:
- Unsupported extension
- File-like object not accepted
- Corrupted or non-pickle `.pkl`
- Corrupted or non-torch `.pt`


In [10]:
from io import BytesIO

# 4.1 Unsupported extension
try:
    with tempfile.NamedTemporaryFile(suffix=".xyz", delete=True) as fake:
        ModelRecorder.load_model(fake.name)
except ValueError as e:
    print("Caught expected error (unsupported ext):", e)

# 4.2 File-like object not accepted
try:
    fake_file = BytesIO(b"dummy")
    ModelRecorder.load_model(fake_file)
except ValueError as e:
    print("Caught expected error (file-like not supported):", e)

# 4.3 Non-pickle content in .pkl
try:
    with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as bad_pkl:
        bad_pkl.write(b"not a real pickle payload")
        bad_pkl.flush()
        bad_pkl_path = bad_pkl.name

    ModelRecorder.load_model(bad_pkl_path)
except RuntimeError as e:
    print("Caught expected error (pickle load):", e)
finally:
    if os.path.exists(bad_pkl_path):
        os.remove(bad_pkl_path)

# 4.4 Non-torch content in .pt
try:
    with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as bad_pt:
        bad_pt.write(b"not a real torch payload")
        bad_pt.flush()
        bad_pt_path = bad_pt.name

    ModelRecorder.load_model(bad_pt_path, model=SimpleTorchModel())
except RuntimeError as e:
    print("Caught expected error (torch load):", e)
finally:
    if os.path.exists(bad_pt_path):
        os.remove(bad_pt_path)


Caught expected error (unsupported ext): Unsupported file extension: .xyz
Caught expected error (file-like not supported): Loading from file-like object '<_io.BytesIO object at 0x79fcee134e00>' is not supported. Please provide a valid file path.
Caught expected error (pickle load): Error loading Pickle model: invalid load key, 'n'.
Caught expected error (torch load): Error loading PyTorch model: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 110

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.