This notebook tests whether trulens plays well with pytorch when there are multiple CUDA devices available. It has to run on a computer with at least 2 CUDA devices.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

# Use this if running this notebook from within its place in the truera repository.
sys.path.insert(0, "../..")

# Or otherwise install trulens.
# !{sys.executable} -m pip install trulens torch

import torch

assert torch.cuda.device_count() >= 2, f"need at least 2 cuda devices to run this test, have {torch.cuda.device_count()}"

In [None]:
class Trivial(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # make sure there is at least 1 parameter in this model
        self.softmax = torch.nn.Linear(in_features=1, out_features=1)

torch.cuda.set_device(1)
device = torch.device("cuda", 1)
model = Trivial().to(device)

In [None]:
def check_device():
    """
    Assert that the model parameters' device is the same as the one we set it to above.
    """
    for p in model.parameters():
        assert p.device.type == device.type and p.device.index == device.index, f"Expected model to be on device {device} but got some parameters on {p.device}."
    print("all good")

# Check model's device before importing and wrapping it using trulens.
print("pre import")
check_device()

from trulens.nn.models import get_model_wrapper

# Check model's device after importing trulens.
print("post import")
check_device()
wrapper = get_model_wrapper(model)

# Finally check after wrapping.
print("post wrap")
check_device()