## Exercise 2 - Bigram Matrix
(6 points)

You should finish the implementation of the given `BigramMatrix` class. Its constructor takes a list of tokens (the same tokens your `preprocess` method in the former exercise should create) to generate a count matrix of bi-grams.

Let $c(w_i,w_j)$ be the number of bigrams $(w_i, w_j)$, i.e., the number of times the words $w_i$ stays in front of the word $w_j$. Your bigram matrix should contain a matrix of counts where `counts[i][j]` contains the value of $c(w_i,w_j)$.

For the input tokens 
```
[<s>, she, said, i, know, that, she, likes, english, food, </s>]
```
the matrix looks like the following table (rows are $w_i$ and columns are $w_j$)

| $w_i$ \ $w_j$ | `<s>` | `</s>` | `english` | `food` | `i` | `know` | `likes` | `said` | `she` | `that` |
|---|---|---|---|---|---|---|---|---|---|---|
| `<s>` | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| `</s>` | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| `english` | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| `food` | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| `i` | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| `know` | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
| `likes` | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| `said` | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
| `she` | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 |
| `that` | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |

* `create(List<String> tokens)` should initialize the matrix.
* `normalize()` should normalize the counts of the matrix (note that a application of the Laplace smoothing should still be possible, even after normalizing).
* `performLaplaceSmoothing()` should perform Laplace smoothing on the counts in the matrix (and on the normalized values of the matrix if it has been normalized before)
* `getCount(String word1, String word2)` should return the (eventually smoothed) count for the bi-gram `(word1, word2)`.
* `getNormalizedCount(String word1, String word2)` should return the (eventually smoothed,) normalized count for the bi-gram `(word1, word2)` (i.e., the probability $P(word2|word1)$).

Additionally, you should try to handle unknown words (i.e., words which do not occur in the given text) in a meaningful way, by implementing the following rules:
* If smoothing has not been applied and one of the words if a bi-gram is not known, return $0$.
* If the matrix has been smoothed, word1 is know and word2 was not part of the given list of tokens, assume that it is a part of the matrix by returning $1$ for its count and $1/s$ as its normalization (where $s$ is the sum of the row of word1 (without adding the $1$ of word2 to this row)).
* If the matrix has been smoothed and word1 was not part of the given list of tokens, assume that its row was empty before smoothing and contains only $1$s after smoothing.

You implementation will be tested in 5 different scenarios:
* The matrix is created and the pure counts are checked.
* The matrix is normalized and the normalized counts are checked.
* The matrix is smoothed and the smoothed counts are checked.
* The matrix is normalized and smoothed and the normalized counts are checked.
* In contrast to the 4 cases above, the matrix is checked with words that do not occur in the text (again in all 4 previous scenarios).

#### Notes

- Do not add additional external libraries.
- Interface
  - You can use _[TAB]_ for autocompletion and _[SHIFT]_+_[TAB]_ for code inspection.
  - Use _Menu_ -> _View_ -> _Toggle Line Numbers_ for debugging.
  - Check _Menu_ -> _Help_ -> _Keyboard Shortcuts_.
- Finish
  - Save your solution by clicking on the _disk icon_.
  - Finally, choose _Menu_ -> _File_ -> _Close and Halt_.
  - Do not forget to _Submit_ your solution in the _Assignments_ view.

In [1]:
import java.util.stream.Collectors;

/**
 * Simple implementation of a bi-gram matrix which can normalize itself and
 * apply Laplace smoothing to its inner counts.
 */
public class BigramMatrix {
    /**
     * Token used for the start of a sentence.
     */
    public static final String SENTENCE_START = "<s>";
    /**
     * Token used for the end of a sentence.
     */
    public static final String SENTENCE_END = "</s>";
    
    // YOUR CODE HERE
    
    HashMap<String, Integer> bigramCount = new HashMap<String, Integer>();
    HashMap<String, Integer> unigramCount = new HashMap<String, Integer>();
    HashMap<String, Integer> normalizedCount = new HashMap<String, Integer>();
    HashMap<String, Double> normalized = new HashMap<String, Double>();
    Set uniqueTokens = new HashSet(); 
    boolean flag = true;

    /**
     * Constructor.
     */
    public BigramMatrix(List<String> tokens) {
        create(tokens);
    }

    /**
     * Internal method for creating the bi-gram matrix from a given list of tokens.
     * 
     * @param tokens
     *            the tokens of an input text for which the matrix should be
     *            initialized.
     */
    protected void create(List<String> tokens) {
        // YOUR CODE HERE
        
        for(int i =0; i<= tokens.size()-2; i++)
        {
            
            StringBuilder keyBuilder = new StringBuilder(tokens.get(i).trim());
            keyBuilder.append(' ').append(tokens.get(i+1).trim());
            String key = keyBuilder.toString().toLowerCase();
            
            if (bigramCount.get(key) != null) { 
  
                // If string is present in unigramCount, 
                // incrementing it's count by 1 
                bigramCount.put(key, bigramCount.get(key) + 1); 
            } 
            else { 
  
                // If string is not present in unigramCount, 
                // putting this string to unigramCount with 1 as it's value 
                bigramCount.put(key, 1); 
            } 
            
        }
   
        // checking each string of tokens 
        for (String s : tokens) { 
            
            s.toLowerCase();
            if (unigramCount.containsKey(s)) { 
  
                // If string is present in unigramCount, 
                // incrementing it's count by 1 
                unigramCount.put(s, unigramCount.get(s) + 1); 
            } 
            else { 
  
                // If string is not present in unigramCount, 
                // putting this string to unigramCount with 1 as it's value 
                unigramCount.put(s, 1); 
            } 
        }
    }

