Notes
* The model does well and learns quickly when training on single characters
* It does a decent job with mulitple characters but isn't perfect
    * It doesn't always distinguish between double letters (treats "aa" as "a")
    * The letter e also seems to give it trouble. It was prediciting it at a much
      later position than I'd expect.

Things to try
* Normalization
* Train on longer strings
* Increase weight of BETWEEN a bit so it doesn't get drowned out by double letters

# Energy-Based models and structured prediction

In this assignment we're going to work with structured prediction. Structured prediction broadly refers to any problem involving predicting structured values, as opposed to plain scalars. Examples of structured outputs include graphs and text.

We're going to work with text. The task is to transcribe a word from an image. The difficulty here is that different words have different lengths, so we can't just have fixed number of outputs.

## Dataset
As always, the first thing to do is implementing the dataset. We're going to create a dataset that creates images of random words. We'll also include some augmentations, such as jitter (moving the character horizontally).

In [None]:
! mkdir fonts
! curl --output fonts/font.zip https://www.fontsquirrel.com/fonts/download/Anonymous
! unzip -n fonts/font.zip -d fonts

In [None]:
from PIL import ImageDraw, ImageFont
import string
import random
import torch
import torchvision
from torchvision import transforms
from PIL import Image # PIL is a library to process images
from matplotlib import pyplot as plt

CHAR_WIDTH = 18
CHAR_HEIGHT = 32

simple_transforms = transforms.Compose([
                                    transforms.ToTensor(), 
                                ])

class SimpleWordsDataset(torch.utils.data.IterableDataset):

    def __init__(self, max_length, len=100, jitter=False, noise=False):
        self.max_length = max_length
        self.transforms = transforms.ToTensor()
        self.len = len
        self.jitter = jitter
        self.noise = noise
  
    def __len__(self):
        return self.len

    def __iter__(self):
        for _ in range(self.len):
            text = ''.join([random.choice(string.ascii_lowercase) for i in range(self.max_length)])
            img = self.draw_text(text, jitter=self.jitter, noise=self.noise)
            yield img, text

    def draw_text(self, text, length=None, jitter=False, noise=False):
        if length == None:
            length = CHAR_WIDTH * len(text)
        img = Image.new('L', (length, 32))
        fnt = ImageFont.truetype("fonts/Anonymous.ttf", 20)

        d = ImageDraw.Draw(img)
        pos = (0, 5)
        if jitter:
            pos = (random.randint(0, 7), 5)
        else:
            pos = (0, 5)
        d.text(pos, text, fill=1, font=fnt)

        img = self.transforms(img)
        img[img > 0] = 1 

        if noise:
            img += torch.bernoulli(torch.ones_like(img) * 0.1)
            img = img.clamp(0, 1)


        return img[0]

sds = SimpleWordsDataset(1, jitter=True, noise=False)
img = next(iter(sds))[0]
print(img.shape)
plt.imshow(img)

We can look at what the entire alphabet looks like in this dataset.

In [None]:
fig, ax = plt.subplots(3, 9, figsize=(12, 6), dpi=200)

max_letter_width = 0
for i, c in enumerate(string.ascii_lowercase):
    row = i // 9
    col = i % 9
    letter_img = sds.draw_text(c)
    col_maxes = letter_img.max(axis=0).values
    assert col_maxes.numel() == CHAR_WIDTH
    letter_width = col_maxes.nonzero()[-1] + 1
    max_letter_width = max(max_letter_width, letter_width.item())
    ax[row][col].imshow(letter_img)
    ax[row][col].axis('off')
ax[2][8].axis('off')

plt.show()

We can also put the entire alphabet in one image.

In [None]:
alphabet = sds.draw_text(string.ascii_lowercase, 14*26)
plt.figure(dpi=200)
plt.imshow(alphabet)
plt.axis('off')

## Model definition
Before we define the model, we define the size of our alphabet. Our alphabet consists of lowercase English letters, and additionally a special character used for space between symbols or before and after the word. For the first part of this assignment, we don't need that extra character.

Our end goal is to learn to transcribe words of arbitrary length. However, first, we pre-train our simple convolutional neural net to recognize single characters. In order to be able to use the same model for one character and for entire words, we are going to design the model in a way that makes sure that the output size for one character (or when input image size is 32x18) is 1x27, and Kx27 whenever the input image is wider. K here will depend on particular architecture of the network, and is affected by strides, poolings, among other things. 
A little bit more formally, our model $f_\theta$, for an input image $x$ gives output energies $l = f_\theta(x)$. If $x \in \mathbb{R}^{32 \times 18}$, then $l \in \mathbb{R}^{1 \times 27}$.
If $x \in \mathbb{R}^{32 \times 100}$ for example, our model may output $l \in \mathbb{R}^{10 \times 27}$, where $l_i$ corresponds to a particular window in $x$, for example from $x_{0, 9i}$ to $x_{32, 9i + 18}$ (again, this will depend on the particular architecture).

Below is a drawing that explains the sliding window concept. We use the same neural net with the same weights to get $l_1, l_2, l_3$, the only difference is receptive field. $l_1$ is looks at the leftmost part, at character 'c', $l_2$ looks at 'a', and $l_3$ looks at 't'. The receptive field may or may not overlap, depending on how you design your convolutions.

