## Image captioning using recurrent language modelling and transfer learning based on CLIP

In [2]:
from pinnacledb.client import the_client

docs = the_client.coco.documents

In the final modelling part of this tutorial, we show that a radically different type of model can be created
but which also leverages the `create_imputation` functionality. This is possible, because we utilize a
different loss, target, and also use a model which has a different inference `forward` pass than training
`forward` pass. This is a very common occurrence, in AI, especially when doing, for example, autoregressive 
training.

The model we create, will be input an image, and will write out a sentence describing that image in unconstrained English. This task is known as "captioning" in AI speak.

This model will leverage a fixed vocabulary of "allowed" words. Let us first create this quickly in the 
following cell:

In [3]:
import tqdm 
import collections
import re

all_captions = []
n = docs.count_documents({'_fold': 'train'})
for r in tqdm.tqdm_notebook(docs.find({'_fold': 'train'}, {'captions': 1, '_id': 0}), total=n):
    all_captions.extend(r['captions'])
    
all_captions = [re.sub('[^a-z ]', '', x.lower()).strip() for x in all_captions]
words = ' '.join(all_captions).split(' ')
counts = dict(collections.Counter(words))
vocab = sorted([w for w in counts if counts[w] > 5 and w])

  0%|          | 0/77716 [00:00<?, ?it/s]

Now we can create the model - it utilizes a "tokenizer" for preprocessing the captioning data.

In `examples.models`:

```python

class SimpleTokenizer:
    def __init__(self, tokens, max_length=15):
        self.tokens = tokens
        if '<unk>' not in tokens:
            tokens.append('<unk>')
        self._set_tokens = set(self.tokens)
        self.lookup = dict(zip(self.tokens, range(len(self.tokens))))
        self.dictionary = {k: i for i, k in enumerate(tokens)}
        self.max_length = max_length

    def __len__(self):
        return len(self.tokens)

    def preprocess(self, sentence):
        sentence = re.sub('[^a-z]]', '', sentence.lower()).strip()
        words = [x for x in sentence.split(' ') if x]
        words = [x if x in self.tokens else '<unk>' for x in words]
        words = words[:self.max_length]
        tokenized = list(map(self.lookup.__getitem__, words))
        tokenized = tokenized + [len(self) + 1 for _ in range(self.max_length - len(words))]
        return torch.tensor(tokenized)


class ConditionalLM(torch.nn.Module):
    def __init__(self, tokenizer, n_hidden=512, max_length=15, n_condition=1024):
        super().__init__()

        self.tokenizer = tokenizer
        self.n_hidden = n_hidden
        self.embedding = torch.nn.Embedding(len(self.tokenizer) + 2, self.n_hidden)
        self.conditioning_linear = torch.nn.Linear(n_condition, self.n_hidden)
        self.rnn = torch.nn.GRU(self.n_hidden, self.n_hidden, batch_first=True)
        self.prediction = torch.nn.Linear(self.n_hidden, len(self.tokenizer) + 2)
        self.max_length = max_length

    def preprocess(self, r):
        out = {}
        if 'caption' in r:
            out['caption'] = [len(self.tokenizer)]  + self.tokenizer.preprocess(r['caption']).tolist()[:-1]
        else:
            out['caption'] = [len(self.tokenizer)]
        out['caption'] = torch.tensor(out['caption'])
        if 'img' in r:
            out['img'] = r['img']
        return out

    def train_forward(self, r):
        input_ = self.embedding(r['caption'])
        img_vectors = self.conditioning_linear(r['img']).unsqueeze(0)
        rnn_outputs = self.rnn(input_, img_vectors)[0]
        return self.prediction(rnn_outputs)

    def forward(self, r):
        hidden_states = self.conditioning_linear(r['img']).unsqueeze(0)
        predictions = \
            torch.zeros(r['caption'].shape[0], self.max_length).to(r['caption'].device).type(torch.long)
        predictions[:, 0] = r['caption'][:, 0]
        for i in range(self.max_length - 1):
            rnn_outputs, hidden_states = self.rnn(self.embedding(predictions[:, i]).unsqueeze(1),
                                                  hidden_states)
            logits = self.prediction(rnn_outputs)[:, 0, :]
            predictions[:, i + 1] = logits.topk(1, dim=1)[1][:, 0].type(torch.long)
        return predictions

    def postprocess(self, output):
        output = output.tolist()
        try:
            first_end_token = next(x for x in output if x == len(self.tokenizer) + 2)
            output = output[:first_end_token]
        except StopIteration:
            pass
        output = [x for x in output if x < len(self.tokenizer)]
        return ' '.join(list(map(self.tokenizer.tokens.__getitem__, output)))

```

