Skip to content

Commit

Permalink
Fix BLIP2 mixed precision on CPU (salesforce#179)
Browse files Browse the repository at this point in the history
* allow both cpu string and torch.device to be identified for model loading.

* blip2 amp cpu compatibility.

* use dtype=float16 by default.
  • Loading branch information
dxli94 committed Mar 6, 2023
1 parent a557de5 commit baad2d7
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 44 deletions.
3 changes: 2 additions & 1 deletion lavis/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import logging
import torch
from omegaconf import OmegaConf
from lavis.common.registry import registry

Expand Down Expand Up @@ -211,7 +212,7 @@ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
"""
)

if device == "cpu":
if device == "cpu" or device == torch.device("cpu"):
model = model.float()

return model.to(device), vis_processors, txt_processors
Expand Down
30 changes: 21 additions & 9 deletions lavis/models/blip2_models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
Expand Down Expand Up @@ -32,6 +33,16 @@ def init_tokenizer(cls):
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer

def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")

if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()

@classmethod
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
Expand All @@ -42,7 +53,7 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained(
"bert-base-uncased", config=encoder_config
)
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
Expand All @@ -52,16 +63,17 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
):
assert model_name in ["eva_clip_g","clip_L"], "vit model must be eva_clip_g or clip_L"
if model_name=="eva_clip_g":
):
assert model_name in [
"eva_clip_g",
"clip_L",
], "vit model must be eva_clip_g or clip_L"
if model_name == "eva_clip_g":
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
elif model_name=="clip_L":
visual_encoder = create_clip_vit_L(
img_size, use_grad_checkpoint, precision
)
elif model_name == "clip_L":
visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
ln_vision = LayerNorm(visual_encoder.num_features)
return visual_encoder, ln_vision

Expand All @@ -80,7 +92,7 @@ def load_from_pretrained(self, url_or_filename):

msg = self.load_state_dict(state_dict, strict=False)

logging.info("Missing keys {}".format(msg.missing_keys))
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)

return msg
Expand Down
4 changes: 2 additions & 2 deletions lavis/models/blip2_models/blip2_image_text_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def forward(self, samples, match_head="itm"):
image = samples["image"]
caption = samples["text_input"]

with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_embeds = image_embeds.float()
image_embeds = image_embeds.float()
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
Expand Down
32 changes: 18 additions & 14 deletions lavis/models/blip2_models/blip2_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
)
if freeze_vit:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
logging.info("freeze vision encoder")
Expand Down Expand Up @@ -95,7 +95,8 @@ def __init__(

def forward(self, samples):
image = samples["image"]
image_embeds = self.ln_vision(self.visual_encoder(image))
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
Expand Down Expand Up @@ -138,12 +139,13 @@ def forward(self, samples):
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

outputs = self.opt_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
with self.maybe_autocast():
outputs = self.opt_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss

return {"loss": loss}
Expand Down Expand Up @@ -177,9 +179,7 @@ def generate(
captions (list): A list of strings of length batch_size * num_captions.
"""
image = samples["image"]
with torch.cuda.amp.autocast(
enabled=(self.device != torch.device("cpu"))
):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
Expand All @@ -194,7 +194,9 @@ def generate(
)

inputs_opt = self.opt_proj(query_output.last_hidden_state)
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(
image.device
)

if "prompt" in samples.keys():
prompt = samples["prompt"]
Expand All @@ -203,7 +205,9 @@ def generate(

prompt = [prompt] * image.size(0)

opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to(image.device)
opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to(
image.device
)
input_ids = opt_tokens.input_ids
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

Expand Down Expand Up @@ -238,7 +242,7 @@ def generate(

@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model","eva_clip_g")
vit_model = cfg.get("vit_model", "eva_clip_g")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
opt_model = cfg.get("opt_model")
Expand Down
20 changes: 10 additions & 10 deletions lavis/models/blip2_models/blip2_qformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __init__(
)
if freeze_vit:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
self.visual_encoder.train = disabled_train
logging.info("freeze vision encoder")
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features, cross_attention_freq
Expand All @@ -90,7 +90,7 @@ def __init__(
def forward(self, samples):
image = samples["image"]
text = samples["text_input"]

image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
Expand Down Expand Up @@ -247,7 +247,7 @@ def forward(self, samples):
return_dict=True,
labels=labels,
)

loss_lm = lm_output.loss

return BlipOutput(
Expand Down Expand Up @@ -403,9 +403,9 @@ def extract_features(self, samples, mode="multimodal"):
image is not None
), "Image is not provided for mode 'image' or 'multimodal'"
# return query features
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
with self.maybe_autocast():
image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
image_embeds_frozen = image_embeds_frozen.float()
image_embeds_frozen = image_embeds_frozen.float()
image_atts = torch.ones(
image_embeds_frozen.size()[:-1], dtype=torch.long
).to(self.device)
Expand Down Expand Up @@ -443,9 +443,9 @@ def extract_features(self, samples, mode="multimodal"):

elif mode == "multimodal":
# return multimodel query features
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
with self.maybe_autocast():
image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
image_embeds_frozen = image_embeds_frozen.float()
image_embeds_frozen = image_embeds_frozen.float()
image_atts = torch.ones(
image_embeds_frozen.size()[:-1], dtype=torch.long
).to(self.device)
Expand Down Expand Up @@ -482,10 +482,10 @@ def extract_features(self, samples, mode="multimodal"):

@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model","eva_clip_g")
vit_model = cfg.get("vit_model", "eva_clip_g")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
cross_attention_freq = cfg.get("cross_attention_freq",2)
cross_attention_freq = cfg.get("cross_attention_freq", 2)

drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
Expand Down
16 changes: 8 additions & 8 deletions lavis/models/blip2_models/blip2_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(

def forward(self, samples):
image = samples["image"]
image_embeds = self.ln_vision(self.visual_encoder(image))

with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
Expand All @@ -117,7 +119,7 @@ def forward(self, samples):
inputs_t5 = self.t5_proj(query_output.last_hidden_state)
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
with self.maybe_autocast(dtype=torch.bfloat16):
input_tokens = self.t5_tokenizer(
samples["text_input"],
padding="longest",
Expand Down Expand Up @@ -182,9 +184,8 @@ def generate(
captions (list): A list of strings of length batch_size * num_captions.
"""
image = samples["image"]
enable_autocast = self.device != torch.device("cpu")

with torch.cuda.amp.autocast(enabled=enable_autocast):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_embeds = image_embeds.float()
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
Expand Down Expand Up @@ -220,7 +221,7 @@ def generate(

encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

with torch.cuda.amp.autocast(enabled=enable_autocast, dtype=torch.bfloat16):
with self.maybe_autocast(dtype=torch.bfloat16):
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

Expand Down Expand Up @@ -257,7 +258,7 @@ def predict_answers(
**kwargs
):
image = samples["image"]
with torch.cuda.amp.autocast(enabled=(self.device != torch.device("cpu"))):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_embeds = image_embeds.float()
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
Expand Down Expand Up @@ -288,8 +289,7 @@ def predict_answers(

encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

device_type = "cuda" if "cuda" in str(self.device) else "cpu"
with torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16):
with self.maybe_autocast(dtype=torch.bfloat16):
inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

Expand Down

0 comments on commit baad2d7

Please sign in to comment.