![cat.png](https://i.imgur.com/JByfyKh.png)

In [None]:
# constants for number of classes in total, and for the special extra character for empty space
ALPHABET_SIZE = 27
BETWEEN = 26

In [None]:
from torch import nn

class SimpleNet(torch.nn.Module):   
    def __init__(self):
        """
        If single_char is True, the output will have size 1 in width (the last
        dimension). Otherwise, the output width will vary depending on the width
        of the input.
        """
        super().__init__()
        # TODO: try normalization
        self.cnn_block = torch.nn.Sequential(
            # (batch_size, num_channels, img_height, img_width)
            # (B, 1, 32, 18 | 10)
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=10,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            # (B, 10, 32, 18 | 10)
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=5,
                stride=2,
                padding=2,
            ),
            # (B, 10, 16, 9 | 5)
            torch.nn.Conv2d(
                in_channels=10,
                out_channels=30,
                kernel_size=5,
                stride=2,
                padding=2,
            ),
            # (B, 30, 8, 5 | 3)
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=(8, 1),
            ),
            # (B, 30, 1, 3 | 1)
            # Last later is another convolution rather than a linear layer
            # because the width isn't known.
            torch.nn.Conv2d(
                in_channels=30,
                out_channels=ALPHABET_SIZE,
                kernel_size=1,
            )
            # (alphabet_size, B, 1, W1)
        )

    def forward(self, x, single_char: bool = False, verbose: bool = False):
        if verbose: print(f"{x.shape=}")
        # Insert new dimension for the number of channels
        x = x.unsqueeze(dim=1)
        """
        if not x.ndim == 3:
            raise ValueError(
                "Input must have shape "
                "(batch_size, img_height, img_width). "
                f"Actual shape: {x.shape}."
            )
        batch_size = x.size(0)
        # Insert new dimension for the number of channels
        x = x.unsqueeze(dim=1)
        # Break original input into sub-images that act like different elements
        # of a single batch. This is equivalent to convolving with a kernel of
        # all 1s. Would that be more efficient?
        # FIXME: Is this valid?
        sub_images = tuple(
            x[b, :, :, w:w + self.window_width]
            for b in range(batch_size)
            for w in range(max(x.size(3) - self.window_width + 1, 1))
        )
        if verbose: print(
            f"{len(sub_images)=}",
            f"{sub_images[0].shape=}",
            f"{sub_images[-1].shape=}"
        )
        # Make sure each sub-image is the same size by padding with 0s
        sub_images = torch.nn.utils.rnn.pad_sequence(
            sub_images, batch_first=True
        )
        if verbose:
            print(f"{sub_images.shape=}")
            for b in sub_images:
                plt.imshow(sub_images[b])
        z0 = self.cnn_block(sub_images)
        # Resize to be (batch_size, alphabet_size, width)
        z1 = torch.stack(z0.split(batch_size, dim=0), dim=0)
        if verbose: print(f"{z1.shape=}")
        return z1.permute(1, 2, 0)
        """
        z0 = self.cnn_block(x)
        if verbose: print(f"{z0.shape=}")
        if single_char:
            z0 = nn.functional.adaptive_max_pool2d(z0, (1, 1))
            if verbose: print(f"pooled shape={z0.shape}")
        # z0 has shape (batch_size, alphabet_size, height, width)
        # Squeeze out the height since it's not useful
        z1 = z0.squeeze(2)
        if verbose: print(f"{z1.shape=}")
        return z1

Let's initalize the model and apply it to the alphabet image:

In [None]:
model = SimpleNet()
with torch.no_grad():
    alphabet_energies = model(alphabet.unsqueeze(0), verbose=True)

def plot_energies(ce):
    fig=plt.figure(dpi=200)
    ax = plt.axes()
    im = ax.imshow(ce.cpu())
    
    ax.set_xlabel('window locations →')
    ax.set_ylabel('← classes')
    ax.xaxis.set_label_position('top') 
    ax.set_xticks([])
    ax.set_yticks([])
    
    cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax) 
    
plot_energies(alphabet_energies[0].detach())

In [None]:
letter = sds.draw_text("a").unsqueeze(0)
_ = model(letter, single_char=False, verbose=True)
print()
_ = model(letter, single_char=True, verbose=True)

So far we only see random outputs, because the classifier is untrained.

## Train with one character

Now we train the model we've created on a dataset where images contain only single characters. Note the changed cross_entropy function.

In [None]:
def cross_entropy(energies, *args, **kwargs):
    """ We use energies, and therefore we need to use log soft arg min instead
        of log soft arg max. To do that we just multiply energies by -1. """
    return nn.functional.cross_entropy(-1 * energies, *args, **kwargs)

def simple_collate_fn(samples):
    images, annotations = zip(*samples)
    images = list(images)
    annotations = list(annotations)
    annotations = list(map(lambda c : torch.tensor(ord(c) - ord('a')), annotations))
    # This code was provided with the notebook, but this section could be
    # simplified by using torch.nn.utils.rnn.pad_sequence().
    m_width = max(CHAR_WIDTH, max([i.shape[1] for i in images]))
    for i in range(len(images)):
        images[i] = torch.nn.functional.pad(images[i], (0, m_width - images[i].shape[-1]))

    if len(images) == 1:
        return images[0].unsqueeze(0), torch.stack(annotations)
    else:
        return torch.stack(images), torch.stack(annotations)

