You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
This is a problem with older version of Tensorflow.
Please try updating to Tensorflow 0.7+.
Also note that skflow will be bundled in next release of Tensorflow. (tf.contrib.skflow)
import random
import numpy as np
from sklearn import datasets
from sklearn.metrics import accuracy_score, mean_squared_error
import tensorflow as tf
import skflow
random.seed(42)
data = np.array(list([[2, 1, 2, 2, 3],
[2, 2, 3, 4, 5],
[3, 3, 1, 2, 1],
[2, 4, 5, 4, 1]]), dtype=np.float32)
labels for classification
labels = np.array(list([1, 0, 1, 0]), dtype=np.float32)
targets for regression
targets = np.array(list([10, 16, 10, 16]), dtype=np.float32)
test_data = np.array(list([[1, 3, 3, 2, 1], [2, 3, 4, 5, 6]]))
def input_fn(X):
return tf.split(1, 5, X)
classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='lstm', n_classes=2, input_op_fn=input_fn)
classifier.fit(data, labels)
classifier.weights_
classifier.bias_
predictions = classifier.predict(test_data)
self.assertAllClose(predictions, np.array([1, 0]))
classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='rnn', n_classes=2,input_op_fn=input_fn, num_layers=2)
classifier.fit(data, labels)
classifier = skflow.TensorFlowRNNClassifier(rnn_size=2, cell_type='invalid_cell_type', n_classes=2,input_op_fn=input_fn, num_layers=2)
with self.assertRaises(ValueError):
classifier.fit(data, labels)
Regression
regressor = skflow.TensorFlowRNNRegressor(rnn_size=2, cell_type='gru', input_op_fn=input_fn)
regressor.fit(data, targets)
regressor.weights_
regressor.bias_
predictions = regressor.predict(test_data)
The text was updated successfully, but these errors were encountered: