Skip to content

Commit c72ec9b

Browse files
committed
Add initializer_list-based constructors
1 parent 6d62f0f commit c72ec9b

File tree

7 files changed

+66
-13
lines changed

7 files changed

+66
-13
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ install:
2424
- conda info -a
2525
- conda install pytest -c conda-forge
2626
- cd test
27-
- conda install xtensor==0.5.0 pytest numpy pybind11==2.0.1 -c conda-forge
27+
- conda install xtensor==0.6.0 pytest numpy pybind11==2.0.1 -c conda-forge
2828
- xcopy /S %APPVEYOR_BUILD_FOLDER%\include %MINICONDA%\include
2929

3030
build_script:

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ install:
5959
# Useful for debugging any issues with conda
6060
- conda info -a
6161
- cd test
62-
- conda install xtensor==0.5.0 pytest numpy pybind11==2.0.1 -c conda-forge
62+
- conda install xtensor==0.6.0 pytest numpy pybind11==2.0.1 -c conda-forge
6363
- cp -r $TRAVIS_BUILD_DIR/include/* $HOME/miniconda/include/
6464

6565
script:

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ conda install -c conda-forge xtensor-python
2929

3030
| `xtensor-python` | `xtensor` | `pybind11` |
3131
|-------------------|------------|-------------|
32-
| master | 0.5.0 | ^2.0.0 |
32+
| master | 0.6.0 | ^2.0.0 |
33+
| 0.5.0 | 0.5.0 | ^2.0.0 |
3334
| 0.4.0 | 0.5.0 | ^2.0.0 |
3435
| 0.3.0 | ^0.4.1 | ^2.0.0 |
3536
| 0.2.0 | ^0.2.1 | ^1.8.1 |

include/xtensor-python/pyarray.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ namespace xt
108108
using semantic_base = xcontainer_semantic<self_type>;
109109
using base_type = pycontainer<self_type>;
110110
using container_type = typename base_type::container_type;
111+
using value_type = typename base_type::value_type;
111112
using pointer = typename base_type::pointer;
112113
using size_type = typename base_type::size_type;
113114
using shape_type = typename base_type::shape_type;
@@ -117,6 +118,12 @@ namespace xt
117118
using inner_strides_type = typename base_type::inner_strides_type;
118119

119120
pyarray() = default;
121+
pyarray(const value_type& t);
122+
pyarray(nested_initializer_list_t<T, 1> t);
123+
pyarray(nested_initializer_list_t<T, 2> t);
124+
pyarray(nested_initializer_list_t<T, 3> t);
125+
pyarray(nested_initializer_list_t<T, 4> t);
126+
pyarray(nested_initializer_list_t<T, 5> t);
120127

121128
pyarray(pybind11::handle h, pybind11::object::borrowed_t);
122129
pyarray(pybind11::handle h, pybind11::object::stolen_t);
@@ -179,6 +186,48 @@ namespace xt
179186
* pyarray implementation *
180187
**************************/
181188

189+
template <class T>
190+
inline pyarray<T>::pyarray(const value_type& t)
191+
{
192+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
193+
nested_copy(m_data.begin(), t);
194+
}
195+
196+
template <class T>
197+
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 1> t)
198+
{
199+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
200+
nested_copy(m_data.begin(), t);
201+
}
202+
203+
template <class T>
204+
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 2> t)
205+
{
206+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
207+
nested_copy(m_data.begin(), t);
208+
}
209+
210+
template <class T>
211+
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 3> t)
212+
{
213+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
214+
nested_copy(m_data.begin(), t);
215+
}
216+
217+
template <class T>
218+
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 4> t)
219+
{
220+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
221+
nested_copy(m_data.begin(), t);
222+
}
223+
224+
template <class T>
225+
inline pyarray<T>::pyarray(nested_initializer_list_t<T, 5> t)
226+
{
227+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
228+
nested_copy(m_data.begin(), t);
229+
}
230+
182231
template <class T>
183232
inline pyarray<T>::pyarray(pybind11::handle h, pybind11::object::borrowed_t)
184233
: base_type(h, pybind11::object::borrowed)