In [None]:
BATCH_SIZE = 100
NUM_WORKERS = 0
DATASET_LENGTH = 10000

sds = SimpleWordsDataset(1, len=DATASET_LENGTH, jitter=True, noise=False)
dataloader = torch.utils.data.DataLoader(
    sds,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=simple_collate_fn,
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
model = SimpleNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

losses = []
accuracies = []
for i, (images, labels) in enumerate(dataloader):
    images = images.to(device)
    labels = labels.unsqueeze(dim=1).to(device)
    optimizer.zero_grad()
    energies = model(images, single_char=True)
    loss = cross_entropy(energies, labels)
    losses.append(loss.item())
    predictions = energies.argmin(dim=1)
    correct = (predictions == labels).sum()
    accuracies.append(correct / BATCH_SIZE)
    loss.backward()
    optimizer.step()

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True)
axes[0].plot(losses)
axes[1].plot(accuracies)
axes[0].set_ylabel("Training Loss")
axes[1].set_ylabel("Training Accuracy")
axes[1].set_xlabel("Batch Iteration")

In [None]:
def get_accuracy(model, dataset):
    cnt = 0
    for i, l in dataset:
        with torch.no_grad():
            energies = model(i.unsqueeze(0).to(device), single_char=True)
        x = energies.argmin(dim=1)[0, 0]
        cnt += int(x == (ord(l[0]) - ord('a')))
    return cnt / len(dataset)

tds = SimpleWordsDataset(1, len=1000)
accuracy = get_accuracy(model, tds)
print(f"Accuracy: {accuracy:.3%}")
assert accuracy == 1.0, 'Your model doesn\'t achieve 100% accuracy for 1 character'

Now, to see how our model would work with more than one character, we apply the model to a bigger input - the image of the alphabet we saw earlier. We extract the energies for each window and show them.

In [None]:
with torch.no_grad():
    alphabet_energies_post_train = model(alphabet.to(device).view(1, *alphabet.shape))
print(f"{alphabet_energies_post_train.shape=}")
plot_energies(alphabet_energies_post_train[0])

Explain any classes that are lit up. What is still missing to be able to use it for transcription of words?

Answer: The final class (`BETWEEN` == 26, where `"a"` through `"z"` are 0 through 25) has a high energy across the entire image. That class corresponds to "no valid character". None of the training examples included it, so the model learned to always assign it a high energy. We need to train on multicharacter examples which include class 26 so that the model better learns how to divide a string into characters.

## Training with multiple characters

Now, we want to train our model to not only recognize the letters, but also to recognize space in-between so that we can use it for transcription later.

This is where complications begin. When transcribing a word from an image, we don't know beforehand how long the word is going to be. We can use our convolutional neural network we've pretrained on single characters to get prediction of character probabilities for all the positions of an input window in the new input image, but we don't know beforehand how to match those predictions with the target label. Training with incorrect matching can lead to wrong model, so in order to be able to train a network to transcribe words, we need a way to find these pairings.

