# Image Caption Generator Using Transformer

<a href="https://colab.research.google.com/github/sha9189/image-captioning-using-transformers/blob/master/predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Click on **Open in Colab** button above to quickly run the model on any image of your choice.    
<br>   
This model's weights(500 MB) had to be uploaded on Google Drive due to the file-size limit of GitHub. Since Google Drive is a banned service within Allstate Network, it fails to load the model's weights when run on Allstate Network. Hence, it is adviced to try out this model on your **personal laptop**(and not on Allstate-issued machine).
<br>   
<br>   
This is a preview of an Image Captioning System developed to describe different **scenes of a house** (like living room, bedroom, kitchen, bathroom, etc). Its architecture was influenced by the recent success of the [DETR](https://github.com/facebookresearch/detr) (DEtection-TRansformer) model on the task of Object Detection.   
<br>   
To run this Notebook, click inside the first cell below and repeatedly press `Shift+Enter` to progressively run the cells.

### Download GitHub Repository and Setup Connection with Google Drive

In [2]:
!git clone https://github.com/sha9189/image-captioning-using-transformers.git

cd image-captioning-using-transformers/

### Set up actual model end-to-end

In [2]:
import requests
import torch
from PIL import Image
import spacy
try:
    spacy_eng = spacy.load("en")
except:
    spacy_eng = spacy.load("en_core_web_sm")
def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]

from configuration import Config
from models import caption 
from datasets import coco


def load_model():
    model, _ = caption.build_model(config)

    # load weights

    model.backbone.load_state_dict(torch.load("checkpoints/checkpoint-breakdown/backbone.pth", map_location='cpu'))
    model.input_proj.load_state_dict(torch.load("checkpoints/checkpoint-breakdown/input_proj.pth", map_location='cpu'))
    model.transformer.load_state_dict(torch.load("checkpoints/checkpoint-breakdown/transformer.pth", map_location='cpu'))
    model.mlp.load_state_dict(torch.load("checkpoints/checkpoint-breakdown/mlp.pth", map_location='cpu'))
    model.to("cuda" if torch.cuda.is_available() else "cpu")

    return model


def create_caption_and_mask(start_token, max_length):
    caption_template = torch.zeros((1, max_length), dtype=torch.long)
    mask_template = torch.ones((1, max_length), dtype=torch.bool)

    caption_template[:, 0] = start_token
    mask_template[:, 0] = False

    return caption_template, mask_template


@torch.no_grad()
def evaluate(image, caption, cap_mask):
    model.eval()
    for i in range(config.max_position_embeddings - 1):
        predictions = model(image, caption, cap_mask)
        predictions = predictions[:, i, :]
        predicted_id = torch.argmax(predictions, axis=-1)

        if predicted_id[0] == 3:
            caption[:, i+1] = predicted_id[0]
            return caption

        caption[:, i+1] = predicted_id[0]
        cap_mask[:, i+1] = False

    return caption



def decode_caption(output, end_token):
    sentence = []
    for idx in output:
        if idx == end_token:
            break
        word = english.vocab.itos[idx]
        sentence.append(word)
    # Remove <sos> from sentence
    sentence = " ".join(sentence[1:])
    return sentence


def predict_nb(image, model):
    image = coco.val_transform(image)
    image = image.unsqueeze(0)

    start_token = english.vocab.stoi["<sos>"]
    end_token = english.vocab.stoi["<eos>"]

    caption, cap_mask = create_caption_and_mask(start_token, config.max_position_embeddings)
    
    image = image.to("cuda" if torch.cuda.is_available() else "cpu")
    caption = caption.to("cuda" if torch.cuda.is_available() else "cpu")
    cap_mask = cap_mask.to("cuda" if torch.cuda.is_available() else "cpu") 

    output = evaluate(image, caption, cap_mask)
    output = output.tolist()[0]
    output = decode_caption(output, end_token)
    output = output.capitalize()
    return output

In [7]:
config = Config()
model = load_model()
english = torch.load('english.pth')

To select image of your choice, follow these steps:   
- Open [Google Images](https://images.google.com/) and search for any one of these - bedroom, kitchen, living room, bathroom.   
- Click on image of your choice and select `Copy image address`.   
- Run the below cell and paste it the link in the input box when prompted.  

*To test multiple images, run only the below cell as many times as you'd like with different image addresses.*

<img src="./images/copy_image_address.jpg" alt="drawing" width="500"/>

In [4]:
print("Please enter the image address below:")
url = input()
image = Image.open(requests.get(url, stream=True).raw).resize((640,480)).convert('RGB')
print(predict_nb(image, model))
image