Skip to content

Commit

Permalink
ENH: R poly compatibility
Browse files Browse the repository at this point in the history
Travis fixes
  • Loading branch information
thequackdaddy committed Sep 15, 2016
1 parent 8b6c712 commit d290dd3
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 0 deletions.
5 changes: 5 additions & 0 deletions doc/API-reference.rst
Expand Up @@ -198,6 +198,11 @@ Spline regression
.. autofunction:: cc
.. autofunction:: te

Orthogonal Polynomial
---------------------

.. autofunction:: poly

Working with formulas programmatically
--------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions patsy/__init__.py
Expand Up @@ -113,5 +113,8 @@ def _reexport(mod):
import patsy.mgcv_cubic_splines
_reexport(patsy.mgcv_cubic_splines)

import patsy.poly
_reexport(patsy.poly)

# XX FIXME: we aren't exporting any of the explicit parsing interface
# yet. Need to figure out how to do that.
179 changes: 179 additions & 0 deletions patsy/poly.py
@@ -0,0 +1,179 @@
# This file is part of Patsy
# Copyright (C) 2012-2013 Nathaniel Smith <njs@pobox.com>
# See file LICENSE.txt for license information.

# R-compatible poly function

# These are made available in the patsy.* namespace
__all__ = ["poly"]

import numpy as np

from patsy.util import have_pandas, no_pickling, assert_no_pickling
from patsy.state import stateful_transform

if have_pandas:
import pandas

class Poly(object):
"""poly(x, degree=1, raw=False)
Generates an orthogonal polynomial transformation of x of degree.
Generic usage is something along the lines of::
y ~ 1 + poly(x, 4)
to fit ``y`` as a function of ``x``, with a 4th degree polynomial.
:arg degree: The number of degrees for the polynomial expansion.
:arg raw: When raw is False (the default), will return orthogonal
polynomials.
.. versionadded:: 0.4.1
"""
def __init__(self):
self._tmp = {}
self._degree = None
self._raw = None

def memorize_chunk(self, x, degree=3, raw=False):
args = {"degree": degree,
"raw": raw
}
self._tmp["args"] = args
# XX: check whether we need x values before saving them
x = np.atleast_1d(x)
if x.ndim == 2 and x.shape[1] == 1:
x = x[:, 0]
if x.ndim > 1:
raise ValueError("input to 'poly' must be 1-d, "
"or a 2-d column vector")
# There's no better way to compute exact quantiles than memorizing
# all data.
x = np.array(x, dtype=float)
self._tmp.setdefault("xs", []).append(x)

def memorize_finish(self):
tmp = self._tmp
args = tmp["args"]
del self._tmp

if args["degree"] < 1:
raise ValueError("degree must be greater than 0 (not %r)"
% (args["degree"],))
if int(args["degree"]) != args["degree"]:
raise ValueError("degree must be an integer (not %r)"
% (self._degree,))

# These are guaranteed to all be 1d vectors by the code above
scores = np.concatenate(tmp["xs"])
scores_mean = scores.mean()
# scores -= scores_mean
self.scores_mean = scores_mean
n = args['degree']
self.degree = n
raw_poly = scores.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1))
raw = args['raw']
self.raw = raw
if not raw:
q, r = np.linalg.qr(raw_poly)
# Q is now orthognoal of degree n. To match what R is doing, we
# need to use the three-term recurrence technique to calculate
# new alpha, beta, and norm.

self.alpha = (np.sum(scores.reshape((-1, 1)) * q[:, :n] ** 2,
axis=0) /
np.sum(q[:, :n] ** 2, axis=0))

# For reasons I don't understand, the norms R uses are based off
# of the diagonal of the r upper triangular matrix.

self.norm = np.linalg.norm(q * np.diag(r), axis=0)
self.beta = (self.norm[1:] / self.norm[:n]) ** 2

def transform(self, x, degree=3, raw=False):
if have_pandas:
if isinstance(x, (pandas.Series, pandas.DataFrame)):
to_pandas = True
idx = x.index
else:
to_pandas = False
else:
to_pandas = False
x = np.array(x, ndmin=1).flatten()

if self.raw:
n = self.degree
p = x.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1))
else:
# This is where the three-term recurrance technique is unwound.

p = np.empty((x.shape[0], self.degree + 1))
p[:, 0] = 1

for i in np.arange(self.degree):
p[:, i + 1] = (x - self.alpha[i]) * p[:, i]
if i > 0:
p[:, i + 1] = (p[:, i + 1] -
(self.beta[i - 1] * p[:, i - 1]))
p /= self.norm

p = p[:, 1:]
if to_pandas:
p = pandas.DataFrame(p)
p.index = idx
return p

__getstate__ = no_pickling

poly = stateful_transform(Poly)


def test_poly_compat():
from patsy.test_state import check_stateful
from patsy.test_poly_data import (R_poly_test_x,
R_poly_test_data,
R_poly_num_tests)
lines = R_poly_test_data.split("\n")
tests_ran = 0
start_idx = lines.index("--BEGIN TEST CASE--")
while True:
if not lines[start_idx] == "--BEGIN TEST CASE--":
break
start_idx += 1
stop_idx = lines.index("--END TEST CASE--", start_idx)
block = lines[start_idx:stop_idx]
test_data = {}
for line in block:
key, value = line.split("=", 1)
test_data[key] = value
# Translate the R output into Python calling conventions
kwargs = {
# integer
"degree": int(test_data["degree"]),
# boolen
"raw": (test_data["raw"] == 'TRUE')
}
# Special case: in R, setting intercept=TRUE increases the effective
# dof by 1. Adjust our arguments to match.
# if kwargs["df"] is not None and kwargs["include_intercept"]:
# kwargs["df"] += 1
output = np.asarray(eval(test_data["output"]))
# Do the actual test
check_stateful(Poly, False, R_poly_test_x, output, **kwargs)
tests_ran += 1
# Set up for the next one
start_idx = stop_idx + 1
assert tests_ran == R_poly_num_tests