![dl.png](https://i.imgur.com/7pnodfV.png)

The importance of pairings can be demonstrated by the drawing above. If we map $l_1, l_2, l_3, l_4$ to 'c', 'a', 't', '_' respectively, we'll correctly train the system, but if we put $l_1, l_2, l_3, l_4$ with 'a', 'a', 't', 't', we'd have a very wrong classifier.

To formalize this, we use energy-based models' framework. Let's define the energy $E(x, y, z)$ as the sum of cross-entropies for a particular pairing between probabilities our model gives for input image $x$ and text transcription $y$, and pairing $z$. $z$ is a function $z : \{1, 2, \dots, \vert l \vert \} \to \{1, 2, \dots, \vert y \vert)$, $l$ here is the energies output of our convolutional neural net $l = f_\theta(x)$. $z$ maps each energy vector in $l$ to an element in the output sequence $y$. We want the mappings to make sense, so $z$ should be a non-decreasing function $z(i) \leq z(i+1)$, and it shouldn't skip characters, i.e. $\forall_i \exists_j z(j)=i$.

Energy is then $E(x, y, z) = C(z) + \sum_{i=1}^{\vert l \vert} l_i[z(i)]$
,  $C(z)$ is some extra term that allows us to penalize certain pairings, and $l_i[z(i)]$ is the energy of $z(i)$-th symbol on position $i$.

In this particular context, we define $C(z)$ to be infinity for impossible pairings:
$$C(z) = \begin{cases}
\infty \; \text{if} \; z(1) \neq 1 \vee z(\vert l \vert) \neq \vert y \vert \vee \exists_{i, 1\leq 1 \leq \vert l \vert - 1} z(i) > z(i+1) \vee z(i) < z(i+1) - 1\\
0 \; \text{otherwise}
\end{cases}
$$


Then, the free energy $F(x, y) = \arg \min_z E(x, y, z)$. In other words, the free energy is the energy of the best pairing between the probabilities provided by our model, and the target labels.

When training, we are going to use cross-entropies along the best path: $\ell(x, y, z) = \sum_{i=1}^{\vert l \vert}H(y_{z(i)}, \sigma(l_i))$, where $H$ is cross-entropy, $\sigma$ is soft-argmin needed to convert energies to a distribution.

First, let's write functions that would calculate the needed cross entropies $H(y_{z(i)}, \sigma(l_i))$, and energies for us.

In [None]:
def build_path_matrix(energies, targets):
    # inputs: 
    #    energies, shape is BATCH_SIZE x ALPHABET_SIZE x L
    #    targets, shape is BATCH_SIZE x T, elements in range [0, ALPHABET_SIZE)
    # L is |l|, i.e. the width of the image in pixels (minus the window size)
    # T is |y|, i.e. the number of symbols in the target
    # 
    # outputs:
    #    a matrix of shape BATCH_SIZE x L x T
    #    where output[i, j, k] = energies[i, targets[i, k], j]
    #
    # Note: you're not allowed to use for loops. The calculation has to be vectorized.
    # you may want to use repeat and repeat_interleave.
    B, A, L = energies.shape
    # assert A == ALPHABET_SIZE, f"{energies.size(2)} != {ALPHABET_SIZE}"
    _, T = targets.shape
    i, j, k = torch.meshgrid(
        torch.arange(B), torch.arange(L), torch.arange(T), indexing="ij"
    )
    output = energies[i, targets[i, k], j]
    return output


def build_ce_matrix(energies, targets, weight=None, verbose=False):
    # inputs: 
    #    energies, shape is BATCH_SIZE x ALPHABET_SIZE x L
    #    targets, shape is BATCH_SIZE x T
    # L is |l|
    # T is |y|
    # 
    # outputs:
    #    a matrix ce of shape BATCH_SIZE x L x T
    #    where ce[i, j, k] = cross_entropy(energies[i, :, j], targets[i, k])
    #
    # Note: you're not allowed to use for loops. The calculation has to be vectorized.
    # you may want to use repeat and repeat_interleave.
    B, A, L = energies.shape
    # assert A == ALPHABET_SIZE, f"{energies.size(2)} != {ALPHABET_SIZE}"
    _, T = targets.shape
    i, j, k = torch.meshgrid(
        torch.arange(B), torch.arange(L), torch.arange(T), indexing="ij"
    )
    assert i.shape == (B, L, T), f"{i.shape=} != {(B, L, T)}"
    if verbose: print(f"{energies.shape=}")
    if verbose: print(f"{i.shape=}")
    if verbose: print(f"{energies[i, :, j].shape=}")
    # if verbose: print(f"{energies=}")
    # if verbose: print(f"{energies[i, :, j]=}")
    if verbose: print(f"{targets.shape=}")
    if verbose: print(f"{targets[i, k].shape=}")
    output = cross_entropy(
        energies[i, :, j].permute(0, 3, 1, 2),
        targets[i, k],
        weight=weight,
        reduction="none",
    )
    # TODO: clean this up
    if verbose: print(f"{output.shape=}")
    output_naive = torch.full_like(output, torch.nan)
    for i in range(B):
        for j in range(L):
            for k in range(T):
                output_naive[i, j, k] = cross_entropy(
                    energies[i, :, j], targets[i, k], weight=weight, reduction="none",
                )
    if verbose: print(f"{output=}")
    if verbose: print(f"{output_naive=}")
    assert torch.allclose(output, output_naive, atol=1e-6), (
        f"{output=} != {output_naive=}"
    )
    
    return output


In [None]:
# Unit tests for build_path_matrix() and build_ce_matrix()
Bt = 2
Lt = 4
Tt = 5
At = 3
test_energies = torch.randn((Bt, At, Lt), requires_grad=False)
test_targets = torch.tensor([
    [0, 1, 2, 1, 0],
    [2, 1, 0, 1, 2],
])
test_pm = build_path_matrix(test_energies, test_targets)
assert test_pm.shape == (Bt, Lt, Tt), f"{test_pm.shape=} != {(Bt, Lt, Tt)}"
for b in range(test_targets.size(0)):
    for t in range(test_targets.size(1)):
        assert torch.allclose(
            test_pm[b, :, t], test_energies[b, test_targets[b,t], :]
        ), f"{test_pm[b, :, t]=} != {test_energies[b, test_targets[b,t], :]=}"

test_energies = torch.tensor([[
    [0.0, 100.0, 100.0, 100.0, 100.0],
    [100.0, 100.0, 0.0, 100.0, 100.0],
    [100.0, 0.0, 100.0, 0.0, 0.0],
]])
print(f"{test_energies.shape=}")
test_targets = torch.tensor([[0, 2, 1, 2]])
test_cem = build_ce_matrix(test_energies, test_targets, verbose=True)
print(f"{test_cem.shape=}\n{test_cem=}")
for b in range(test_cem.size(0)):
    for l in range(test_cem.size(1)):
        for t in range(test_cem.size(2)):
            expected_loss = cross_entropy(
                test_energies[b, :, l],
                test_targets[b, t],
            )
            assert torch.allclose(
                test_cem[b, l, t], expected_loss, atol=1e-6,
            ), f"{test_cem[b, l, t]=} != {expected_loss=}"

test_pm = build_path_matrix(test_energies, test_targets)[0]
test_path_energy, test_path, _ = find_path(test_pm)
test_path_tensor = torch.tensor(test_path)
print(f"{test_path_tensor.shape=}\n{test_path_tensor=}")
print(f"{test_path_energy=}")
plot_pm(test_pm.T, test_path)
test_path_loss = test_cem[0, test_path_tensor[:,0], test_path_tensor[:,1]].sum()
print(f"{test_path_loss=}")

In [None]:
B, L, T = 16, 300, 25
dummy_energies = torch.ones((B, ALPHABET_SIZE, L))
dummy_targets = torch.ones((B, T), dtype=torch.int64)
print(f"{dummy_energies.shape=}, {dummy_targets.shape=}")
o1 = build_path_matrix(dummy_energies, dummy_targets)
print(f"{o1.shape=}")
assert o1.shape == (B, L, T), f"{o1.shape=} != {(B, L, T)}"

o2 = build_ce_matrix(dummy_energies, dummy_targets)
print(f"{o2.shape=}")
assert o2.shape == (B, L, T), F"{o2.shape=} != {(B, L, T)}"

Another thing we will need is a transformation for our label $y$. We don't want to use it as is, we want to insert some special label after each character, so, for example `cat` becomes `c_a_t_`. This extra `_` models the separation between words, allowing our model to distinguish between strings `aa` and `a` in its output. This is then used in inference - we can just get the most likely character for each position from $l = f_\theta(x)$ (for example `aa_bb_ccc_`), and then remove duplicate characters (`a_b_c_`), and then remove `_` (`abc`). 
Let's implement a function that would change the string in this manner, and then map all characters to values from 0 to 26, with 0 to 25 corresponding to a-z, and 26 corresponding to _:

In [None]:
def char_to_num(char: str) -> int:
    if char == "_":
        return 26
    return ord(char) - ord("a")


def num_to_char(num: int) -> str:
    char = chr(num + ord("a"))
    if char <= "z":
        return char
    return "_"


def transform_word(s):
    # input: a string
    # output: a tensor of shape 2*len(s)
    underscored = "_".join(s) + "_"
    return torch.tensor(
        [char_to_num(char) for char in underscored], dtype=torch.int64
    )

    
# Unit tests
assert char_to_num("a") == 0, f"{char_to_num('a')} != 0"
assert char_to_num("b") == 1, f"{char_to_num('b')} != 1"
assert char_to_num("z") == 25, f"{char_to_num('z')} != 25"
assert char_to_num("_") == 26, f"{char_to_num('_')} != 26"
assert num_to_char(0) == "a", f"{num_to_char(0)} != 'a'"
assert num_to_char(1) == "b", f"{num_to_char(1)} != 'b'"
assert num_to_char(25) == "z", f"{num_to_char(25)} != 'z'"
assert num_to_char(26) == "_", f"{num_to_char(26)} != '_'"
assert (transform_word("") == torch.tensor([])).all()
assert (transform_word("a") == torch.tensor([0, 26])).all()
assert (transform_word("abc") == torch.tensor([0, 26, 1, 26, 2, 26])).all()

In [None]:
def plot_pm(energies, path=None):
    fig=plt.figure(dpi=200)
    ax = plt.axes()
    im = ax.imshow(energies.to(device))
    
    ax.set_xlabel('window locations →')
    ax.set_ylabel('← label characters')
    ax.xaxis.set_label_position('top') 
    ax.set_xticks([])
    ax.set_yticks([])
    
    if path is not None:
        for i in range(len(path) - 1):
            ax.plot(*path[i], *path[i+1], marker = 'o', markersize=0.5, linewidth=10, color='r', alpha=1)

    cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax) 