In [4]:
from examples.models import ConditionalLM, SimpleTokenizer

tokenizer = SimpleTokenizer(vocab)
m = ConditionalLM(tokenizer)

Let us know create the required models necessary for training this model. One of the models is fairly 
trivial, only used to create the prediction target for the learning task:

In [5]:
docs.create_model('conditional_lm', object=m, active=False, features={'img': 'clip'}, key='_base')
docs.create_model('captioning_tokenizer', tokenizer, key='caption', active=False)

We'll use a standard autoregressive loss, of the sort used as a matter of course in language modelling tasks.

In `examples.losses`:


```python
def auto_regressive_loss(x, y):
    # start token = x.shape[2] - 2, stop_token = x.shape[2] - 1 (by convention)
    stop_token = x.shape[2] - 1
    x = x.transpose(2, 1)
    losses = torch.nn.functional.cross_entropy(x, y, reduce=False)
    not_stops = torch.ones_like(losses)
    not_stops[:, 1:] = (y[:, :-1] != stop_token).type(torch.long)
    normalizing_factors = not_stops.sum(axis=1).unsqueeze(1)
    av_loss_per_row = (losses * not_stops).div(normalizing_factors).sum(axis=1)
    return av_loss_per_row.mean()
```

In [6]:
from examples.losses import auto_regressive_loss
docs.create_loss('autoregressive_loss', auto_regressive_loss)

Since each record in the database has several captions per image, we'll need to use a so-called "splitter", to 
align the prediction model and prediction target during training. You can see that the splitter randomly chooses
one of the captions to train on for an iteration.

In `examples.splitters`:

```python
import random


def captioning_splitter(r):
    index = random.randrange(len(r['captions']))
    target = {}
    target['caption'] = r['captions'][index]
    r['caption'] = r['captions'][index]
    return r, target
```

In [7]:
from examples.splitters import captioning_splitter

docs.create_splitter('captioning_splitter', captioning_splitter)
captioning_splitter(docs.find_one())