include/xtensor-python/pytensor.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cstddef>
1313
#include <array>
1414
#include <algorithm>
15+
#include "xtensor/xutils.hpp"
1516
#include "xtensor/xsemantic.hpp"
1617
#include "xtensor/xiterator.hpp"
1718

@@ -83,6 +84,7 @@ namespace xt
8384
using semantic_base = xcontainer_semantic<self_type>;
8485
using base_type = pycontainer<self_type>;
8586
using container_type = typename base_type::container_type;
87+
using value_type = typename base_type::value_type;
8688
using pointer = typename base_type::pointer;
8789
using size_type = typename base_type::size_type;
8890
using shape_type = typename base_type::shape_type;
@@ -92,7 +94,7 @@ namespace xt
9294
using inner_strides_type = typename base_type::inner_strides_type;
9395

9496
pytensor() = default;
95-
97+
pytensor(nested_initializer_list_t<T, N> t);
9698
pytensor(pybind11::handle h, pybind11::object::borrowed_t);
9799
pytensor(pybind11::handle h, pybind11::object::stolen_t);
98100
pytensor(const pybind11::object &o);
@@ -137,6 +139,13 @@ namespace xt
137139
* pytensor implementation *
138140
***************************/
139141

142+
template <class T, std::size_t N>
143+
inline pytensor<T, N>::pytensor(nested_initializer_list_t<T, N> t)
144+
{
145+
base_type::reshape(xt::shape<shape_type>(t), layout::row_major);
146+
nested_copy(m_data.begin(), t);
147+
}
148+
140149
template <class T, std::size_t N>
141150
inline pytensor<T, N>::pytensor(pybind11::handle h, pybind11::object::borrowed_t)
142151
: base_type(h, pybind11::object::borrowed)

test/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@ using complex_t = std::complex<double>;
1111

1212
// Examples
1313

14-
double example1(xt::pyarray<double> &m)
14+
double example1(xt::pyarray<double>& m)
1515
{
1616
return m(0);
1717
}
1818

19-
xt::pyarray<double> example2(xt::pyarray<double> &m)
19+
xt::pyarray<double> example2(xt::pyarray<double>& m)
2020
{
2121
return m + 2;
2222
}
2323

2424
// Readme Examples
2525

26-
double readme_example1(xt::pyarray<double> &m)
26+
double readme_example1(xt::pyarray<double>& m)
2727
{
2828
auto sines = xt::sin(m);
2929
return std::accumulate(sines.begin(), sines.end(), 0.0);

test/test_pyarray.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,27 @@
1616
class ExampleTest(TestCase):
1717

1818
def test_example1(self):
19-
print("test1")
2019
self.assertEqual(4, xt.example1([4, 5, 6]))
2120

2221
def test_example2(self):
23-
print("test2")
2422
x = np.array([[0., 1.], [2., 3.]])
2523
res = np.array([[2., 3.], [4., 5.]])
2624
y = xt.example2(x)
2725
np.testing.assert_allclose(y, res, 1e-12)
2826

2927
def test_vectorize(self):
30-
print("test3")
3128
x1 = np.array([[0, 1], [2, 3]])
3229
x2 = np.array([0, 1])
3330
res = np.array([[0, 2], [2, 4]])
3431
y = xt.vectorize_example1(x1, x2)
3532
np.testing.assert_array_equal(y, res)
3633

3734
def test_readme_example1(self):
38-
print("test4")
3935
v = np.arange(15).reshape(3, 5)
4036
y = xt.readme_example1(v)
4137
np.testing.assert_allclose(y, 1.2853996391883833, 1e-12)
4238

4339
def test_readme_example2(self):
44-
print("test5")
4540
x = np.arange(15).reshape(3, 5)
4641
y = [1, 2, 3, 4, 5]
4742
z = xt.readme_example2(x, y)
@@ -51,7 +46,6 @@ def test_readme_example2(self):
5146
[-1.084323, -0.583843, 0.45342 , 1.073811, 0.706945]], 1e-5)
5247

5348
def test_rect_to_polar(self):
54-
print("test6")
5549
x = np.ones(10, dtype=complex)
5650
z = xt.rect_to_polar(x[::2]);
5751
np.testing.assert_allclose(z, (np.ones(5, dtype=float), np.zeros(5, dtype=float)), 1e-5)

0 commit comments

Comments
 (0)