with torch.no_grad():
    energies = model(alphabet.to(device).unsqueeze(dim=0))
targets = transform_word(string.ascii_lowercase).unsqueeze(0)

pm = build_path_matrix(energies, targets)
plot_pm(energies[0].detach())

What do you see? What does the model classify correctly, and what does it have problems with?

Answer: The heatmap is identical to the plot above since we reran the model on the same input (the full alphabet image) without doing any retraining in between.

Searching for a good pairing $z$ is same as searching for a trajectory with a small sum of it's values in this `pm` matrix. Where does the trajectory start, and where does it end? What other properties does the trajectory have? Can you see where an optimal trajecotry would be passing through in the plot above?

Answer: The ideal path for this input image (the full alphabet) would start in the top-left at a character label of `a` (0) for window location 0. It would generally trend down and to the left, ending up at a label of `z` (25) for the largest window location. Between characters, the path would drop down to a character label of `_` (26) to represent the fact that the window isn't aligned with a character at those locations.

Now let's implement a function that would tell us the energy of a particular path (i.e. pairing).

Energy is then $E(x, y, z) = C(z) + \sum_{i=1}^{\vert l \vert} l_i[z(i)]$
,  $C(z)$ is some extra term that allows us to penalize certain pairings, and $l_i[z(i)]$ is the energy of $z(i)$-th symbol on position $i$.

In this particular context, we define $C(z)$ to be infinity for impossible pairings:
$$C(z) = \begin{cases}
\infty \; \text{if} \; z(1) \neq 1 \vee z(\vert l \vert) \neq \vert y \vert \vee \exists_{i, 1\leq 1 \leq \vert l \vert - 1} z(i) > z(i+1) \vee z(i) < z(i+1) - 1\\
0 \; \text{otherwise}
\end{cases}
$$

In [None]:
def path_energy(pm, path, verbose=False):
    # inputs:
    #   pm - a matrix of energies 
    #    L - energies length
    #    T - targets length
    #   path - list of length L that maps each energy vector to an element in T
    # returns:
    #   energy - sum of energies on the path, or 2**30 if the mapping is invalid
    L, T = pm.shape
    assert len(path) == L, f"{len(path)} != {L}"
    INVALID = torch.tensor(2**30)
    if path[0] != 0:
        if verbose: print(f"Starts at {path[0]=} != 0")
        return INVALID
    if path[-1] != T - 1:
        if verbose: print(f"Ends at {path[-1]=} != {T - 1}")
        return INVALID
    if any(path[i] > path[i+1] for i in range(L - 1)):
        if verbose: print(f"Decreases")
        return INVALID
    if any(path[i] < path[i+1] - 1 for i in range(L - 1)):
        if verbose: print(f"Skips")
        return INVALID
    if verbose: print("Valid path")
    w = pm[torch.arange(L), path]
    return w.sum()


In [None]:
path = torch.zeros(energies.shape[2] - 1)
path[:targets.shape[1] - 1] = 1
path = [0] + list(map(lambda x : x.int().item(), path[torch.randperm(path.shape[0])].cumsum(dim=-1)))
points = list(zip(range(energies.shape[2]), path))

# TODO: why is this striped, alternating between high and low values every row?
# It's because the BETWEEN character comes after each letter and has a very high
# energy.
plot_pm(pm[0].T.detach(), points)
energy = path_energy(pm[0], path, verbose=True)
print(f"energy is {energy.item()}")

Now, generate two paths with the worst possible energy, print their energies and plot them.

