Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

negative indices are currently unsupported in rnn example #161

Closed
vinayakumarr opened this issue Apr 5, 2016 · 2 comments
Closed

negative indices are currently unsupported in rnn example #161

vinayakumarr opened this issue Apr 5, 2016 · 2 comments

Comments

@vinayakumarr
Copy link

untitled

import random
import numpy as np
from sklearn import datasets
from sklearn.metrics import accuracy_score, mean_squared_error

import tensorflow as tf

import skflow

random.seed(42)
data = np.array(list([[2, 1, 2, 2, 3],
[2, 2, 3, 4, 5],
[3, 3, 1, 2, 1],
[2, 4, 5, 4, 1]]), dtype=np.float32)

labels for classification

labels = np.array(list([1, 0, 1, 0]), dtype=np.float32)

targets for regression

targets = np.array(list([10, 16, 10, 16]), dtype=np.float32)
test_data = np.array(list([[1, 3, 3, 2, 1], [2, 3, 4, 5, 6]]))

def input_fn(X):
return tf.split(1, 5, X)

    # Classification

classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='lstm', n_classes=2, input_op_fn=input_fn)
classifier.fit(data, labels)
classifier.weights_
classifier.bias_
predictions = classifier.predict(test_data)
self.assertAllClose(predictions, np.array([1, 0]))

classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='rnn', n_classes=2,input_op_fn=input_fn, num_layers=2)
classifier.fit(data, labels)
classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='invalid_cell_type', n_classes=2,input_op_fn=input_fn, num_layers=2)
with self.assertRaises(ValueError):
classifier.fit(data, labels)

Regression

regressor = skflow.TensorFlowRNNRegressor(rnn_size=2, cell_type='gru', input_op_fn=input_fn)
regressor.fit(data, targets)
regressor.weights_
regressor.bias_
predictions = regressor.predict(test_data)

@ilblackdragon
Copy link
Contributor

This is a problem with older version of Tensorflow.
Please try updating to Tensorflow 0.7+.
Also note that skflow will be bundled in next release of Tensorflow. (tf.contrib.skflow)

@vinayakumarr
Copy link
Author

Am using tensorflow 0.7.1 and skflow 0.1.0. Then also it is showing same error. Please have a look on the attached image
untitled

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants