In [15]:
import torch
import clip
from model import ClipGPT2Model
from transformers import AutoTokenizer
from model_config import config


In [16]:
device = torch.device('cuda:0')
# _, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
weights_path = config['weights_path']
CPU = torch.device("cpu")

model = ClipGPT2Model(config['prefix_length'], config['clip_length'])
model.load_state_dict(torch.load(weights_path, map_location=CPU))
model = model.eval()
model = model.to(device)

In [17]:
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
                  entry_length=67, temperature=1., stop_token: str = '.'):

    model.eval()
    stop_token_index = tokenizer.encode(stop_token)[0]
    tokens = None
    scores = None
    device = next(model.parameters()).device
    seq_lengths = torch.ones(beam_size, device=device)
    is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
    with torch.no_grad():
        if embed is not None:
            generated = embed
        else:
            if tokens is None:
                tokens = torch.tensor(tokenizer.encode(prompt))
                tokens = tokens.unsqueeze(0).to(device)
                generated = model.gpt.transformer.wte(tokens)
        for i in range(entry_length):
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            logits = logits.softmax(-1).log()
            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                generated = generated.expand(beam_size, *generated.shape[1:])
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts

In [18]:
from datasets import Dataset

valid_dataset = Dataset.from_file("./data/coca_valid.arrow")