In [None]:
# The worse possible energy is one for an invalid path,
# such as one that's not monotionc.
t = torch.arange(energies.size(2))
jump_path = torch.zeros_like(t)
jump_path[energies.size(2) // 2:] = targets.size(1) - 1
jump_points = list(zip(t, jump_path))
plot_pm(pm[0].T.detach(), jump_points)
jump_energy = path_energy(pm[0], jump_path, verbose=True).item()
print(f"Jump path energy: {jump_energy}")

sin_path = targets.size(1) * (torch.sin(2 * torch.pi * t / energies.size(2)) / 2 + 0.5)
sin_points = list(zip(t, sin_path))
plot_pm(pm[0].T.detach(), sin_points)
sin_energy = path_energy(pm[0], sin_path, verbose=True).item()
print(f"Sine wave energy: {sin_energy}")

### Optimal path finding
Now, we're going to implement the finding of the optimal path. To do that, we're going to use Viterbi algorithm, which in this case is a simple dynamic programming problem.
In this context, it's a simple dynamic programming algorithm that for each pair i, j, calculates the minimum cost of the path that goes from 0-th index in the energies and 0-th index in the target, to i-th index in the energies, and j-th index in the target. We can memorize the values in a 2-dimensional array, let's call it `dp`. Then we have the following transitions:
```
dp[0, 0] = pm[0, 0]
dp[i, j] = min(dp[i - 1, j], dp[i - 1, j - 1]) + pm[i, j]
```

The optimal path can be recovered if we memorize which cell we came from for each `dp[i, j]`.

Below, you'll need to implement this algorithm:

In [None]:
# This cell doesn't use the path_energy() function
# even though it seems like it should.

def energy_to_point(pm, i, j, dp):
    if (i, j) in dp:
        return dp[(i, j)]
    if i < j:
        raise Exception(f"Invalid path destination: ({i=}, {j=})")
    if i == 0 and j == 0:
        energy = 0
        path = ()
    elif j == 0:
        energy, path = energy_to_point(pm, i - 1, j, dp)
    elif i == j:
        energy, path = energy_to_point(pm, i - 1, j - 1, dp)
    else:
        assert i != 0 and j != 0
        assert i > j
        energy, path = min(
            energy_to_point(pm, i - 1, j, dp),
            energy_to_point(pm, i - 1, j - 1, dp),
        )
    energy = energy + pm[i, j]
    path = path + ((i, j),)
    path_js = [p[1] for p in path]
    valid_energy = path_energy(pm[:i + 1, :j + 1], path_js)
    if not torch.isclose(energy, valid_energy, atol=1e-3):
        print(f"Found path to ({i}, {j}) is invalid")
        print("Naive energy:", energy)
        print("Validated energy", valid_energy)
        print(path_js)
        raise Exception("Energy discrepancy")
    dp[i, j] = (energy, path)
    return energy, path


def find_path(pm):
    # inputs:
    #   pm - a tensor of shape LxT with energies
    #     L is length of energies array
    #     T is target sequence length
    # NOTE: this is slow because it's not vectorized to work with batches.
    #  output:
    #     a tuple of three elements:
    #         1. sum of energies on the best path,
    #         2. list of tuples - points of the best path in the pm matrix 
    #         3. the dp array
    dp_dict = {}
    min_energy, best_path = energy_to_point(
        pm, pm.size(0) - 1, pm.size(1) - 1, dp_dict
    )
    dp_array = torch.full_like(pm, torch.nan)
    for (i, j) in dp_dict:
        dp_array[i, j] = dp_dict[(i, j)][0]
    return min_energy, best_path, dp_array


Let's take a look at the best path:

In [None]:
free_energy, path, d = find_path(pm[0])
plot_pm(pm[0].T.detach(), path)
print('free energy is', free_energy.item())

We can also visualize the dp array. You may need to tune clamping to see what it looks like.

In [None]:
plt.figure(dpi=200)
plt.imshow(d.cpu().detach().T)
plt.axis('off')

### Training loop
Now is time to train the network using our best path finder. We're going to use the energy loss function:
$$\ell(x, y) = \sum_i H(y_{z(i)}, l_i)$$
Where $z$ is the best path we've found. This is akin to pushing down on the free energy $F(x, y)$, while pushing up everywhere else by nature of cross-entropy.

In [None]:
def collate_fn(samples):
    """ A function to collate samples into batches for multi-character case"""
    images, annotations = zip(*samples)
    images = list(images)
    annotations = list(annotations)
    annotations = list(map(transform_word, annotations))
    m_width = max(CHAR_WIDTH, max([i.shape[1] for i in images]))
    m_length = max(3, max([s.shape[0] for s in annotations]))
    for i in range(len(images)):
        images[i] = torch.nn.functional.pad(images[i], (0, m_width - images[i].shape[-1]))
        annotations[i] = torch.nn.functional.pad(annotations[i], (0, m_length - annotations[i].shape[0]), value=BETWEEN)
    if len(images) == 1:
        return images[0].unsqueeze(0), torch.stack(annotations)
    else:
        return torch.stack(images), torch.stack(annotations)

WORD_LENGTH = 2
sds = SimpleWordsDataset(WORD_LENGTH, 30_000)

BATCH_SIZE = 100
dataloader = torch.utils.data.DataLoader(
    sds, batch_size=BATCH_SIZE, num_workers=0, collate_fn=collate_fn
)

# TODO: train the model
# note: remember that our best path finding algorithm is not batched, so you'll
# need a for loop to do loss calculation. 
# This is not ideal, as for loops are very slow, but for 
# demonstration purposes it will suffice. In practice, this will be
# unusable for any real problem unless it handles batching.

# also: remember that the loss is the sum of cross_entropies along the path, not 
# energies!

model = SimpleNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
losses = []
# Down-weight the BETWEEN character since it appears so often
loss_weights = torch.ones(ALPHABET_SIZE)
loss_weights[BETWEEN] = 1 / ALPHABET_SIZE

# model.register_full_backward_hook(
#     lambda model, grad_input, grad_output: print(grad_input, grad_output, "", sep="\n")
# )

for i, (images, targets) in enumerate(dataloader):
    assert images.shape == (
        BATCH_SIZE, CHAR_HEIGHT, CHAR_WIDTH * WORD_LENGTH
    ), (
        f"{images.shape=} != {(BATCH_SIZE, CHAR_HEIGHT, CHAR_WIDTH * WORD_LENGTH)=}"
    )
    assert (
        targets.shape == (BATCH_SIZE, 2 * WORD_LENGTH)
    ), (
        f"{targets.shape=} != {(BATCH_SIZE, 2 * WORD_LENGTH)=}"
    )
    images = images.to(device)
    targets = targets.to(device)
    optimizer.zero_grad()
    energies = model(images)
    assert energies.ndim == 3 and energies.shape[:2] == (
        BATCH_SIZE, ALPHABET_SIZE
    ), f"{energies.shape=} != ({BATCH_SIZE=}, {ALPHABET_SIZE=}, W)"
    path_matrix = build_path_matrix(energies, targets)
    cross_entropy_matrix = build_ce_matrix(energies, targets, loss_weights)
    assert path_matrix.shape == cross_entropy_matrix.shape
    assert path_matrix.shape == (
        BATCH_SIZE, energies.size(2), 2 * WORD_LENGTH
    ), (
        f"{path_matrix.shape=} != "
        f"{(BATCH_SIZE, energies.size(2), 2 * WORD_LENGTH)=}"
    )
    loss = torch.tensor(0.0, requires_grad=True)
    for b in range(path_matrix.size(0)):
        _, path, _ = find_path(path_matrix[b])
        path_tensor = torch.tensor(path)
        cem = cross_entropy_matrix[b]
        assert (
            cem.shape == (energies.size(2), 2 * WORD_LENGTH)
        ), (
            f"{cem.shape=} != {(energies.size(2), 2 * WORD_LENGTH)=}"
        )
        path_loss = cem[path_tensor[:,0], path_tensor[:,1]]
        loss = loss + path_loss.sum()
    loss = loss / BATCH_SIZE
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    print(f"Batch iteration: {i}, loss: {loss:.3f}     ", end="\r")

plt.plot(losses)
plt.xlabel("Batch iteration")
plt.ylabel("Loss")

In [None]:
plot_pm(path_matrix[-1].detach().T, path)
print(images.shape)
print(path_matrix.shape)

In [None]:
energy, path2, heatmap = find_path(path_matrix[-1])
plt.imshow(heatmap.detach().T)

In [None]:
letter_image_full, letter_label = next(iter(SimpleWordsDataset(2, 1)))
letter_image = letter_image_full[:, :]
plt.imshow(letter_image)

In [None]:
with torch.no_grad():
    letter_energies = model(letter_image.unsqueeze(0)).squeeze(0).squeeze(-1)
print(f"{letter_energies.shape=}")
# print(f"{letter_energies=}")
# sorted_indices = torch.argsort(letter_energies[:,0])
# for ind in sorted_indices:
#     print(num_to_char(ind), letter_energies[ind,0])
min_energy_indices = torch.argmin(letter_energies, 0)
for x in range(letter_energies.size(1)):
    print(
        x,
        num_to_char(min_energy_indices[x].item()),
        letter_energies[min_energy_indices[x], x],
    )

In [None]:
example_image, example_label = next(iter(SimpleWordsDataset(WORD_LENGTH, 1)))
print(f"{example_image.shape=}, {example_label=}")
plt.imshow(example_image)

In [None]:
example_target = transform_word(example_label)
print(f"{example_target=}")
with torch.no_grad():
    example_energies = model(example_image.unsqueeze(0))
print(f"{example_energies.shape=}")
plt.imshow(example_energies[0].detach())

In [None]:
example_pm = build_path_matrix(example_energies, example_target.unsqueeze(0))
print(f"{example_pm.shape=}")
example_min_energy, example_path, example_dp = find_path(example_pm[0])
print(f"{example_min_energy=}, {example_path=}")
plot_pm(example_pm[0].detach().T, example_path)

Let's check what the energy matrix looks like for the alphabet image now.

In [None]:
energies = model(alphabet.unsqueeze(0).to(device))
targets = transform_word(string.ascii_lowercase)
pm = build_path_matrix(energies, targets.unsqueeze(0))

free_energy, path, _ = find_path(pm[0])
plot_pm(pm[0].detach().T, path)
print('free energy is', free_energy.item())

Explain how the free energy changed, and why.

Answer: #TODO

We can also look at raw energies output:

In [None]:
alphabet_energy_post_train_viterbi = model(alphabet.to(device).view(1, *alphabet.shape))

plt.figure(dpi=200, figsize=(40, 10))
plt.imshow(alphabet_energy_post_train_viterbi.cpu().data[0])
plt.axis('off')

How does this compare to the energies we had after training only on one-character dataset?

Answer: #TODO

## Decoding

Now we can use the model for decoding a word from an image. Let's pick some word, apply the model to it, and see energies. 

In [None]:
img = sds.draw_text('hello')
energies = model(img.to(device).unsqueeze(0))
plt.imshow(img)
plot_energies(energies[0].detach().cpu())

You should see some characters light up. Now, let's implement a simple decoding algorithm. To decode, first we want to get most likely classes for all energies, and then do two things:
1. segment strings using the divisors (our special character with index 26), and for each segment replace it with the most common character in that segment. Example: aaab_bab_ -> a_b. If some characters are equally common, you can pick random.
2. remove all special divisor characters: a_b -> ab


In [None]:
def indices_to_str(indices):
    # inputs: indices - a tensor of most likely class indices
    # outputs: decoded string
    chunk_start = 0
    characters = []
    for i, ind in enumerate(indices):
        if ind == BETWEEN:
            if chunk_start < i:
                most_common = indices[chunk_start:i].mode().values.item()
                characters.append(num_to_char(most_common))
            chunk_start = i + 1
    if len(indices) > 0 and ind != BETWEEN:
        characters.append(num_to_char(ind))
    return "".join(characters)

# Unit tests for indices_to_str()
test_input = torch.tensor([])
test_output = ""
actual_output = indices_to_str(test_input)
assert actual_output ==  test_output, (
    f"{actual_output!r} != {test_output!r}"
)
test_input = torch.tensor([0])
test_output = num_to_char(0)
actual_output = indices_to_str(test_input)
assert actual_output ==  test_output, (
    f"{actual_output!r} != {test_output!r}"
)
test_input = torch.tensor([BETWEEN])
test_output = ""
actual_output = indices_to_str(test_input)
assert actual_output ==  test_output, (
    f"{actual_output!r} != {test_output!r}"
)
test_input = torch.tensor([0, BETWEEN, BETWEEN])
test_output = num_to_char(0)
actual_output = indices_to_str(test_input)
assert actual_output ==  test_output, (
    f"{actual_output!r} != {test_output!r}"
)
test_input = torch.tensor([0, BETWEEN, BETWEEN, 1])
test_output = num_to_char(0) + num_to_char(1)
actual_output = indices_to_str(test_input)
assert actual_output ==  test_output, (
    f"{actual_output!r} != {test_output!r}"
)

In [None]:
min_indices = energies[0].argmin(dim=0)
print(min_indices)
print(indices_to_str(min_indices))