Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

We’re showing branches in this repository, but you can also compare across forks.

base fork: pavlov99/pmll
...
head fork: pavlov99/pmll
compare: 0681a237d1
  • 2 commits
  • 2 files changed
  • 0 commit comments
  • 1 contributor
Showing with 29 additions and 2 deletions.
  1. +26 −0 python/examples/demo_irls.py
  2. +3 −2 python/pmll/classification.py
26 python/examples/demo_irls.py
View
@@ -0,0 +1,26 @@
+import numpy as np
+import scipy
+from scipy.io import loadmat
+import matplotlib.pyplot as plt
+
+import sys
+sys.path.append('..') # parent directory with library
+from pmll import classification
+
+if __name__ == '__main__':
+ data = scipy.io.loadmat('../../data/iris.mat')
+ x = data['X'][50:]
+ y = data['Y'][50:] - 1
+
+ number_random_features = 3
+ x = np.hstack([x, np.random.randn(x.shape[0], number_random_features)])
+
+ model_irls = classification.IrlsModel()
+ model_irls.train(x, y, regularization=1e-3, max_iterations=500)
+
+ plt.plot(np.hstack(model_irls._IrlsModel__history['weights']).T)
+ plt.plot(model_irls._IrlsModel__history['weight_change'])
+ plt.show()
+
+ classifier_irls = classification.IrlsClassifier(model_irls)
+ print classifier_irls.classify(x)
5 python/pmll/classification.py
View
@@ -105,7 +105,8 @@ def train(self, objects, labels, object_weights=None, max_iterations=100,
np.ones([objects.shape[0], 1]),
)))
labels = np.asmatrix(labels)
- object_weights = object_weights or np.array([[1]] * objects.shape[0])
+ if object_weights is None:
+ object_weights = np.array([[1]] * objects.shape[0])
I = regularization * np.eye(objects.shape[1])
# Initialize weights
@@ -117,7 +118,7 @@ def train(self, objects, labels, object_weights=None, max_iterations=100,
probability = classifier.classify(objects[:, :-1])
object_weights_new = np.multiply(
probability - np.power(probability, 2),
- object_weights or np.array([[1]] * objects.shape[0]),
+ object_weights,
)
X = objects

No commit comments for this range

Something went wrong with that request. Please try again.