    /**
     * Transforms the internal count matrix into a normalized counts matrix.
     */
    public void normalize() {
        // YOUR CODE HERE
        for (String k : bigramCount.keySet()) {
            
                String uniword = k.split(" ")[0];
                normalized.put(k, (double)bigramCount.get(k)/(unigramCount.get(uniword)));
                normalizedCount.put(k, bigramCount.get(k));
            
            }
        for(String k : unigramCount.keySet()){
            
            normalized.put(k, (double)0);
            
        }
        
        
        flag = true;
        
        
    }

    /**
     * Performs the Laplace smoothing on the bi-gram matrix.
     */
    public void performLaplaceSmoothing() {
        // YOUR CODE HERE
        for (String k : bigramCount.keySet()) {
            
                String uniword = k.split(" ")[0];
                normalized.put(k, (double)(bigramCount.get(k) + 1)/(unigramCount.get(uniword)+unigramCount.size()));
                normalizedCount.put(k, bigramCount.get(k)+1);
        
            
            }
        for(String k : unigramCount.keySet()){
            
            normalized.put(k, (double) 1/(unigramCount.get(k)+unigramCount.size()));
            
        }
        
        flag = false;
        
    }

    /**
     * Returns the count of the bi-gram matrix for the bi-gram (word1, word2).
     */
    public double getCount(String word1, String word2) {
        double count = 0;
        
        word1.toLowerCase();
        word2.toLowerCase();
        
        // YOUR CODE HERE
        
        StringBuilder keyBuilder = new StringBuilder(word1.trim());
        keyBuilder.append(' ').append(word2.trim());
        String key = keyBuilder.toString();
        
        if(flag){
            
            if(bigramCount.containsKey(key))
            {
                count = bigramCount.get(key);
               // System.out.println("key exists: " +key +" "+count);
            }
            else count = 0;
                
            
            
            
        }else{
            
            if(bigramCount.containsKey(key))
                count = bigramCount.get(key)+ 1;
            else count = 1;
            
        }
        
        
        return count;
    }

    /**
     * Returns the normalized count of the bi-gram matrix for the bi-gram (word1, word2) (i.e., P(word2 | word1)).
     */
    public double getNormalizedCount(String word1, String word2) {
        double normalizedCount = 0;
        // YOUR CODE HERE
        
        word1.toLowerCase();
        word2.toLowerCase();
        StringBuilder keyBuilder = new StringBuilder(word1.trim());
        keyBuilder.append(' ').append(word2.trim());
        String key = keyBuilder.toString();
        //uniqueTokens = 
        
        if(flag){
            
            if(normalized.containsKey(key))
                normalizedCount = normalized.get(key);
            else normalizedCount = 0;
            
            
        }else{
            
           if(unigramCount.containsKey(word1))
            {
               
                if(normalized.containsKey(key))
                {
                    //System.out.println("key: "+key);
                    normalizedCount = normalized.get(key);
                    //System.out.println("normalizedCount: "+normalizedCount);
                }
                else normalizedCount = normalized.get(word1);
            }            
            else 
            {normalizedCount = (double)1/unigramCount.size();
               // System.out.println("unigramCount.size(): "+unigramCount.size());
            }
            
        }
        
        
        return normalizedCount;
    }
}

// This line should make sure that compile errors are directly identified when executing this cell
// (the line itself does not produce any meaningful result)
(new BigramMatrix(Arrays.asList("a", "b"))).normalize();

# Evaluation

- Run the following cell to test your implementation.
- You can ignore the cells afterwards.

In [2]:
%maven org.junit.jupiter:junit-jupiter-api:5.3.1
import org.junit.jupiter.api.Assertions;
import org.opentest4j.AssertionFailedError;

public static final double DELTA = 0.000001;

public static void checkMatrix(BigramMatrix matrix, String[][] testCases, double[] expectedValues,
        boolean checkNormalizedCounts) throws Exception {
    try {
        double value, diff;
        for (int i = 0; i < testCases.length; i++) {
            value = checkNormalizedCounts ? matrix.getNormalizedCount(testCases[i][0], testCases[i][1])
                    : matrix.getCount(testCases[i][0], testCases[i][1]);
            diff = Math.abs(value - expectedValues[i]);
            Assertions.assertTrue(diff < DELTA, "Your solution returned "
                    + (checkNormalizedCounts ? ("P(\"" + testCases[i][1] + "\"|\"" + testCases[i][0] + "\")=")
                            : ("c(\"" + testCases[i][0] + "\",\"" + testCases[i][1] + "\")="))
                    + value + " while " + expectedValues[i] + " has been expected.");
        }
        System.out.println("Test(s) successfully completed.");
    } catch (AssertionFailedError e) {
        throw e;
    } catch (RuntimeException e) {
        System.err.println("Your solution caused an unexpected error:");
        throw e;
    }
}