def test_poly_errors():
from nose.tools import assert_raises
x = np.arange(27)
# Invalid input shape
assert_raises(ValueError, poly, x.reshape((3, 3, 3)))
assert_raises(ValueError, poly, x.reshape((3, 3, 3)), raw=True)
# Invalid degree
assert_raises(ValueError, poly, x, degree=-1)
assert_raises(ValueError, poly, x, degree=0)
assert_raises(ValueError, poly, x, degree=3.5)
37 changes: 37 additions & 0 deletions patsy/test_poly_data.py
@@ -0,0 +1,37 @@
# This file auto-generated by tools/get-R-bs-test-vectors.R
# Using: R version 3.2.4 Revised (2016-03-16 r70336)
import numpy as np
R_poly_test_x = np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ])
R_poly_test_data = """
--BEGIN TEST CASE--
degree=1
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]).reshape((20, 1, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=1
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, ]).reshape((20, 1, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=3
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, ]).reshape((20, 3, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=3
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, ]).reshape((20, 3, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=5
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, 1, 5.0625, 25.62890625, 129.746337890625, 656.84083557128906, 3325.2567300796509, 16834.112196028233, 85222.692992392927, 431439.8832739892, 2184164.4090745705, 11057332.320940012, 55977744.87475881, 283387333.4284665, 1434648375.4816115, 7262907400.875659, 36768468716.933022, 186140372879.47342, 942335637702.33411, 4770574165868.0674, 24151031714707.086, 1, 7.59375, 57.6650390625, 437.89389038085938, 3325.2567300796509, 25251.168294042349, 191751.05923288409, 1456109.6060497134, 11057332.320940012, 83966617.31213823, 637621500.21404958, 4841938267.2504387, 36768468716.933022, 279210559319.21014, 2120255184830.252, 16100687809804.727, 122264598055704.64, 928446791485507, 7050392822843070, 53538920498464552, ]).reshape((20, 5, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=5
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, 0.11925766326375063, 0.11701962699862156, 0.11367531238125347, 0.10868744714732725, 0.10126981942884175, 0.090287103769210786, 0.074134201646975206, 0.050620044131431986, 0.016933017097416861, -0.030116712154368355, -0.093138533517390085, -0.17160263551697441, -0.25618209006285081, -0.3183631162695052, -0.29707753517866498, -0.10102478727647804, 0.30185248746535442, 0.55289166632880227, -0.46108564710186972, 0.081962667419115426, -0.12626707822019206, -0.12250155553682644, -0.11689136915447108, -0.10856147160045609, -0.096257598068575617, -0.078227654788373013, -0.052128116579684983, -0.015063001240831148, 0.035988153544508683, 0.10280803884977513, 0.18263307034840112, 0.26144732880503613, 0.30325203347309243, 0.24116709207723347, -0.00082575540196283526, -0.37830141983168153, -0.42887161757203512, 0.55207091753656046, -0.17171017635275559, 0.016240179713238136, ]).reshape((20, 5, ), order="F")
--END TEST CASE--
"""
R_poly_num_tests = 6
62 changes: 62 additions & 0 deletions tools/get-R-poly-test-vectors.R
@@ -0,0 +1,62 @@
cat("# This file auto-generated by tools/get-R-bs-test-vectors.R\n")
cat(sprintf("# Using: %s\n", R.Version()$version.string))
cat("import numpy as np\n")

options(digits=20)
library(splines)
x <- (1.5)^(0:19)

MISSING <- "MISSING"

is.missing <- function(obj) {
length(obj) == 1 && obj == MISSING
}

pyprint <- function(arr) {
if (is.missing(arr)) {
cat("None\n")
} else {
cat("np.array([")
for (val in arr) {
cat(val)
cat(", ")
}
cat("])")
if (!is.null(dim(arr))) {
cat(".reshape((")
for (size in dim(arr)) {
cat(sprintf("%s, ", size))
}
cat("), order=\"F\")")
}
cat("\n")
}
}

num.tests <- 0
dump.poly <- function(degree, raw) {
cat("--BEGIN TEST CASE--\n")
cat(sprintf("degree=%s\n", degree))
cat(sprintf("raw=%s\n", raw))

args <- list(x=x, degree=degree, raw=raw)

result <- do.call(poly, args)

cat("output=")
pyprint(result)
cat("--END TEST CASE--\n")
assign("num.tests", num.tests + 1, envir=.GlobalEnv)
}

cat("R_poly_test_x = ")
pyprint(x)
cat("R_poly_test_data = \"\"\"\n")

for (degree in c(1, 3, 5)) {
for (raw in c(TRUE, FALSE)) {
dump.poly(degree, raw)
}
}
cat("\"\"\"\n")
cat(sprintf("R_poly_num_tests = %s\n", num.tests))

0 comments on commit d290dd3

Please sign in to comment.