In [1]:
import dds

## Custom library integration - PyTorch example

DDS can deal with arbitrary Python pickle objects, which covers a large range of use cases. This is not enough for specialized libraries that may have their own storage formats. This example takes a simple pytorch example and shows how to add support for custom objects in DDS.

The original example is at https://github.com/pytorch/examples/blob/master/regression/main.py .

In [2]:
from itertools import count

import torch
import torch.nn.functional as F
import functools

torch.manual_seed(1)

POLY_DEGREE = 4

# Because randn is not idempotent, it needs to be wrapped to ensure it is executed once
@functools.lru_cache(maxsize=None)
def W_target(): return torch.randn(POLY_DEGREE, 1) * 5

@functools.lru_cache(maxsize=None)
def b_target(): return torch.randn(1) * 5


def make_features(x):
    """Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
    x = x.unsqueeze(1)
    return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)


def f(x):
    """Approximated function."""
    return x.mm(W_target()) + b_target().item()


def poly_desc(W, b):
    """Creates a string description of a polynomial."""
    result = 'y = '
    for i, w in enumerate(W):
        result += '{:+.2f} x^{} '.format(w, i + 1)
    result += '{:+.2f}'.format(b[0])
    return result


def get_batch(batch_size=32):
    """Builds a batch i.e. (x, f(x)) pair."""
    random = torch.randn(batch_size)
    x = make_features(random)
    y = f(x)
    return x, y


The first difference with the example is that all the PyTorch variables should be in a function. If they are just top-level variables, they will be ignored during the calculation of the checksums of the final outcome.

In [3]:
def train_function():
    
    # Define model
    fc = torch.nn.Linear(W_target().size(0), 1)

    for batch_idx in count(1):
        # Get data
        batch_x, batch_y = get_batch()

        # Reset gradients
        fc.zero_grad()

        # Forward pass
        output = F.smooth_l1_loss(fc(batch_x), batch_y)
        loss = output.item()

        # Backward pass
        output.backward()

        # Apply gradients
        for param in fc.parameters():
            param.data.add_(-0.1 * param.grad)

        # Stop criterion
        if loss < 1e-3:
            return loss, batch_idx, fc

# Let's check first that the function works correctly:

loss, batch_idx, fc = train_function()

print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target().view(-1), b_target()))

Loss: 0.000630 after 301 batches
==> Learned function:	y = +3.28 x^1 +1.33 x^2 +0.35 x^3 +3.12 x^4 +2.92
==> Actual function:	y = +3.31 x^1 +1.33 x^2 +0.31 x^3 +3.11 x^4 +2.92
