# Brief

In gradient based bigram model, instead of explicitly counting co-occurrence of bigrams, we will randomly initialze those counts and use gradient descent to fill the counts automatically. The gradient will descent according to negative log likelihood loss function.

### Action plan:
1. Read the data and construct input, output for the neural net which will look something like following:
    -   names = ['juho', ...]
    -   `Xs = ['<s>', 'j', 'u', 'h', 'o', ...]` 
    -   `Ys = ['j', 'u', 'h', 'o', '<e>', ...]`
    -   then use `c2i` to map each character into an int
    -   The reason is that we want to feed in a neural net a previous character and expect the model to output the next character.
2. Create a weight matrix that will resemble probability matrix in the orginal bigram model.
    -   It will be a matrix of shape `V x V` where `V` is our character space.
    -   `V[i,j]` can be interpreted as the prob. of `i2c[i]` preceding `i2c[j]`
    -    To do so, we need to define:
           1. appropriate shape of the matrix
           2. appropriate functions to be applied so that we get probability distribution in each row of the matrix,
           3. and an appropriate loss function to compare the resulting prob. distribution to the real distribution.
    -   The model output would be something like `[0.05 0.65 0.30]` where the reference is `[0, 1, 0]`

## 1. Prepare data

In [1]:
from my_utils import names, chars, c2i, i2c
import torch

In [2]:
Xs, Ys = [], []

for name in names:
    bigrams = [(c2i[a],c2i[b]) for a,b in zip(['<s>'] + list(name), list(name) + ['<e>'])]

    Xs.extend([a for a,_ in bigrams])
    Ys.extend([b for _,b in bigrams])

Xs = torch.tensor(Xs)
Ys = torch.tensor(Ys)

In [3]:
Xs, Ys

(tensor([ 0,  2,  7,  ...,  5, 18, 43]),
 tensor([ 2,  7, 26,  ..., 18, 43, 46]))

In [4]:
import torch.nn.functional as F
Xenc = F.one_hot(Xs, num_classes=len(c2i)).float()

## 2. Build the model

In [5]:
W = torch.randn((len(c2i), len(c2i)))
logits = Xenc @ W

# Two below lines are just Softmax to
# change logits into prob. dist.
probs = logits.exp()
probs /= probs.sum(dim=1, keepdims=True) 

Then we get probability distribution of each to-be next character. We then compare this probability to the real one-hotted following character using loss function. Note here that the gradient can be propagated through all the opeartion we have applied here since all of them are differentiable.

The loss function that we are using here is Cross-entropy loss which is:

$$\textbf L(\hat y, y) = -\sum_{c \in \textbf{C}} y_c \log{\hat y_c}$$

Now for this task the loss is reduced to be:

$$\textbf L(\hat p, p) = -1 * log(\hat p_c)$$

where $\hat p_c$ is the predicted probability of character $c$, and the correct next character is `i2c[c]` (so, $p_c = 1$)

Then, for a minibatch, we need to reduce loss by some aggregation function such as sum or average to make the loss for a minibatch scalar.

In [6]:
W = torch.randn((len(c2i), len(c2i)), requires_grad=True)

In [50]:
for i in range(100):
    
    # Forward pass
    logits = Xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdims=True)
    
    # compute loss
    loss = -1* probs[ torch.arange(len(Xenc)), Ys ].log()
    reduced_loss = loss.mean()

    if i%10==0:
        print(reduced_loss.item()) # we need to reduce loss to make it scalar

    # Backward pass
    W.grad = None
    reduced_loss.backward()

    W.data -= 20* W.grad

2.5159099102020264
2.515619993209839
2.5153369903564453
2.5150606632232666
2.514791488647461
2.514528274536133
2.5142714977264404
2.5140209197998047
2.513775587081909
2.513536214828491


## Try generating

In [57]:
for _ in range(5):
    tokens = [c2i['<s>']]

    while True:
        recent_token = torch.tensor([tokens[-1]])
        input_vector = F.one_hot(recent_token, num_classes=len(c2i)).float()
        logits_output = input_vector @ W
        counts_output = logits_output.exp()
        probs_output = counts_output / counts_output.sum(1, keepdims=True)

        next_idx = torch.multinomial(probs_output, num_samples = 1).item()
        tokens.append(next_idx)
        
        if i2c[next_idx] == '<e>':
            break

    print(''.join([i2c[i] for i in tokens][1:-1]))

kidren
kaimikebosvars
an
jatr
ary
