Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions Losses/BLIPLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import torch
import torch.nn.functional as F
from Losses.LossInterface import LossInterface
from torchvision import transforms
from util import wget_file
from blip.blip_itm import blip_itm


blip_checkpoint_table = {
"model_base_retrieval_coco": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth",
}


def parse_prompt(prompt):
vals = prompt.rsplit(':', 2)
vals = vals + ['', '1', '-inf'][len(vals):]
return vals[0], float(vals[1]), float(vals[2])


class BLIPLoss(LossInterface):
def __init__(self,**kwargs):
super().__init__(**kwargs)
self.blip_model = "model_base_retrieval_coco"

checkpoint_path = f'models/blip_{self.blip_model}.ckpt'
if not os.path.exists(checkpoint_path):
wget_file(blip_checkpoint_table[self.blip_model], checkpoint_path)

self.image_size = 384
self.model = blip_itm(pretrained=checkpoint_path, image_size=self.image_size, vit='base')
self.model.eval()
self.model.requires_grad_(False)
self.model.to(self.device)

self.normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

@staticmethod
def add_settings(parser):
return parser

def get_loss(self, cur_cutouts, out, args, globals=None, lossGlobals=None):
text_prompts, weights, _ = zip(*[parse_prompt(prompt) for prompt in args.prompts])

text = self.model.tokenizer(
text_prompts,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
)

text_output_itc = self.model.text_encoder(
text.input_ids.to(self.device),
attention_mask=text.attention_mask.to(self.device),
return_dict=True,
mode="text",
)

text_features = F.normalize(
self.model.text_proj(text_output_itc.last_hidden_state[:, 0, :]),
dim=-1,
)

max_size = max([size for size in cur_cutouts.keys()])
images = cur_cutouts[max_size]
if images.shape[-2:] != (self.image_size, self.image_size):
images = F.interpolate(
images,
size=self.image_size,
mode="bicubic",
align_corners=False,
)

image_embeds = self.model.visual_encoder(self.normalize(images))

image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
self.device
)

text_output_itm = self.model.text_encoder(
text.input_ids.to(self.device).repeat(len(images), 1),
attention_mask=text.attention_mask.repeat(len(images), 1).to(
self.device
),
encoder_hidden_states=image_embeds.repeat(len(text_prompts), 1, 1),
encoder_attention_mask=image_atts.repeat(len(text_prompts), 1),
return_dict=True,
)
itm_loss = -(F.softmax( # softmax in original. optimizing logit gives it a huge strength
self.model.itm_head(
text_output_itm.last_hidden_state[:, 0, :].to(self.device)
),
dim=1,
)[:, 1] * torch.tensor(weights).repeat(len(images)).to(self.device)).mean()

image_features = F.normalize(
self.model.vision_proj(image_embeds[:, 0, :]), dim=-1
)

spherical_distance_itc = (
(image_features[None, :] - text_features[:, None])
.norm(dim=-1)
.div(2)
.arcsin()
.square()
.mul(2)
.mul(torch.tensor(weights)[:, None].to(self.device))
).mean()
return (spherical_distance_itc + itm_loss) / 2
Empty file added Losses/__init__.py
Empty file.
Empty file added blip/__init__.py
Empty file.
237 changes: 237 additions & 0 deletions blip/blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import warnings
warnings.filterwarnings("ignore")

import os
from blip.vit import VisionTransformer, interpolate_pos_embed
from blip.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from urllib.parse import urlparse
from timm.models.hub import download_cached_file


class BLIP_Base(nn.Module):
def __init__(self,
med_config = "",
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()

self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)


def forward(self, image, caption, mode):

assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)

if mode=='image':
# return image features
image_embeds = self.visual_encoder(image)
return image_embeds

elif mode=='text':
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
return text_output.last_hidden_state

elif mode=='multimodal':
# return multimodel features
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
return output.last_hidden_state



class BLIP_Decoder(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
prompt = 'a picture of ',
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()

self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)

self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1


def forward(self, image, caption):

image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)

text.input_ids[:,0] = self.tokenizer.bos_token_id

decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100

decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss

return loss_lm

def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)

if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)

image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}

prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
input_ids[:,0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]

if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)

captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):])
return captions


def blip_decoder(pretrained='',**kwargs):
model = BLIP_Decoder(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model

def blip_feature_extractor(pretrained='',**kwargs):
model = BLIP_Base(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model

def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer


def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):

assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width

def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")

def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')

state_dict = checkpoint['model']

state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

Loading