({'_id': ObjectId('63f4d31b99e15ed933e61fcd'),
  'captions': ['A restaurant has modern wooden tables and chairs.',
   'A long restaurant table with rattan rounded back chairs.',
   'a long table with a plant on top of it surrounded with wooden chairs ',
   'A long table with a flower arrangement in the middle for meetings',
   'A table is adorned with wooden chairs with blue accents.'],
  'img': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x168>,
  '_fold': 'train',
  '_outputs': {'img': {'clip': tensor([ 0.0203,  0.0837,  0.0035,  ..., -0.0788,  0.0529, -0.1146]),
    'clip_projection': tensor([ 0.0012,  0.1153, -0.0723, -0.0510,  0.1098, -0.1030,  0.0434, -0.1056,
            -0.1817, -0.0636, -0.1248, -0.0559,  0.0486,  0.0592,  0.0267,  0.0602,
            -0.0153, -0.0122,  0.1966, -0.2138,  0.1524, -0.0079, -0.0866, -0.0067,
            -0.1458,  0.2234, -0.0030,  0.0827, -0.0748,  0.0598,  0.0271, -0.0271,
             0.0626, -0.0612,  0.0378,  0.1458,  0.0533,  0.

Since we have this new splitter, we need to create a new validation data set

In [8]:
docs.create_validation_set('captioning', splitter=docs.splitters['captioning_splitter'],
                           sample_size=500, chunk_size=100)

  0%|                                                                                                                                                                   | 0/500 [00:00<?, ?it/s]

downloading content from retrieved urls
found 0 urls
computing chunk (1/1)
finding documents under filter
done.
processing with clip



  0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s][A
 10%|███████████████▌                                                                                                                                            | 1/10 [00:01<00:09,  1.07s/it][A
 20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:02<00:07,  1.00it/s][A
 30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.03it/s][A
 40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.05it/s][A
 50%|██████████████

bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.
processing with clip_projection



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 13081.85it/s][A


bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.


 20%|██████████████████████████████▌                                                                                                                          | 100/500 [00:14<00:56,  7.02it/s]

bulk writing...
done.


 20%|███████████████████████████████▏                                                                                                                         | 102/500 [00:14<00:56,  7.09it/s]

downloading content from retrieved urls
found 0 urls
computing chunk (1/1)
finding documents under filter
done.
processing with clip



  0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s][A
 10%|███████████████▌                                                                                                                                            | 1/10 [00:00<00:08,  1.07it/s][A
 20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.09it/s][A
 30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.09it/s][A
 40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.09it/s][A
 50%|██████████████

bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.
processing with clip_projection



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5177.39it/s][A


bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.


 40%|█████████████████████████████████████████████████████████████▏                                                                                           | 200/500 [00:26<00:38,  7.87it/s]

bulk writing...
done.
downloading content from retrieved urls
found 0 urls
computing chunk (1/1)
finding documents under filter
done.
processing with clip



  0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s][A
 10%|███████████████▌                                                                                                                                            | 1/10 [00:00<00:08,  1.06it/s][A
 20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.05it/s][A
 30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.07it/s][A
 40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.08it/s][A
 50%|██████████████

bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.
processing with clip_projection



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 11748.09it/s][A


bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.


 60%|███████████████████████████████████████████████████████████████████████████████████████████▊                                                             | 300/500 [00:37<00:24,  8.20it/s]

bulk writing...
done.
downloading content from retrieved urls
found 0 urls
computing chunk (1/1)
finding documents under filter
done.
processing with clip



  0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s][A
 10%|███████████████▌                                                                                                                                            | 1/10 [00:00<00:08,  1.03it/s][A
 20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.04it/s][A
 30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.05it/s][A
 40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.06it/s][A
 50%|██████████████

bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.
processing with clip_projection



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 10012.90it/s][A


bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.


 80%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                              | 400/500 [00:49<00:12,  8.26it/s]

bulk writing...
done.
downloading content from retrieved urls
found 0 urls
computing chunk (1/1)
finding documents under filter
done.
processing with clip



  0%|                                                                                                                                                                    | 0/10 [00:00<?, ?it/s][A
 10%|███████████████▌                                                                                                                                            | 1/10 [00:00<00:08,  1.04it/s][A
 20%|███████████████████████████████▏                                                                                                                            | 2/10 [00:01<00:07,  1.07it/s][A
 30%|██████████████████████████████████████████████▊                                                                                                             | 3/10 [00:02<00:06,  1.08it/s][A
 40%|██████████████████████████████████████████████████████████████▍                                                                                             | 4/10 [00:03<00:05,  1.09it/s][A
 50%|██████████████

bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.
processing with clip_projection



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 10466.66it/s][A


bulk writing...
done.
computing chunk (1/1)
finding documents under filter
done.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:01<00:00,  8.15it/s]

bulk writing...
done.





Now we're ready to start the training:

In [None]:
docs.create_imputation(
    'image_captioner',
    model='conditional_lm',
    loss='autoregressive_loss',
    target='captioning_tokenizer',
    splitter='captioning_splitter',
    validation_sets=['captioning'],
    batch_size=50,
    lr=0.001,
)

Let's test the model on a sample data point:

In [None]:
test_docs = list(docs.find().limit(20))
images = list(docs.find({}, {'img': 1}).limit(100))

results = docs.apply_model('conditional_lm', test_docs, batch_size=10)

for r, res in zip(images, results):
    display(r['img'])
    print(res)

Now we have trained and evaluated several models of various types. This includes multiple interacting models with mutual dependencies. In the case of our own efficient semantic search, and also the attribute predictor, these models are downstream of the image clip model, in the sense that at inference time, clip must be present in order to be able to execute these models. In the case of attribute prediction, the training task was downstream from the 
spacy pipeline for part-of-speech tagging; these tags were used to produce targets for training. However at run-time, the spacy pipeline won't be necessary.

The models which we've added and trained are now ready to go, and when new data is added or updated to the collection, they will automatically process this data, and insert the model outputs into the collection documents.

Here is the complete set of models which exist in the collection:

In [None]:
docs.list_models()

Not all of these respond to incoming data, for that we need to specify the `active` argument:

In [None]:
docs.list_models(active=True)

We can see that these models have processed all documents and their outputs saved:

In [None]:
docs.find_one()