Skip to content

bug in DataFeeder constructor #1  #7

@dididy

Description

@dididy

System information

centos 7.0
python 3.4
tensorflow 1.2.1

Describe the problem

내가 이런것들을 부를 때 tensorflow.contrib.learn.DNNRegressor.fi(x_train_dict, y_train,steps=1000),x_train_dict 는 dict and y_train 는 array , 프로그램은 예외를 따라 던진다:

 File "/home/star/yuce.ddxq.mobi/zhuge/management/commands/forecast_product_sale.py", line 148, in tflearn_dnn_train2
    regressor.fit(x_train_dict, y_train, steps=10000, batch_size=10)
  File "/usr/lib/python3.4/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func
    return func(*args, **kwargs)
  File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 439, in fit
    SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
  File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1340, in fit
    epochs=None)
  File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 137, in _get_input_fn
    epochs=epochs)
  File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py", line 152, in setup_train_data_feeder
    x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
  File "/usr/lib/python3.4/site-packages/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py", line 326, in __init__
    dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
AttributeError: 'numpy.ndarray' object has no attribute 'items'

그리고 연관된 코드는:

   x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance(
        y, dict)
    if isinstance(y, list):
      y = np.array(y)

    self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
                   ]) if x_is_dict else check_array(x, x.dtype)
    self._y = None if y is None else \
      dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)

코드의 마지막 줄이 잘못된 것 처럼 보인다, 이것은 y_isdict 를 x_is_dict대신쓰지 않았나? 나는 코드를 수정했다:

dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if y_is_dict else check_array(y, y.dtype)

그리고 이것은 작동한다.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions