Skip to content

NUTS sampler needs more than 6 GB RAM on model with just 26 unkowns #1420

@bjornsing

Description

@bjornsing

I'm trying to build a relatively simple model in pymc3 (version 3.0rc1) but my script (see below) hangs when instantiating the NUTS sampler. Top tells me it's because the python process needs more than 5.9 GB RAM, starts paging to disk and grinds to a halt. To me this seems like an extraordinary amount of memory given that the model contains only 26 unknowns.

The code:

import numpy as np
import theano
assert(theano.config.device == 'cpu'), "This code should probably run on the cpu..."
assert(theano.config.floatX == 'float64'), "This code need double precision floats..."
from pymc3 import Model, Uniform, Normal, Gamma, StudentT

# Length of core in cm:
L = 100

# Radiocarbon dates (depth in cm and age in years):
RCDs = [{'depth': 45.0, 'age': 3200.0, 'sd': 130},
        {'depth': 60.0, 'age': 5700.0, 'sd': 130},
        {'depth': 90.0, 'age': 7500.0, 'sd': 130},]

# Number of segments:
K = 25

# Year core was retrieved:
A = 2016

# MODEL SETUP - Essentially and attempt to implement https://projecteuclid.org/euclid.ba/1339616472

# Standard deviation in top layer age:
sigma_A = 10

# Gamma parameters:
# TBD: Shouldn't these parameters be "learnt" from the data?
a_alpha = 2
b_alpha = 2.0/1.1

bacon_model = Model()

with bacon_model:

    # TBD: Shouldn't this be a truncated normal distribution?
    theta = Normal('theta', mu=A, sd=sigma_A)

    # TODO: It seems prior on w shouldn't be uniform, but what should it be...?
    w = Uniform('w', lower=0.0, upper=1.0)

    alpha = [Gamma('a_' + str(j), alpha=a_alpha, beta=b_alpha) for j in range(K)]

    def x(j):
        if j <= K:
            return w * x(j+1) + (1.0 - w) * alpha[j - 1]
        else:
            # TBD: Is this correct...?
            return 0

    delta_c = float(L)/float(K)

    def t(z):
        depth = 0
        age   = theta

        j = 1
        while depth < (z - delta_c):
            depth += delta_c
            age   += x(j) * delta_c
            j     += 1

        age += x(j) * (z - depth)

        return age

    for RCD in RCDs:
        # TODO: Switch from Normal to StudentT...
        RCD['variable'] = Normal('RCD_' + str(RCD['depth']), mu=t(RCD['depth']), sd=RCD['sd'], observed=RCD['age'])

# INFERENCE - Copy-pasted from https://peerj.com/preprints/1686.pdf

from scipy import optimize
from pymc3 import find_MAP, NUTS, sample
import datetime

with bacon_model:

    print "*** Finding MAP (%s) ***" % datetime.datetime.now()

    start = find_MAP(fmin=optimize.fmin_powell)

    print "MAP:", start

    print "*** Instantiating NUTS sampler (%s) ***" % datetime.datetime.now()

    step = NUTS(scaling=start)

    # NOT REACHED,
    # because NUTS._init_ seems to need more than 5.9 GB RAM and system starts swapping...
    #
    # Relevant output line from top:
    #   PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME+ COMMAND
    #26290 xxxxx     20   0 10.233g 5.901g   2600 D  12.6 76.8  48:50.91 python

    print "*** Ready to sample (%s) ***" % datetime.datetime.now()

    trace = sample(2000, step, start=start)

    print "*** Done sampling (%s) ***" % datetime.datetime.now()

Since I've never managed to run the model it may not be correct, but in any case I find it rather strange that it needs so much memory.

The version of pymc3 I'm using:

$ pip show pymc3

---
Metadata-Version: 2.0
Name: pymc3
Version: 3.0rc1
Summary: PyMC3
Home-page: http://github.com/pymc-devs/pymc
Author: Thomas Wiecki
Author-email: thomas.wiecki@gmail.com
Installer: pip
License: Apache License, Version 2.0
Location: /home/coach/local/miniconda2/lib/python2.7/site-packages
Requires: nbsphinx, CommonMark, joblib, numpy, patsy, sphinx, pandas, scipy, numpydoc, matplotlib, theano, recommonmark
Classifiers:
  Development Status :: 5 - Production/Stable
  Programming Language :: Python
  Programming Language :: Python :: 2
  Programming Language :: Python :: 3
  Programming Language :: Python :: 2.7
  Programming Language :: Python :: 3.4
  Programming Language :: Python :: 3.5
  License :: OSI Approved :: Apache Software License
  Intended Audience :: Science/Research
  Topic :: Scientific/Engineering
  Topic :: Scientific/Engineering :: Mathematics
  Operating System :: OS Independent

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions