In [7]:
import torch
import torch.nn as nn
import random
from transformers import LlamaForCausalLM, LlamaTokenizer
from models.blip2 import Blip2Base

class DetGPT(Blip2Base):
    def __init__(self, vit_model, q_former_model, img_size, llama_model, prompt_path=None, prompt_template="", max_txt_len=32, end_sym='\n', low_resource=False, device_8bit=0):
        super().__init__()

        self.tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.visual_encoder, self.ln_vision = self.init_vision_encoder(vit_model, img_size)
        self.Qformer, self.query_tokens = self.init_Qformer(num_query_token=32, vision_width=self.visual_encoder.num_features)
        self.load_from_pretrained(url_or_filename=q_former_model)

        self.llama_model = LlamaForCausalLM.from_pretrained(llama_model, torch_dtype=torch.float16, load_in_8bit=low_resource, device_map={'': device_8bit} if low_resource else None)
        self.llama_proj = nn.Linear(self.Qformer.config.hidden_size, self.llama_model.config.hidden_size)

        self.max_txt_len = max_txt_len
        self.end_sym = end_sym

        if prompt_path:
            with open(prompt_path, 'r') as f:
                raw_prompts = f.read().splitlines()
            self.prompt_list = [prompt_template.format(p) for p in raw_prompts if "<ImageHere>" in p]

    def encode_img(self, image):
        device = image.device
        image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True)
        inputs_llama = self.llama_proj(query_output.last_hidden_state)
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama

    def prompt_wrap(self, img_embeds, atts_img, prompt):
        batch_size = img_embeds.shape[0]
        p_before, p_after = prompt.split('<ImageHere>')
        p_before_tokens = self.tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_after_tokens = self.tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
        wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
        wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
        return wrapped_img_embeds, wrapped_atts_img

    def forward(self, samples):
        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)

        if "task" in samples:
            task = samples['task']
            task_prompt = ['<ImageHere> ' + t for t in task]
            img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, random.choice(task_prompt))
            text_input = samples['text_input']

        text = [t + self.end_sym for t in text_input]
        to_regress_tokens = self.tokenizer(text, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens=False).to(image.device)
        targets = to_regress_tokens.input_ids.masked_fill(to_regress_tokens.input_ids == self.tokenizer.pad_token_id, -100)

        empty_targets = torch.ones([atts_img.shape[0], atts_img.shape[1]+1], dtype=torch.long).to(image.device).fill_(-100)
        targets = torch.cat([empty_targets, targets], dim=1)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device) * self.tokenizer.bos_token_id
        bos_embeds = self.llama_model.model.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)

        outputs = self.llama_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=targets)
        loss = outputs.loss

        return {"loss": loss}

    @classmethod
    def from_config(cls, cfg):
        return cls(
            vit_model=cfg.get("vit_model", "eva_clip_g"),
            q_former_model=cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"),
            img_size=cfg.get("image_size", 224),
            llama_model=cfg.get("llama_model", "./vicuna-7b-v1.5"),
            prompt_path=cfg.get("prompt_path", ""),
            prompt_template=cfg.get("prompt_template", ""),
            max_txt_len=cfg.get("max_txt_len", 32),
            end_sym=cfg.get("end_sym", '\n'),
            low_resource=cfg.get("low_resource", False),
            device_8bit=cfg.get("device_8bit", 0)
        )


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'detgpt'