Skip to content

Commit

Permalink
Tree serialization implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
sbos committed Aug 15, 2011
1 parent b537c9d commit c241203
Showing 1 changed file with 145 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,40 +1,35 @@
package org.apache.mahout.classifier.sequencelearning.hmm;

import com.google.common.base.Function;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.sequencelearning.hmm.mapreduce.HiddenStateProbabilitiesWritable;
import org.apache.mahout.math.DenseVector;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class HmmOnlineViterbi {
public class HmmOnlineViterbi implements Writable {
private HmmModel model;
private double[] probs;
private Node[] leaves;
private Node root;
private Tree tree;
private LinkedList<int[]> backpointers;
private int i;
private double lastLikelihood;
private Function<int[], Void> output;

public HmmOnlineViterbi(HmmModel model) {
this.model = model;
probs = null;
tree = new Tree();
leaves = new Node[model.getNrOfHiddenStates()];
root = null;
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
Node node = new Node();
node.position = 0;
node.state = i;
tree.addLast(node);
leaves[i] = node;
}

backpointers = new LinkedList<int[]>();
i = 0;
clear();
}

public HmmOnlineViterbi(HmmModel model, Function<int[], Void> output) {
Expand All @@ -46,6 +41,14 @@ public void setOutput(Function<int[], Void> output) {
this.output = output;
}

public HmmModel getModel() {
return model;
}

public double getLastLogLikelihood() {
return lastLikelihood;
}

public void process(Iterable<Integer> observations) {
Iterator<Integer> iterator = observations.iterator();

Expand Down Expand Up @@ -77,13 +80,13 @@ public void process(Iterable<Integer> observations) {
node.state = k;
node.setParent(leaves[maxState]);
newLeaves[k] = node;
//tree.addLast(node);
}
backpointers.add(optimalStates);

tree.compress();
Node newRoot = tree.getRoot();
if (root != newRoot && newRoot != null) {
lastLikelihood = newProbs[newRoot.state];
traceback(i - newRoot.position - 1, newRoot.state, false);
leaves = newLeaves;
root = newRoot;
Expand All @@ -96,7 +99,26 @@ public void process(Iterable<Integer> observations) {
}
}

public void finish() {
private void clear() {
probs = null;
tree = new Tree();
leaves = new Node[model.getNrOfHiddenStates()];
root = null;
for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
Node node = new Node();
node.position = 0;
node.state = i;
tree.addLast(node);
leaves[i] = node;
}

backpointers = new LinkedList<int[]>();
i = 0;

lastLikelihood = -Double.MAX_VALUE;
}

public double finish() {
int maxState = 0;
for (int k = 1; k < model.getNrOfHiddenStates(); ++k) {
if (probs[k] > probs[maxState])
Expand All @@ -106,6 +128,10 @@ public void finish() {
if (backpointers.size() > 0) {
traceback(backpointers.size(), maxState, true);
}

double result = probs[maxState];
clear();
return result;
}

private void traceback(int i, int state, boolean last) {
Expand All @@ -123,6 +149,69 @@ private void traceback(int i, int state, boolean last) {
output.apply(result);
}

@Override
public void write(DataOutput output) throws IOException {
HiddenStateProbabilitiesWritable probs = new HiddenStateProbabilitiesWritable(this.probs);
probs.write(output);

output.write(backpointers.size());
for (int[] optimalStates: backpointers) {
for (int state: optimalStates) {
output.write(state);
}
}

output.write(i);
DoubleWritable doubleWritable = new DoubleWritable(lastLikelihood);
doubleWritable.write(output);

//serializing compressed backpointers tree
//3 * |H| - 2 is the upper bound of node number
HashBiMap<Node, Integer> nodeMap = HashBiMap.create(3 * model.getNrOfHiddenStates() - 2);
root.write(nodeMap, output);
for (Node leave: leaves)
leave.write(nodeMap, output);

output.write(tree.size);
Node node = tree.first;
while (node != null) {
node.write(nodeMap, output);
node = node.next;
}
}

@Override
public void readFields(DataInput input) throws IOException {
clear();

HiddenStateProbabilitiesWritable probs = new HiddenStateProbabilitiesWritable();
probs.readFields(input);
this.probs = probs.toProbabilityArray();

int backpointersSize = input.readInt();
for (int i = 0; i < backpointersSize; ++i) {
int[] optimalStates = new int[model.getNrOfHiddenStates()];
for (int j = 0; j < optimalStates.length; ++j)
optimalStates[j] = input.readInt();
backpointers.addLast(optimalStates);
}

i = input.readInt();
DoubleWritable doubleWritable = new DoubleWritable();
doubleWritable.readFields(input);
lastLikelihood = doubleWritable.get();

BiMap<Node, Integer> nodeMap = HashBiMap.create(3 * model.getNrOfHiddenStates() - 2);
root = Node.read(nodeMap, input);

for (int i = 0; i < leaves.length; ++i)
leaves[i] = Node.read(nodeMap, input);

tree.size = input.readInt();
for (int i = 0; i < tree.size; ++i)
tree.addLast(Node.read(nodeMap, input));
}

static class Node {
public int position, state;
public Node parent;
Expand All @@ -138,6 +227,38 @@ public Node() {
childNumber = 0;
}

public int write(BiMap<Node, Integer> map, DataOutput output) throws IOException {
Integer index = map.get(this);
if (index == null) {
output.writeBoolean(true);
index = map.put(this, map.size());
output.write(index);
output.write(position);
output.write(state);
output.write(childNumber);
parent.write(map, output);
} else {
output.writeBoolean(false);
output.write(index);
}

return index;
}

public static Node read(BiMap<Node, Integer> map, DataInput input) throws IOException {
boolean first = input.readBoolean();
int index = input.readInt();
if (first) {
Node node = new Node();
node.position = input.readInt();
node.state = input.readInt();
node.childNumber = input.readInt();
node.parent = map.inverse().get(input.readInt());
return node;
}
return map.inverse().get(index);
}

public void setParent(Node parent) {
if (this.parent != null)
--this.parent.childNumber;
Expand All @@ -148,17 +269,17 @@ public void setParent(Node parent) {
}

static class Tree {
Node first, last;
Node first;
int size;

public Tree() {
first = last = null;
first = null;
size = 0;
}

public void addLast(Node node) {
if (last != null) last.next = node;
if (first == null) first = node;
node.previous = last;
last = node;
++size;
}

public void remove(Node node) {
Expand All @@ -169,8 +290,8 @@ public void remove(Node node) {

if (first == node)
first = node.next;
if (last == node)
last = last.previous;

--size;
}

public Node getRoot() {
Expand Down

0 comments on commit c241203

Please sign in to comment.