In [1]:
import torch
import matplotlib.pyplot as plt

from cnp.cov import (
    MeanFieldGaussianLayer,
    InnerprodGaussianLayer,
    KvvGaussianLayer
)

from cnp.cnp import (
    StandardGNP,
    StandardAGNP,
    StandardConvGNP,
    FullConvGNP
)

def warn(*args, **kwargs):
    pass

import warnings
warnings.warn = warn

In [2]:
def measure_model_forward(model_name):

    args_cov_type = 'kvv'
    args_model = model_name
    args_x_dim = 1
    args_num_basis_dim = 512
    args_noise_type = 'hetero'
    device = 'cuda:0'

    # =============================================================================
    # Create model
    # =============================================================================

    cov_types = {
        'meanfield' : MeanFieldGaussianLayer,
        'innerprod' : InnerprodGaussianLayer,
        'kvv'       : KvvGaussianLayer
    }

    if args_cov_type == 'meanfield':
        output_layer = MeanFieldGaussianLayer()

    else:
        output_layer = cov_types[args_cov_type](num_embedding=args_num_basis_dim,
                                                noise_type=args_noise_type)

    # Create model architecture
    if args_model == 'GNP':
        model = StandardGNP(input_dim=args_x_dim, output_layer=output_layer)

    elif args_model == 'AGNP':
        model = StandardAGNP(input_dim=args_x_dim, output_layer=output_layer)

    elif args_model == 'convGNP':
        model = StandardConvGNP(input_dim=args_x_dim, output_layer=output_layer)

    elif args_model == 'FullConvGNP':
        model = FullConvGNP()

    elif args_model == 'ANP':
        model = StandardANP(input_dim=args_x_dim,
                            num_samples=args_np_loss_samples)

    elif args_model == 'convNP':
        model = StandardConvNP(input_dim=args_x_dim,
                               num_samples=args_np_loss_samples)

    else:
        raise ValueError(f'Unknown model {args_model}.')

    assert torch.cuda.memory_allocated() == 0

    # Load model to appropriate device
    model = model.to(device)
    
    N = 50
    M = 100
    D = 1
    u = 8.

    torch.manual_seed(0)

    x_context = torch.linspace(-u, u, N)[None, :, None].to(device)
    y_context = torch.rand(size=(1, N, 1)).to(device)
    x_target = torch.linspace(-u, u, M)[None, :, None].to(device)
    
    model.forward(x_context, y_context, x_target)
    print(f'{model_name} Memory: ', torch.cuda.memory_allocated())

    %timeit _ = model.forward(x_context, y_context, x_target)

In [3]:
measure_model_forward('FullConvGNP')

FullConvGNP Memory:  1227776
4.59 ms ± 5.05 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
# Using a seed of 0
# GNP         584 µs ± 740 ns   , 476160  Bytes
# AGNP        1.62 ms ± 1.68 µs , 2971136 Bytes
# ConvGNP     1.74 ms ± 1.19 µs , 231424  Bytes
# FullConvGNP 4.59 ms ± 5.05 µs , 1227776 Bytes

In [5]:
!nvidia-smi

Mon Nov 22 19:17:09 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:17:00.0 Off |                  N/A |
| 22%   30C    P8    10W / 250W |   1328MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:B3:00.0 Off |                  N/A |
| 22%   29C    P8    21W / 250W |      3MiB / 11019MiB |      0%      Default |
|       