Application description:

- Performing visual entity extraction
    * First use tool to generate image caption
    * Then use NER on the caption to get the entities

This is useful in for example fake news detection where we want to use external knowledge on the visual part of a news article to further enhance fact checking.

Novelty

- Combining these two processes has not been done before to my knowledge, at least not in the context of fake news detection

In [9]:
import torch
from PIL import Image
# setup device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load sample image
raw_image = Image.open("img/181002113456-01-golden-gate-bridge-restricted.jpg").convert("RGB")

In [63]:

from lavis.models import load_model_and_preprocess
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
# this also loads the associated image processors
model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
# preprocess the image
# vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
# generate caption
img_caption = model.generate({"image": image})
# ['a large fountain spewing water into the air']
print(img_caption)

['a view of the golden gate bridge in san francisco']


In [50]:
from flair.datasets import CONLL_03
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [51]:
# 1. get the corpus
corpus = CONLL_03()
print(corpus)

2022-11-07 21:34:15,613 Reading data from /Users/oysteinlondalnilsen/.flair/datasets/conll_03
2022-11-07 21:34:15,614 Train: /Users/oysteinlondalnilsen/.flair/datasets/conll_03/train.txt
2022-11-07 21:34:15,615 Dev: None
2022-11-07 21:34:15,615 Test: /Users/oysteinlondalnilsen/.flair/datasets/conll_03/test.txt
Corpus: 13488 train + 1499 dev + 3684 test sentences


In [52]:
# 2. what label do we want to predict?
label_type = 'ner'

In [54]:
# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type)
print(label_dict)

2022-11-07 21:35:54,568 Computing label dictionary. Progress:


13488it [00:00, 39995.99it/s]

2022-11-07 21:35:54,927 Dictionary created for label 'ner' with 10 values: o (seen 153761 times), b-loc (seen 6446 times), b-per (seen 5895 times), b-org (seen 5698 times), i-per (seen 4039 times), i-org (seen 3361 times), b-misc (seen 3085 times), i-misc (seen 1051 times), i-loc (seen 1034 times)
Dictionary with 10 tags: <unk>, o, b-loc, b-per, b-org, i-per, i-org, b-misc, i-misc, i-loc





In [55]:
# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

Downloading config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

Downloading sentencepiece.bpe.model:   0%|          | 0.00/4.83M [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

In [56]:
# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

2022-11-07 21:43:58,146 SequenceTagger predicts: Dictionary with 10 tags: <unk>, o, b-loc, b-per, b-org, i-per, i-org, b-misc, i-misc, i-loc


In [57]:
# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

In [66]:
# 7. run fine-tuning
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2022-11-08 01:31:38,176 ----------------------------------------------------------------------------------------------------
2022-11-08 01:31:38,179 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(250002, 1024, padding_idx=1)
        (position_embeddings): Embedding(514, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_feature

In [64]:
from flair.data import Sentence

# load the model you trained
model = SequenceTagger.load('resources/taggers/sota-ner-flert/final-model.pt')

# make a sentence
sentence = Sentence(img_caption)

# run NER over sentence
model.predict(sentence)

2022-11-08 01:29:09,021 loading file resources/taggers/sota-ner-flert/final-model.pt
2022-11-08 01:29:26,543 SequenceTagger predicts: Dictionary with 10 tags: <unk>, o, b-loc, b-per, b-org, i-per, i-org, b-misc, i-misc, i-loc


In [65]:
# print the sentence with all annotations
print(sentence)

print('The following NER tags are found:')

# iterate over entities and print each
for entity in sentence.get_spans('ner'):
    print(entity)

Sentence: "a view of the golden gate bridge in san francisco" → ["a view of the golden gate bridge in san francisco"/b-loc]
The following NER tags are found:
