# L14b: The Skip-Gram Embedding Model
In this lab, we'll look at the Skip-Gram model, which is a neural network model for learning word embeddings. This is the second text embedding model we'll cover in this course. 
* __Continuous Bag of Words (CBOW)__: This architecture predicts the target word based on its context words. It uses a shallow neural network to learn the embeddings of words in a given context. No positional information is used, and the model is trained to minimize the loss between the predicted and actual target word.
* __Skip-Gram__: A skip-gram model consists of a single hidden layer that transforms a one-hot encoded input word into a dense vector representation, optimizing the embedding so that words appearing in similar contexts have similar vector representations. Imagine you're reading a sentence and can guess the words that come before and after a particular word.

See section 2: [Rong, X. (2014). word2vec Parameter Learning Explained. ArXiv, abs/1411.2738.](https://arxiv.org/abs/1411.2738)

### Tasks
Before we start, execute the `Run All Cells` command to check if you (or your neighbor) have any code or setup issues. Code issues, then raise your hands - and let's get those fixed!
* __Task 1: Setup, Data, Prerequisites (10 min)__: In this task, we'll load a public dataset of headlines curated as either sarcastic or not sarcastic. Our dataset is available on [Kaggle](https://www.kaggle.com/datasets/rmisra/news-headlines-dataset-for-sarcasm-detection). After loading the data, we'll tokenize the data (convert text strings to numerical arrays).
* __Task 2: Build and Train a HiPPO-LegS model instance (15 min)__: In this task, we will build and train a HiPPO-S4-LegS model instance on the sample input sequence we selected above. We start by creating a model instance, and the we train this instance for different hidden state sizes.
* __Task 3: Does the S4 model generalize? (25 min)__: In this task, we'll explore how the S4-LegS model performs when we give input sequences that are _similar_ but not the same as the training data. We'll take the training data, perturb some words, and feed the perturbed sequence into the model.

Let's get started!
___

## Task 1: Setup, Data, Prerequisites
In this task, we'll set up the environment, load the data, and prepare it for training. We'll also install the required libraries and load the dataset. 

In [1]:
include("Include.jl")

Next, let's specify an example sentence, tokenize it and create a vocabulary. We'll also create a mapping from words to indices and vice versa. This will help us convert the text data into numerical arrays that can be fed into the model.

In [2]:
words, vocabulary, inverse_vocabulary = let 
    
    # initialize -
    vocabulary = Dict{String, Int}();
    inverse_vocabulary = Dict{Int, String}();

    # TDOD: specify a sample sentence -
    sample_sentence = "The quick brown fox jumps over the lazy dog"; # Classical pangram!

    # split -
    words = split(sample_sentence, " ") .|> lowercase |> unique; # no external ordering

    # build the vocabulary -
    for (i, word) in enumerate(words)
        vocabulary[word] = i;
        inverse_vocabulary[i] = word;
    end

    # return -
    words, vocabulary, inverse_vocabulary
end;

In [3]:
words

8-element Vector{String}:
 "the"
 "quick"
 "brown"
 "fox"
 "jumps"
 "over"
 "lazy"
 "dog"

__Constants__: Let's set up some constants for the model. These constants will be used throughout the example codes below. See the comments in the code for more details.

In [4]:
N = length(words); # size of the vocabulary
windowsize = 3; # size of the context window
number_of_epochs = 100; # number of epochs
number_digit_array = range(1, stop=N, step=1) |> collect; # list of numbers from 1 to N

__CBOW training datatset__: Fill me in.

In [5]:
cbow_training_dataset = let

    # initialize -
    training_dataset = Vector{Tuple{Vector{Float32}, OneHotVector{UInt32}}}();
    C = windowsize - 1; # number of context words

    # build the training data -
    for i ∈ 2:(N-1)
        
        targetword = words[i]; # target word
        contextwords = words[(i-1):(i+1)] |> v-> [v[1], v[3]] # context words
        
        # proces the target word -
        targetword_index = vocabulary[targetword]; # index of the target word
        y = onehot(targetword_index, number_digit_array); # one-hot encoding of the target word

        # process the context words -
        tmp = Array{Float32,2}(undef, N, C); # temporary array
        for (j,word) in enumerate(contextwords)
            contextword_index = vocabulary[word]; # index of the context word
            x = onehot(contextword_index, number_digit_array) .|> Float32; # one-hot encoding of the context word
            tmp[:, j] .= x; # store the context word
        end
        x = sum(tmp, dims=2) |> vec .|> Float32; # average of the context words
        
        # store the training data -
        push!(training_dataset, (x, y)); # store the training data
    end

    # return -
    training_dataset;
end;

__Skip Gram training datatset__: Fill me in.

In [None]:
skipgram_training_dataset = let

    # initialize -
    training_dataset = Vector{Tuple{Vector{Float32}, Vector{Float32}}}();
    C = windowsize - 1; # number of context words

    # build the training data -
    for i ∈ 2:(N-1)
        
        contextword = words[i]; # target word
        targetwords = words[(i-1):(i+1)] |> v-> [v[1], v[3]] # context words
        
        # proces the context word -
        contextword_index = vocabulary[targetword]; # index of the target word
        x = onehot(contextwork_index, number_digit_array) |> Float32; # one-hot encoding of the target word

        # process the targets words -
        tmp = Array{Float32,2}(undef, N, C); # temporary array
        for (j,word) in enumerate(contextwords)
            contextword_index = vocabulary[word]; # index of the context word
            x = onehot(contextword_index, number_digit_array) .|> Float32; # one-hot encoding of the context word
            tmp[:, j] .= x; # store the context word
        end
        x = sum(tmp, dims=2) |> vec .|> Float32; # average of the context words
        
        # store the training data -
        push!(training_dataset, (x, y)); # store the training data
    end

    # return -
    training_dataset;
end;

In [6]:
cbow_training_dataset[1]

(Float32[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], Bool[0, 1, 0, 0, 0, 0, 0, 0])

## Task 2: Build and Train a CBOW model instance
In this task, we will build and train a CBOW model instance on the sample input sequence we selected above. We start by creating a model instance, and the we train this instance for a few epochs.

In [7]:
cbow_model = let

    # TODO: Uncomment the code below to build the model!
    Flux.@layer MyFluxNeuralNetworkModel  trainable=(input, hidden); # create a "namespaced" of sorts
    MyModel() = MyFluxNeuralNetworkModel( # a strange type of constructor
        Chain(
            input = Dense(N, windowsize, identity),  # layer 1
            hidden = Dense(windowsize, N, identity), # layer 2
            output = NNlib.softmax) # layer 3 (output layer)
    );
    cbow_model = MyModel().chain;
end

Chain(
  input = Dense(8 => 3),                [90m# 27 parameters[39m
  hidden = Dense(3 => 8),               [90m# 32 parameters[39m
  output = NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m59 parameters, 444 bytes.

__Training__: Fill me in.

In [10]:
trained_cbow_model = let

    localmodel = cbow_model; # make a local copy of the model

    # setup the loss function -
    loss(ŷ, y) = Flux.Losses.logitcrossentropy(ŷ, y; agg = mean); # loss for training multiclass classifiers, what is the agg?

    # setup the optimizer
    λ = 0.64; # TODO: maybe change the learning rate (default: 0.61)?
    β = 0.10; # TODO: maybe change the momentum parameter (default: 0.10)?
    opt_state = Flux.setup(Momentum(λ,β), localmodel);

    # training loop -
    for i ∈ 1:number_of_epochs
        # train the model - check out the do block notion: https://docs.julialang.org/en/v1/base/base/#do
        Flux.train!(localmodel, cbow_training_dataset, opt_state) do m, x, y
            loss(m(x), y) # loss function
        end

        if (rem(i,10) == 0)
            @show "Epoch $i of $number_of_epochs completed" # print the epoch number
        end
    end

    # return the trained model -
    localmodel;
end

"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 10 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 20 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 30 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 40 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 50 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 60 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 70 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 80 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 90 of 100 completed"
"Epoch $(i) of $(number_of_epochs) completed" = "Epoch 100 of 100 completed"


Chain(
  input = Dense(8 => 3),                [90m# 27 parameters[39m
  hidden = Dense(3 => 8),               [90m# 32 parameters[39m
  output = NNlib.softmax,
) [90m                  # Total: 4 arrays, [39m59 parameters, 444 bytes.

Let's give the model a few inputs and see what it predicts. If we give it the original context, it should return the original target word. 
* _What get's returned?_ The network will return $p(w_{i}|\mathbf{x})$, the probability of each word in the vocabulary being the target word. 

In [40]:
(x,y,word) = let
    
    x = cbow_training_dataset[1][1]; # first training data
    y = trained_cbow_model(x);
    word = y |> argmax |> i-> inverse_vocabulary[i]; # index of the word

    (x,y,word) # return the values
end;

In [41]:
x |> x-> findall(x-> x!= 0.0, x) .|> i-> inverse_vocabulary[i] # find the words in the context

2-element Vector{String}:
 "the"
 "brown"

In [42]:
word

"quick"

What happens if we give it a perturbed context? Does it still return the original target word?

In [44]:
(x,y,word) = let

    # Let's add two contexts together, and see what happens -
    x₁ = cbow_training_dataset[1][1]; # first context
    x₂ = cbow_training_dataset[4][1]; # second context
    x = (x₁ + x₂); # sum of the two contexts

    y = trained_cbow_model(x); # run the model with the sum of the two contexts
    word = y |> argmax |> i-> inverse_vocabulary[i]; # get the word from the index

    
    # return -
    x, y, word
end;

In [45]:
x |> x-> findall(x-> x!= 0.0, x) .|> i-> inverse_vocabulary[i] # find the words in the context

4-element Vector{String}:
 "the"
 "brown"
 "fox"
 "over"

In [46]:
word

"jumps"