In [19]:
for x in valid_dataset:
    print(x)
    break

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=400x400 at 0x7F4769642EE0>, 'caption': 'A bicycle replica with a clock as the front wheel.', 'clip': [0.166259765625, 0.06500244140625, -0.08172607421875, 0.129150390625, -0.336181640625, -0.44921875, 0.11492919921875, 0.467041015625, 0.091064453125, 0.15771484375, 0.385498046875, -0.10076904296875, 0.4599609375, -0.5029296875, 0.439453125, 0.374267578125, 0.402587890625, 0.96875, -0.3759765625, -0.373291015625, -0.39697265625, -0.444091796875, -0.40380859375, -0.08465576171875, 0.1541748046875, 0.1566162109375, -0.1502685546875, -0.2135009765625, 0.363037109375, 0.11151123046875, 0.356201171875, -0.1624755859375, -0.07275390625, 0.4619140625, -0.47802734375, -0.2098388671875, -0.51318359375, -0.15380859375, 0.2171630859375, 0.740234375, -0.037017822265625, -0.39794921875, 0.5791015625, -0.060638427734375, 0.0244140625, -1.5673828125, -0.383544921875, -0.2159423828125, 0.058441162109375, 0.02362060546875, -0.0223541259765

In [20]:
import json

with open("/data/qiaowei/coco2014/annotations/captions_val2014.json") as f:
    data = json.load(f)


In [21]:
data

{'info': {'description': 'COCO 2014 Dataset',
  'url': 'http://cocodataset.org',
  'version': '1.0',
  'year': 2014,
  'contributor': 'COCO Consortium',
  'date_created': '2017/09/01'},
 'images': [{'license': 3,
   'file_name': 'COCO_val2014_000000391895.jpg',
   'coco_url': 'http://images.cocodataset.org/val2014/COCO_val2014_000000391895.jpg',
   'height': 360,
   'width': 640,
   'date_captured': '2013-11-14 11:18:45',
   'flickr_url': 'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg',
   'id': 391895},
  {'license': 4,
   'file_name': 'COCO_val2014_000000522418.jpg',
   'coco_url': 'http://images.cocodataset.org/val2014/COCO_val2014_000000522418.jpg',
   'height': 480,
   'width': 640,
   'date_captured': '2013-11-14 11:38:44',
   'flickr_url': 'http://farm1.staticflickr.com/1/127244861_ab0c0381e7_z.jpg',
   'id': 522418},
  {'license': 3,
   'file_name': 'COCO_val2014_000000184613.jpg',
   'coco_url': 'http://images.cocodataset.org/val2014/COCO_val2014_000000184613.

In [22]:
import numpy as np

df = valid_dataset.to_pandas()

In [23]:
df

Unnamed: 0,image,caption,clip
0,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A bicycle replica with a clock as the front wh...,"[0.166259765625, 0.06500244140625, -0.08172607..."
1,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A black Honda motorcycle parked in front of a ...,"[0.340576171875, 0.326904296875, -0.1204833984..."
2,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A room with blue walls and a white sink and door.,"[0.206298828125, 0.323974609375, -0.0002510547..."
3,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A car that seems to be parked illegally behind...,"[0.285400390625, 0.314208984375, 0.5908203125,..."
4,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A large passenger airplane flying through the ...,"[-0.1751708984375, 0.77783203125, -0.241699218..."
...,...,...,...
202649,"{'bytes': None, 'path': '/data/qiaowei/coco201...",A plate of food and a beverage are on a table.,"[0.215576171875, 0.21337890625, 0.1591796875, ..."
202650,"{'bytes': None, 'path': '/data/qiaowei/coco201...",This is an open faced sandwich with several co...,"[0.215576171875, 0.21337890625, 0.1591796875, ..."
202651,"{'bytes': None, 'path': '/data/qiaowei/coco201...",People eating in a restaurant near wine bottles.,"[-0.0170745849609375, 0.1572265625, -0.1843261..."
202652,"{'bytes': None, 'path': '/data/qiaowei/coco201...",The scissors with black handles are sitting open.,"[-0.1541748046875, 0.29345703125, 0.2028808593..."


In [24]:
for image in df['image']:
    print(image['path'])
    print(image['path'].split('/')[-1])
    break

/data/qiaowei/coco2014/val2014/COCO_val2014_000000203564.jpg
COCO_val2014_000000203564.jpg


In [25]:
import re

def str_to_id(image_name):
    return int(re.findall("COCO_val2014_0*([1-9]\d*).jpg", image_name)[0])


In [26]:
id_clip_map = {}
for image, clip in zip(df["image"], df["clip"]):
    id = str_to_id(image["path"].split("/")[-1])
    id_clip_map[id] = clip
    

In [27]:
from tqdm import tqdm

id_caption_map = {}
for id_, clip in tqdm(list(id_clip_map.items())):
    with torch.no_grad():
        prefix = torch.tensor(clip, dtype=torch.float32, device = device).unsqueeze(0)
        # print(prefix.shape)
        prefix_embed = model.clip_project(prefix).reshape(1, config['prefix_length'], -1)
    generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
    print(generated_text_prefix)
    id_caption_map[id_] = generated_text_prefix

  next_tokens_source = next_tokens // scores_sum.shape[1]
  0%|          | 2/40504 [00:00<2:06:16,  5.35it/s]

A black and white photo of a clock with a clock on it.
A black motorcycle parked on the side of a driveway.


  0%|          | 4/40504 [00:00<2:07:43,  5.28it/s]

A bathroom with a white toilet and a blue towel.
A car sits on a sidewalk next to a parked car.


  0%|          | 5/40504 [00:00<1:57:19,  5.75it/s]

An airplane is taking off from a runway.
A bathroom with a toilet and a sink.


  0%|          | 8/40504 [00:01<1:59:11,  5.66it/s]

A modern kitchen with a stove top and a microwave oven.
A desk with a computer and a keyboard on it.


  0%|          | 10/40504 [00:01<1:56:07,  5.81it/s]

A bathroom with a sink and a marble counter top.
A white toilet sitting in a bathroom.


  0%|          | 12/40504 [00:02<1:59:36,  5.64it/s]

A woman sitting on a bench on a city street.
A box with a box of donuts on top of it.


  0%|          | 13/40504 [00:02<2:07:01,  5.31it/s]

An old car parked on the side of the road.


  0%|          | 15/40504 [00:02<2:04:24,  5.42it/s]

A kitchen with a stove and a sink.
A group of people riding motorcycles down a street.


  0%|          | 17/40504 [00:03<2:03:55,  5.45it/s]

A cat is standing on a toilet bowl.
A couple of people sitting on a bench next to a bench.


  0%|          | 19/40504 [00:03<1:55:59,  5.82it/s]

A man is sitting on a horse in a street.
A row of motorcycles parked on a street.


  0%|          | 21/40504 [00:03<1:54:44,  5.88it/s]

A cat sitting in a bowl on a table.
A man is looking at a toilet in a bathroom.


  0%|          | 22/40504 [00:03<2:03:23,  5.47it/s]

A white plate with a slice of cake on it.


  0%|          | 24/40504 [00:04<2:10:23,  5.17it/s]

A child sitting in a field with a kite in the background.
A bathroom with a toilet and a sink.


  0%|          | 26/40504 [00:04<1:55:18,  5.85it/s]

A military jet is flying through the air.
Two horses are walking in the street.


  0%|          | 27/40504 [00:04<1:49:52,  6.14it/s]

A purse is sitting on a park bench.


  0%|          | 28/40504 [00:05<4:30:24,  2.49it/s]

a close up of a piece of cake on a table                                                        


  0%|          | 30/40504 [00:06<3:22:46,  3.33it/s]

A toilet in a bathroom with a pink toilet seat.
A desk with a computer and a laptop on top of it.


  0%|          | 31/40504 [00:06<2:20:47,  4.79it/s]

A couple of dogs in the back of a car.





KeyboardInterrupt: 

In [None]:
with open("predict_end.json", "w") as f:
    json.dump(id_caption_map, f)