System.out.println("----- 1st example -----");
List<String> tokens = Arrays.asList("<s>", "she", "said", "i", "know", "that", "she", "likes",
                                        "english", "food", "</s>");
BigramMatrix m = new BigramMatrix(tokens);

System.out.print("Check counts: ");
checkMatrix(m,
        new String[][] { { "she", "said" }, { "english", "food" }, { "likes", "food" } },
        new double[] { 1, 1, 0 }, false);

m.normalize();
System.out.print("Check normalized counts: ");
checkMatrix(m,
        new String[][] { { "she", "said" }, { "english", "food" }, { "likes", "food" } },
        new double[] { 0.5, 1, 0 }, true);

m.performLaplaceSmoothing();
System.out.print("Check smoothed counts: ");
checkMatrix(m,
        new String[][] { { "she", "said" }, { "english", "food" }, { "likes", "food" } },
        new double[] { 2, 2, 1 }, false);

System.out.print("Check normalized, smoothed counts: ");
checkMatrix(m,
        new String[][] { { "she", "said" }, { "english", "food" }, { "likes", "food" } },
        new double[] { 1.0 / 6.0, 2.0 / 11.0, 1.0 / 11.0 }, true);

System.out.println("----- 2nd example -----");
// Apply the solution to a longer example
tokens = Arrays.asList("<s>", "london", "is", "the", "capital", "and", "largest", "city", "of",
                      "england", "</s>", "<s>", "million", "people", "live", "in", "london",
                      "</s>", "<s>", "the", "river", "thames", "is", "in", "london", "</s>",
                      "<s>", "london", "is", "the", "largest", "city", "in", "western",
                      "europe", "</s>");
m = new BigramMatrix(tokens);

System.out.print("Check counts: ");
checkMatrix(m, new String[][] { { "london", "</s>" }, { "largest", "city" }, { "river", "thames" },
        { "city", "river" } }, new double[] { 2, 2, 1, 0 }, false);

m.normalize();
System.out.print("Check normalized counts: ");
checkMatrix(m, new String[][] { { "london", "</s>" }, { "largest", "city" }, { "river", "thames" },
        { "city", "river" } }, new double[] { 0.5, 1.0, 1.0, 0 }, true);

m.performLaplaceSmoothing();
System.out.print("Check smoothed counts: ");
checkMatrix(m, new String[][] { { "london", "</s>" }, { "largest", "city" }, { "river", "thames" },
        { "city", "river" } }, new double[] { 3, 3, 2, 1 }, false);

System.out.print("Check normalized, smoothed counts: ");
checkMatrix(m, new String[][] { { "london", "</s>" }, { "largest", "city" }, { "river", "thames" },
        { "city", "river" } }, new double[] { 3.0 / 23.0, 1.0 / 7.0 , 0.1, 1.0 / 21 }, true);

System.out.println("----- Test with unknown words -----");
m = new BigramMatrix(tokens); // set matrix back
// Check unknown words
System.out.print("Check counts: ");
checkMatrix(m, new String[][] { { "london", "underground" }, { "small", "city" }, { "sky", "scraper" } }, 
            new double[] { 0, 0, 0 }, false);

m.normalize();
System.out.print("Check normalized counts: ");
checkMatrix(m, new String[][] { { "london", "underground" }, { "small", "city" }, { "sky", "scraper" } }, 
            new double[] { 0, 0, 0 }, true);

m.performLaplaceSmoothing();
System.out.print("Check smoothed counts: ");
checkMatrix(m, new String[][] { { "london", "underground" }, { "small", "city" }, { "sky", "scraper" } }, 
            new double[] { 1, 1, 1 }, false);

System.out.print("Check normalized, smoothed counts: ");
checkMatrix(m, new String[][] { { "london", "underground" }, { "small", "city" }, { "sky", "scraper" } }, 
            new double[] { 1.0 / 23.0, 1.0 / 19.0, 1.0 / 19.0 }, true);

----- 1st example -----
Check counts: Test(s) successfully completed.
Check normalized counts: Test(s) successfully completed.
Check smoothed counts: Test(s) successfully completed.
Check normalized, smoothed counts: Test(s) successfully completed.
----- 2nd example -----
Check counts: Test(s) successfully completed.
Check normalized counts: Test(s) successfully completed.
Check smoothed counts: Test(s) successfully completed.
Check normalized, smoothed counts: Test(s) successfully completed.
----- Test with unknown words -----
Check counts: Test(s) successfully completed.
Check normalized counts: Test(s) successfully completed.
Check smoothed counts: Test(s) successfully completed.
Check normalized, smoothed counts: Test(s) successfully completed.


In [None]:
// Ignore this cell

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell

In [None]:
// Ignore this cell