# Exercise 1 - Naive Bayes Classification
(10 points)

Implement a Naive Bayes classifier by finalizing the two given classes. The `BayesianLearner` class acts like a builder for the `BayesianClassifier` instances. That means that the learner gets the set of classes during its creation and the `learnExample` method of the learner is called once for each document of the training set. Internally, the learner should gather all statistics that are necessary for the classifier when processing the training examples.
After the learner saw all training documents, the `createClassifier` method is called which creates an instance of the `BayesianClassifier` class and initializes it with the statistics gathered before. 
The classification itself is carried out by the `classify` method which takes an unknown document and assigns it one of the classes learned before.

#### Hints

- Please do not forget to preprocess your documents. What exactly the preprocessing does is up to you.
- The evaluation will measure the accuracy of your classifier.
- The evaluation in the hidden tests has three stages. 
  1. Your solution will get 4 points as soon as it is better than the baselines. The baselines are:
     - For each class, a classifier that always returns this class.
     - A random guesser that returns a random class.
  2. If your solution has an accuracy >= 0.7, you will get 3 more points.
  3. If your solution has an accuracy >= 0.8, you will get 3 more points.
- You can download the [single-class-train.tsv](https://hobbitdata.informatik.uni-leipzig.de/teaching/SNLP/classification/single-class-train.tsv) file. It comprises one document per line. The first word is the class, followed by a tab character (`\t`). The remaining content of the line is the text of the document.

#### Notes

- Do not add additional external libraries.
- Interface
  - You can use _[TAB]_ for autocompletion and _[SHIFT]_+_[TAB]_ for code inspection.
  - Use _Menu_ -> _View_ -> _Toggle Line Numbers_ for debugging.
  - Check _Menu_ -> _Help_ -> _Keyboard Shortcuts_.
- Finish
  - Save your solution by clicking on the _disk icon_.
  - Finally, choose _Menu_ -> _File_ -> _Close and Halt_.
  - Do not forget to _Submit_ your solution in the _Assignments_ view.

In [1]:

/**
 * Classifier implementing naive Bayes classification.
 */
public class BayesianClassifier {
    // YOUR CODE HERE
    Set<String> uniqueWords = null;
	HashMap<String, Integer> updatedClassCount = null;
	HashMap<String, BigDecimal> updatedClassProb = null;
	HashMap<String, HashMap<String, Integer>> UpdatedClassWordCount = null;
	int globalTotalCount = 0;

	HashMap<String, Double> classProbabilites = new HashMap<>();

	public final static String URL_REGEX = "((www\\.[\\s]+)|(https?://[^\\s]+))";
	public final static String CONSECUTIVE_CHARS = "([a-z])\\1{1,}";
	public final static String STARTS_WITH_NUMBER = "[1-9]\\s*(\\w+)";
    
    
    BayesianClassifier(HashMap<String, BigDecimal> classProb, HashMap<String, HashMap<String, Integer>> classWordCount,
			HashMap<String, Integer> classCount, int totalCount, Set<String> uniqueWords) {
		this.updatedClassCount = classCount;
		this.updatedClassProb = classProb;
		this.UpdatedClassWordCount = classWordCount;
		this.globalTotalCount = totalCount;
		this.uniqueWords = uniqueWords;

	}

	public String preprocess(String text) {

		text = text.replaceAll(STARTS_WITH_NUMBER, "");

		// text = text.replaceAll("@([^\\s]+)", "");
		text = text.replaceAll("[^\\s\\w']*", "");
		text = text.replaceAll("(\\bthe\\b)*", "");
		text = text.replaceAll("(\\band\\b)*", "");
		text = text.replaceAll("(\\ba\\b)*", "");
		text = text.replaceAll("(\\bis\\b)*", "");
		text = text.replaceAll("(\\bits\\b)*", "");
		text = text.replaceAll("(\\bfrom\\b)*", "");
		text = text.replaceAll("(\\bit\\b)*", "");
		text = text.replaceAll("(\\bfor\\b)*", "");
		text = text.replaceAll("(\\bin\\b)*", "");
		text = text.replaceAll("(\\bto\\b)*", "");
		text = text.replaceAll("(\\bof\\b)*", "");
		text = text.replaceAll("(\\bhas\\b)*", "");
		text = text.replaceAll("(\\bhad\\b)*", "");
		text = text.replaceAll("(\\bhave\\b)*", "");
		text = text.replaceAll("(\\bwas\\b)*", "");
		text = text.replaceAll("(\\bare\\b)*", "");
		text = text.replaceAll("(\\bat\\b)*", "");
		text = text.replaceAll("([0-9]+)*", "");

		return text;
	}


    /**
     * Classifies the given document and returns the class name.
     */
    public String classify(String text) {
        String clazz = null;
        // YOUR CODE HERE
        
        text = preprocess(text);
		String[] words = text.toLowerCase().split("[^a-z0-9']+");
		HashMap<String, Integer> map = new HashMap<String, Integer>();

		int uniqueWordsSize = this.uniqueWords.size();

		for (String word : words) {

			this.uniqueWords.add(word);

			if (map.containsKey(word)) {
				int value = map.get(word);
				value++;
				map.put(word, value);
			} else {
				map.put(word, 1);
			}

		}

		BigDecimal minimum = BigDecimal.ZERO;

		for (String key : this.updatedClassCount.keySet()) {
			BigDecimal finalProbability = BigDecimal.ONE;

			BigDecimal classProb = BigDecimal.ZERO;

			if (updatedClassProb.containsKey(key)) {
				// System.out.println("Class "+key+ " has probability
				// "+updatedClassProb.get(key));
				classProb = updatedClassProb.get(key);
			}

			HashMap<String, Integer> temp = UpdatedClassWordCount.get(key);
			// System.out.println("Temp "+temp);
			int totalWordsInHapMapForThatClass = 0;

			for (int value : temp.values()) {
				totalWordsInHapMapForThatClass += value;
			}

			for (String k : map.keySet()) {
				BigDecimal wordFrequency = BigDecimal.ZERO;

				if (temp.containsKey(k)) {
					wordFrequency = new BigDecimal(temp.get(k));
					// System.out.println(" word frequ for word "+k+" is
					// "+wordFreq);
					BigDecimal denominator = new BigDecimal(totalWordsInHapMapForThatClass)
							.add(new BigDecimal(uniqueWordsSize));
					BigDecimal probByClass = (wordFrequency.add(BigDecimal.ONE)).divide(denominator,
							MathContext.DECIMAL128);

					BigDecimal power = probByClass.pow(map.get(k),MathContext.DECIMAL128);
					finalProbability = finalProbability.multiply(power,MathContext.DECIMAL128);

				} else {

					BigDecimal denominator = (new BigDecimal(totalWordsInHapMapForThatClass))
							.add(new BigDecimal(uniqueWordsSize));
					BigDecimal probByClass = (BigDecimal.ONE).divide(denominator, MathContext.DECIMAL128);

					BigDecimal power = probByClass.pow(map.get(k),MathContext.DECIMAL128);
					finalProbability = finalProbability.multiply(power,MathContext.DECIMAL128);

				}

			}

			finalProbability = finalProbability.multiply(classProb,MathContext.DECIMAL128);

			if (finalProbability.compareTo(minimum) == 1) {
				clazz = key;
				minimum = finalProbability;
			}

		}

		// update ClassCounr Map
		globalTotalCount++;

		if (updatedClassCount.containsKey(clazz)) {
			int value = updatedClassCount.get(clazz);
			value++;
			updatedClassCount.put(clazz, value);

			for (String key : updatedClassCount.keySet()) {
				BigDecimal value1 = (new BigDecimal(updatedClassCount.get(key)))
						.divide(new BigDecimal(globalTotalCount), MathContext.DECIMAL128);
				updatedClassProb.put(key, value1);
			}

		}

		if (UpdatedClassWordCount.containsKey(clazz)) {
			for (String key : map.keySet()) {
				if (UpdatedClassWordCount.get(clazz).containsKey(key)) {
					UpdatedClassWordCount.get(clazz).put(key, UpdatedClassWordCount.get(clazz).get(key) + map.get(key));
				} else {
					UpdatedClassWordCount.get(clazz).put(key, map.get(key));
				}
			}
		}

        
        
        return clazz;
    }
}

/**
 * Learner (or Builder) class for a naive Bayes classifier.
 */
public class BayesianLearner {
    // YOUR CODE HERE
    
    Set<String> uniqueWords = new HashSet<>();
	HashMap<String, Integer> classCount = new HashMap<String, Integer>();
	HashMap<String, BigDecimal> classProb = new HashMap<String, BigDecimal>();
	HashMap<String, HashMap<String, Integer>> classWordCount = new HashMap<String, HashMap<String, Integer>>();
	int totalCount = 0;

	public final static String URL_REGEX = "((www\\.[\\s]+)|(https?://[^\\s]+))";
	public final static String CONSECUTIVE_CHARS = "([a-z])\\1{1,}";
	public final static String STARTS_WITH_NUMBER = "[1-9]\\s*(\\w+)";

    /**
     * Constructor taking the set of classes the classifier should be able to
     * distinguish.
     */
    public BayesianLearner(Set<String> classes) {
        // YOUR CODE HERE
        
        for (String c : classes) {
			this.classCount.put(c, 0);
			this.classProb.put(c, BigDecimal.ZERO);
		}
    }

    /**
     * The method used to learn the training examples. It takes the name of the
     * class as well as the text of the training document.
     */
    public void learnExample(String clazz, String text) {
        // YOUR CODE HERE
        totalCount++;
		// line = line.toLowerCase().replaceAll("[^a-z0-9\\s]","");

		text = preprocess(text);
		String[] words = text.toLowerCase().split("[^a-z0-9']+");

		if (classWordCount.containsKey(clazz)) {
			// HashMap<String, Integer> temp = classWordCount.get(clazz);
			// classWordCount.put(clazz, wordCount);

			HashMap<String, Integer> temp = classWordCount.get(clazz);
			for (String c : words) {
				uniqueWords.add(c);
				if (temp.containsKey(c)) {
					int val = temp.get(c);
					val++;
					temp.put(c, val);
				} else {
					temp.put(c, 1);
				}

			}

			classWordCount.put(clazz, temp);

		} else {
			HashMap<String, Integer> wordCount = new HashMap<String, Integer>();
			for (String c : words) {
				uniqueWords.add(c);
				if (wordCount.containsKey(c)) {
					int val = wordCount.get(c);
					val++;
					wordCount.put(c, val);
				} else {
					wordCount.put(c, 1);
				}

			}
			classWordCount.put(clazz, wordCount);
		}

		// updating class count based on input clazz
		if (classCount.containsKey(clazz)) {
			int val = classCount.get(clazz);

			val++;
			// System.out.println("value of "+clazz+ " is "+ val);
			classCount.put(clazz, val);
		} else {
			classCount.put(clazz, 0);
		}

    }
    
    public String preprocess(String text) {

		text = text.replaceAll(STARTS_WITH_NUMBER, "");

		// text = text.replaceAll("@([^\\s]+)", "");
		text = text.replaceAll("[^\\s\\w']*", "");
		text = text.replaceAll("(\\bthe\\b)*", "");
		text = text.replaceAll("(\\band\\b)*", "");
		text = text.replaceAll("(\\ba\\b)*", "");
		text = text.replaceAll("(\\bis\\b)*", "");
		text = text.replaceAll("(\\bits\\b)*", "");
		text = text.replaceAll("(\\bfrom\\b)*", "");
		text = text.replaceAll("(\\bit\\b)*", "");
		text = text.replaceAll("(\\bfor\\b)*", "");
		text = text.replaceAll("(\\bin\\b)*", "");
		text = text.replaceAll("(\\bto\\b)*", "");
		text = text.replaceAll("(\\bof\\b)*", "");
		text = text.replaceAll("(\\bhas\\b)*", "");
		text = text.replaceAll("(\\bhad\\b)*", "");
		text = text.replaceAll("(\\bhave\\b)*", "");
		text = text.replaceAll("(\\bwas\\b)*", "");
		text = text.replaceAll("(\\bare\\b)*", "");
		text = text.replaceAll("(\\bat\\b)*", "");
		text = text.replaceAll("([0-9]+)*", "");

		return text;
	}

    /**
     * Creates a BayesianClassifier instance based on the statistics gathered from
     * the training example.
     */
    public BayesianClassifier createClassifier() {
        BayesianClassifier classifier = null;
        // YOUR CODE HERE
        
        for (String key : this.classCount.keySet()) {

			BigDecimal value = (new BigDecimal(classCount.get(key))).divide(new BigDecimal(totalCount),
					MathContext.DECIMAL128);
			this.classProb.put(key, value);
		}

		classifier = new BayesianClassifier(this.classProb, this.classWordCount, this.classCount, this.totalCount,
				this.uniqueWords);
        
        return classifier;
    }
}
// This line should make sure that compile errors are directly identified when executing this cell
// (the line itself does not produce any meaningful result)
new BayesianLearner(new HashSet<>(Arrays.asList("good","bad")));
System.out.println("compiled");

compiled


# Evaluation

- Run the following cell to test your implementation.
- You can ignore the cells afterwards.

In [2]:
%maven org.junit.jupiter:junit-jupiter-api:5.3.1
import org.junit.jupiter.api.Assertions;
import org.opentest4j.AssertionFailedError;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import java.util.Map.Entry;

/**
 * Simple method for reading classification examples from a file as a list of (class, text) pairs.
 */
public static List<String[]> readClassData(String filename) throws IOException {
    return FileUtils.readLines(new File(filename), "utf-8").stream().map(s -> s.split("\t"))
            .filter(s -> s.length > 1).collect(Collectors.toList());
}

public static void checkClassifier(List<String[]> trainingCorpus, List<String[]> evaluationCorpus,
        double minAccuracy) {
    try {
        System.out.print("Training corpus size: ");
        System.out.println(trainingCorpus.size());
        System.out.print("Eval. corpus size   : ");
        System.out.println(evaluationCorpus.size());
        // Determine the classes
        Set<String> classes = Arrays.asList(trainingCorpus, evaluationCorpus).stream().flatMap(l -> l.stream())
                .map(d -> d[0]).distinct().collect(Collectors.toSet());
        // Determine the number of instances per class in the evaluation set
        Map<String, Long> evalClassCounts = evaluationCorpus.stream()
                .collect(Collectors.groupingBy(d -> d[0], Collectors.counting()));
        for (String clazz : classes) {
            if (!evalClassCounts.containsKey(clazz)) {
                evalClassCounts.put(clazz, 0L);
            }
        }

        // Determine the expected accuracies of the baselines
        Map<String, Double> accForClassGuessers = new HashMap<>();
        for (Entry<String, Long> e : evalClassCounts.entrySet()) {
            accForClassGuessers.put(e.getKey(), e.getValue() / (double) evaluationCorpus.size());
        }
        double accRandomGuesser = 1.0 / accForClassGuessers.size();

        // Train the classifier
        long time1 = System.currentTimeMillis();
        BayesianLearner learner = new BayesianLearner(classes);
        for (String[] trainingExample : trainingCorpus) {
            learner.learnExample(trainingExample[0], trainingExample[1]);
        }
        BayesianClassifier classifier = learner.createClassifier();
        time1 = System.currentTimeMillis() - time1;
        System.out.println("Training took       : " + time1 + "ms");

        // Classify the evaluation corpus
        long time2 = System.currentTimeMillis();
        int tp = 0, errors = 0, id = 0;
        String result;
        List<String[]> fpDetails = new ArrayList<>();
        for (String[] evalExample : evaluationCorpus) {
            result = classifier.classify(evalExample[1]);
            if (evalExample[0].equals(result)) {
                ++tp;
            } else {
                ++errors;
                fpDetails.add(new String[] { Integer.toString(id), evalExample[0], result });
            }
            ++id;
        }
        time2 = System.currentTimeMillis() - time2;
        System.out.println("Classification took : " + time2 + "ms");
        double accuracy = tp / (double) (tp + errors);

        System.out.println("Baseline classifiers: ");
        for (Entry<String, Double> baseResult : accForClassGuessers.entrySet()) {
            System.out.println(String.format("Always %-13s: %-7.5f", baseResult.getKey(), baseResult.getValue()));
        }
        System.out.println(String.format("Random guesser      : %-7.5f", accRandomGuesser));
        System.out.println(String.format("Your solution       : %-7.5f (%d tp, %d errors)", accuracy, tp, errors));
        if (fpDetails.size() > 0) {
            System.out.println("  Wrong classifications are:");
            for (int i = 0; i < Math.min(fpDetails.size(), 20); ++i) {
                System.out.print("    id=");
                System.out.print(fpDetails.get(i)[0]);
                System.out.print(" expected=");
                System.out.print(fpDetails.get(i)[1]);
                System.out.print(" result=");
                System.out.println(fpDetails.get(i)[2]);
            }
            if (fpDetails.size() > 20) {
                System.out.println("    ...");
            }
        }

        // Make sure that the students solution is better than all baselines
        for (Entry<String, Double> baseResult : accForClassGuessers.entrySet()) {
            if (baseResult.getValue() >= accuracy) {
                StringBuilder builder = new StringBuilder();
                builder.append("Your solution is not better than a classifier that always chooses the \"");
                builder.append(baseResult.getKey());
                builder.append("\" class.");
                Assertions.fail(builder.toString());
            }
        }
        if (accRandomGuesser >= accuracy) {
            Assertions.fail("Your solution is not better than a random guesser.");
        }
        if ((minAccuracy > 0) && (minAccuracy > accuracy)) {
            Assertions.fail("Your solution did not reach the expected accuracy of " + minAccuracy);
        }
        System.out.println("Test successfully completed.");
    } catch (AssertionFailedError e) {
        throw e;
    } catch (Throwable e) {
        System.err.println("Your solution caused an unexpected error:");
        throw e;
    }
}

System.out.println("---------- Simple example corpus ----------");
List<String[]> exampleCorpusTrain = Arrays.asList(
        new String[] {"chess", "white king, black rook, black queen, white pawn, black knight, white bishop." },
        new String[] {"history", "knight person granted honorary title knighthood" },
        new String[] {"history", "knight order eligibility, knighthood, head of state, king, prelate, middle ages." },
        new String[] {"chess", "Defense knight pawn opening game opponent." },
        new String[] {"literature", "Knights Round Table. King Arthur. literary cycle Matter of Britain."}
        );
List<String[]> exampleCorpusTest = Arrays.asList(
        new String[] {"history", "Knighthood Middle Ages." },
        new String[] {"chess", "player king knight opponent king checkmate game draw." },
        // document with unknown words
        new String[] {"literature", "britain king arthur. Sir Galahad." }
        );
checkClassifier(exampleCorpusTrain, exampleCorpusTest, 0);

System.out.println();
System.out.println("---------- Larger example corpus ----------");
List<String[]> classificationData =readClassData("/srv/distribution/single-class-train.tsv");
checkClassifier(classificationData.subList(0, 750), classificationData.subList(750, classificationData.size()), 0);

---------- Simple example corpus ----------
Training corpus size: 5
Eval. corpus size   : 3
Training took       : 12ms
Classification took : 8ms
Baseline classifiers: 
Always literature   : 0.33333
Always chess        : 0.33333
Always history      : 0.33333
Random guesser      : 0.33333
Your solution       : 1.00000 (3 tp, 0 errors)
Test successfully completed.

---------- Larger example corpus ----------
Training corpus size: 750
Eval. corpus size   : 260
Training took       : 1736ms
Classification took : 1146ms
Baseline classifiers: 
Always gold         : 0.03462
Always money-fx     : 0.19231
Always trade        : 0.16923
Always interest     : 0.11538
Always coffee       : 0.06154
Always money-supply : 0.10000
Always ship         : 0.05000
Always sugar        : 0.06538
Always crude        : 0.21154
Random guesser      : 0.11111
Your solution       : 0.85769 (223 tp, 37 errors)
  Wrong classifications are:
    id=5 expected=money-fx result=interest
    id=10 expected=money-supply resu

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell