Skip to content

Commit

Permalink
Working Online Viterbi
Browse files Browse the repository at this point in the history
  • Loading branch information
sbos committed Aug 15, 2011
1 parent cce2759 commit 6b71a5e
Showing 1 changed file with 156 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,60 +1,104 @@
package org.apache.mahout.classifier.sequencelearning.hmm;

import java.io.DataInputStream;
import java.io.FileInputStream;
import org.apache.mahout.math.DenseVector;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class HmmOnlineViterbi {
public static class Node {
static class Node {
public int position, state;
public Node parentNode;
public Node parent;
public Node next;
public Node previous;
public int childNumber;

public Node() {
position = state = -1;
parentNode = null;
parent = null;
next = null;
previous = null;
childNumber = 0;
}

public void setParentNode(Node parent) {
if (parentNode != parent && parentNode != null)
--parentNode.childNumber;
parentNode = parent;
++parentNode.childNumber;
public void setParent(Node parent) {
if (this.parent != null)
--this.parent.childNumber;
this.parent = parent;
if (this.parent != null)
++this.parent.childNumber;
}
}

static void compress(LinkedList<Node> tree) {
ListIterator<Node> iterator = tree.listIterator();
Node node = iterator.next();
while (iterator.hasNext()) {
if (node.childNumber < 1) {
if (node.parentNode != null)
--node.parentNode.childNumber;
iterator.remove();
static class Tree {
Node first, last;

public Tree() {
first = last = null;
}

public void addLast(Node node) {
if (last != null) last.next = node;
if (first == null) first = node;
node.previous = last;
last = node;
}

public void remove(Node node) {
if (node.previous != null)
node.previous.next = node.next;
if (node.next != null)
node.next.previous = node.previous;

if (first == node)
first = node.next;
if (last == node)
last = last.previous;
}

public Node getRoot() {
Node node = first;
while (node != null && node.childNumber < 2) {
node = node.next;
}
else if (node.parentNode != null) {
while (node.parentNode.childNumber == 1) {
node.setParentNode(node.parentNode.parentNode);
return node;
}

public void compress() {
Node node = first;
while (node != null) {
if (node.childNumber < 1) {
if (node.parent != null) {
--node.parent.childNumber;
}
remove(node);
}
else {
while (node.parent != null && node.parent.childNumber == 1) {
remove(node.parent);
node.setParent(node.parent.parent);
}
}
node = node.next;
}
node = iterator.next();
}
}

static int[] traceback(List<int[]> backpointers, int i, int state) {
int[] result = new int[backpointers.size()+1];
result[--i] = state;
while (i > 0) {
--i;
static void traceback(LinkedList<int[]> backpointers, int i, int state, boolean last) {
int[] result = new int[i+1];
result[i] = state;
if (!last) backpointers.remove(i);
--i;
while (i >= 0) {
int[] optimalStates = backpointers.get(i);
backpointers.remove(i);
result[i] = optimalStates[result[i+1]];
--i;
}
for (int aResult : result) System.out.print(aResult + " ");
System.out.println();
return result;
}

private static double getTransitionProbability(HmmModel model, int i, int j) {
Expand All @@ -72,23 +116,24 @@ private static double[] getInitialProbabilities(HmmModel model, int startObserva
Math.log(model.getEmissionMatrix().getQuick(h, startObservation));
return probs;
}

public static Iterable<Integer> onlineViterbi(HmmModel model, Iterable<Integer> observations) {
Iterator<Integer> iterator = observations.iterator();

double[] probs = getInitialProbabilities(model, iterator.next());
LinkedList<Node> tree = new LinkedList<Node>();
Tree tree = new Tree();
Node[] leaves = new Node[model.getNrOfHiddenStates()];
Node root = null;
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
Node node = new Node();
node.position = 0;
node.state = i;
tree.push(node);
tree.addLast(node);
leaves[i] = node;
}

List<int[]> backpointers = new ArrayList<int[]>();
LinkedList<int[]> backpointers = new LinkedList<int[]>();
int i = 1;
int lastOutput = 0;
while (iterator.hasNext()) {
int observation = iterator.next();
double[] newProbs = new double[model.getNrOfHiddenStates()];
Expand All @@ -110,21 +155,22 @@ public static Iterable<Integer> onlineViterbi(HmmModel model, Iterable<Integer>
Node node = new Node();
node.position = i;
node.state = k;
node.setParentNode(leaves[optimalStates[k]]);
node.setParent(leaves[maxState]);
newLeaves[k] = node;
tree.push(node);
//tree.addLast(node);
}
backpointers.add(optimalStates);

Node oldRoot = tree.getLast();
compress(tree);
Node newRoot = tree.getLast();
if (newRoot != oldRoot) {
traceback(backpointers, newRoot.position - lastOutput, newRoot.state);
lastOutput = i;
tree.compress();
Node newRoot = tree.getRoot();
if (root != newRoot && newRoot != null) {
traceback(backpointers, i - newRoot.position - 1, newRoot.state, false);
leaves = newLeaves;
root = newRoot;
}

for (Node leave: newLeaves)
tree.addLast(leave);
probs = newProbs;
++i;
}
Expand All @@ -135,16 +181,78 @@ public static Iterable<Integer> onlineViterbi(HmmModel model, Iterable<Integer>
maxState = k;
}

if (backpointers.size() > 0)
traceback(backpointers, i - lastOutput, maxState);
if (backpointers.size() > 0) {
traceback(backpointers, backpointers.size(), maxState, true);
}
return null;
}

private static HmmModel createBad() {
HmmModel model = new HmmModel(4, 3);
double e = 0.3, f = 0.2;

model.setInitialProbabilities(new DenseVector(new double[] {0.25, 0.25, 0.25, 0.25}));

model.getEmissionMatrix().set(0, 0, e);
model.getEmissionMatrix().set(0, 1, 1.0-e);
model.getEmissionMatrix().set(0, 2, 0);

model.getEmissionMatrix().set(1, 0, 1.0-e);
model.getEmissionMatrix().set(1, 1, e);
model.getEmissionMatrix().set(1, 2, 0);

model.getEmissionMatrix().set(2, 0, e-f);
model.getEmissionMatrix().set(2, 1, 1.0-e-f);
model.getEmissionMatrix().set(2, 2, 2.0*f);

model.getEmissionMatrix().set(3, 0, 1.0-e-f);
model.getEmissionMatrix().set(3, 1, e-f);
model.getEmissionMatrix().set(3, 2, 2.0*f);

model.getTransitionMatrix().set(0, 0, 0.5);
model.getTransitionMatrix().set(0, 1, 0.5);
model.getTransitionMatrix().set(1, 0, 0.5);
model.getTransitionMatrix().set(1, 1, 0.5);

model.getTransitionMatrix().set(2, 3, 0.5);
model.getTransitionMatrix().set(3, 2, 0.5);
model.getTransitionMatrix().set(3, 3, 0.5);
model.getTransitionMatrix().set(2, 2, 0.5);

return model;
}

public static void main(String[] args) throws IOException {
HmmModel model = new HmmModel(2, 2);
double e = 0.01;
model.setInitialProbabilities(new DenseVector(new double[] {e, 1.0-e}));

HmmOnlineViterbi.onlineViterbi(LossyHmmSerializer.deserialize(new DataInputStream(new FileInputStream("../hmm.model"))),
Arrays.asList(1,1,1,0,0,0,0,0,0,1,1,1,0,0,2,2,2,
2, 2, 2,2, 2, 2, 3, 3, 3, 3, 3, 3 ,3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 3, 3, 3,
3, 3, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3));
model.getEmissionMatrix().set(0, 0, e);
model.getEmissionMatrix().set(0, 1, 1.0-e);
model.getEmissionMatrix().set(1, 0, 1.0-e);
model.getEmissionMatrix().set(1, 1, e);
model.getTransitionMatrix().set(0, 0, 0.5);
model.getTransitionMatrix().set(0, 1, 0.5);
model.getTransitionMatrix().set(1, 0, 0.5);
model.getTransitionMatrix().set(1, 1, 0.5);

//model = LossyHmmSerializer.deserialize(new DataInputStream(new FileInputStream("../hmm.model")));
model = createBad();

int[] data = HmmEvaluator.predict(model, 27);
System.out.print("Seq: ");
for (int x: data)
System.out.print(x + " ");
System.out.println();
//System.out.print("Std: ");
int[] dec = HmmAlgorithms.viterbiAlgorithm(model, data, true);
for (int x: dec)
System.out.print(x + " ");
System.out.println();
List<Integer> shit = new ArrayList<Integer>();
for (int x: data)
shit.add(x);
HmmOnlineViterbi.onlineViterbi(model,
shit);
}
}

0 comments on commit 6b71a5e

Please sign in to comment.