In [1]:
# model implementation

from typing import Dict

import torch
from torch import nn
from transformers import CLIPVisionConfig, CLIPVisionModel, GPT2LMHeadModel, SwinModel


class VisionModel(nn.Module):
    def __init__(
        self, model_name: str, out_features: int, frozen_backbone: bool = False
    ):
        super().__init__()
        self.model_name = model_name
        self.model = self._create_model()

        self.output_dimension = self._get_output_dimension()
        self.projection_features = nn.Linear(
            in_features=self.output_dimension,
            out_features=out_features,
        )

        if frozen_backbone:
            for p in self.model.parameters():
                p.required_grad = False

    def _create_model(self):
        #config = CLIPVisionConfig.from_pretrained(self.model_name)
        model = SwinModel.from_pretrained(self.model_name)
        return model

    def _get_output_dimension(self):
        return self.model.config.hidden_size


    def forward(self, pixel_values: torch.tensor):
        batch_size = pixel_values.shape[0]
        embeddings = self.model(pixel_values)

        if not isinstance(embeddings, torch.Tensor):
            embeddings = embeddings.pooler_output

        embeddings = embeddings.reshape(batch_size, self.output_dimension)
        embeddings = self.projection_features(embeddings)

        return embeddings


class LanguageModel(GPT2LMHeadModel):
    def __init__(self, config):
        super(LanguageModel, self).__init__(config)

        self.n_embd = self.config.n_embd

    def forward(
        self,
        input_ids=None,
        image_token_embeddings=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        image_token_mask=None,
    ):
        # Project image embeddings to token embeddings
        if image_token_embeddings is not None and image_token_mask is not None:
            inputs_embeds = self.transformer.wte(input_ids)
            ind = image_token_mask.nonzero(as_tuple=True)
            # token 개수 만큼으로 reshape
            image_token_embeddings = image_token_embeddings.reshape(-1, self.n_embd)
            inputs_embeds[ind] = image_token_embeddings.type(inputs_embeds.dtype)
            input_ids = None

        return super().forward(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

    def generate(
        self,
        input_ids=None,
        image_token_embeddings=None,
        max_length=None,
        **generate_kwargs
    ):
        # Project image embeddings to token embeddings
        if image_token_embeddings is not None:
            batch_size = image_token_embeddings.shape[0]
            if input_ids is not None:
                inputs_embeds = self.transformer.wte(input_ids)
                image_token_embeddings = image_token_embeddings.reshape(batch_size, -1, self.n_embd)
                #inputs_embeds[:, 0] = image_token_embeddings.type(inputs_embeds.dtype)
                inputs_embeds = image_token_embeddings
                input_ids = None
            else:
                image_token_embeddings = image_token_embeddings.reshape(batch_size, -1, self.n_embd)
                inputs_embeds = image_token_embeddings

            input_ids = None
            return super().generate(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                max_length=max_length,
                **generate_kwargs
            )
        else:
            return super().generate(
                input_ids=input_ids, max_length=max_length, **generate_kwargs
            )


class EncoderDecoder(nn.Module):
    def __init__(
        self, encoder: VisionModel, decoder: LanguageModel, config: Dict[str, bool]
    ):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

        if config["encoder_frozen"]:
            for p in self.encoder.model.parameters():
                p.requires_grad = False

        if config["decoder_frozen"]:
            for p in self.decoder.parameters():
                p.requires_grad = False

    def forward(self, *args, **kwargs):
        if "pixel_values" in kwargs:
            pixel_values = kwargs.pop("pixel_values")
            kwargs["image_token_embeddings"] = self.encoder(pixel_values)
        return self.decoder.forward(*args, **kwargs, return_dict=True)

    def generate(self, *args, **kwargs):
        if "pixel_values" in kwargs:
            pixel_values = kwargs.pop("pixel_values")
            kwargs["image_token_embeddings"] = self.encoder(pixel_values)
        return self.decoder.generate(*args, **kwargs, return_dict_in_generate=True)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

from transformers import GPT2Config, GPT2Tokenizer
vm_name = "microsoft/swin-base-patch4-window12-384-in22k" #"google/vit-large-patch32-384" #"laion/CLIP-ViT-L-14-laion2B-s32B-b82K" 
lm_name = "gpt2-medium" 

config = GPT2Config.from_pretrained(lm_name)
tokenizer = GPT2Tokenizer.from_pretrained(lm_name)
n_image_token = 4
lm = LanguageModel.from_pretrained(lm_name, config=config)
vm = VisionModel(model_name=vm_name, out_features=config.n_embd* n_image_token)


tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens(["<IMG>"])
# Resize the model's token embeddings to account for the new token(s)
lm.resize_token_embeddings(len(tokenizer))

model = EncoderDecoder(
        encoder=vm,
        decoder=lm,
        config={
            "encoder_frozen": True,
            "decoder_frozen": True,
        }
    )
model.eval()

Some weights of the model checkpoint at microsoft/swin-base-patch4-window12-384-in22k were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


EncoderDecoder(
  (encoder): VisionModel(
    (model): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0): SwinLayer(
                (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=128, out_features=128, bias=True)
                    (key): Linear(in_features=128, out_features=128, bias=True)
                    (value): Linear(in_features=128, out_features=128, bias=True)
                    (dropout): Dropout(p=0.0, inplace=False)
    

In [3]:
ckpt = torch.load('/home/jaehoon_scatterlab_co_kr/projects/image-to-prompts/results/swin-0414-384-v6/best-model.ckpt')
model.load_state_dict(ckpt)

<All keys matched successfully>

In [4]:
import cv2
IMAGENET_MEAN_RGB = [0.485, 0.456, 0.406]
IMAGENET_STD_RGB = [0.229, 0.224, 0.225]

def process(path, image_size=(224,224)):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, image_size, interpolation=cv2.INTER_AREA)
    image = image / 255.0
    image = (image - IMAGENET_MEAN_RGB) / IMAGENET_STD_RGB
    image = image.transpose(2, 0, 1)
    pixel_values = torch.tensor(image, dtype=torch.float)

    return pixel_values

In [5]:
flist = os.listdir('../datasets/test/')

for f in flist:
    pixel_values = process('../datasets/test/'+f, image_size=(384,384))
    pixel_values = pixel_values.unsqueeze(0)
    
    out = model.generate(input_ids=None, pixel_values=pixel_values, max_length=32, do_sample=True, top_p=0.1, no_repeat_ngram_size=3, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
    out_sequence = tokenizer.decode(out.sequences.tolist()[0], skip_special_tokens=True)
    print(f)
    print(out_sequence)
    

a4e1c55a9.png
a young girl with a black and white portrait of the artist, anagrams in red ink on canvas.
 <IMG>  I am not sure what this picture
227ef0887.png
a man with a mask, wearing an old school classic movie poster.
 <IMG>  the great white shark in his bedroom at night on top of michael j
c98f79f71.png
a man with a gun, in the background of an old school movie theatre.
 <IMG>  I love this picture by David Bowie and his girlfriend on screen


f27825b2c.png
a man with a gun, in the background of an old school movie theatre.
 <IMG>  I love this picture by David Bowie and his girlfriend on screen


92e911621.png
a man with a gun, in the background of an old school movie poster.
 <IMG>  I love this picture by David Bowie and his girlfriend at night on
20057f34d.png
a man with a big black beard, wearing an old school classic
 <IMG>  hissing sound effect of the movie "The Great Movie Soundtrack" by John
d8edf2e40.png
a man with a mask, wearing an eyepatch and glasses
 <IMG>  the image o

In [8]:
flist = os.listdir('../datasets/test/')

for f in flist:
    pixel_values = process('../datasets/test/'+f)
    pixel_values = pixel_values.unsqueeze(0)
    
    out = model.generate(input_ids=None, pixel_values=pixel_values, max_length=32, do_sample=True, top_p=0.1, no_repeat_ngram_size=3, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id)
    out_sequence = tokenizer.decode(out.sequences.tolist()[0], skip_special_tokens=True)
    print(f)
    print(out_sequence)
    

a4e1c55a9.png
a robot with a mustache and beard, holding an axe in its hand. ink on paper style of the 1800 s vintage magazine illustration by greg rut
227ef0887.png
wood carving of a rose, intricate details and elegant design. high quality product photo by greg rutkowski with detailed background in the style james ens
c98f79f71.png
a portrait of a man with the head and body proportions, face features,of an orangutan wearing ornate jewelry. oil on canvas by greg
f27825b2c.png
a person working at a donut shop, looking up to the sky and seeing an apple in front of them. they are confused because it is not there
92e911621.png
a cute dinosaur eating a slice of pizza in the forest, digital art style by greg rutkowski and james c. lauchner with
20057f34d.png
a hole in the ground, a cave entrance with an upside down pyramid on top of it., digital art by greg rutkowski and james
d8edf2e40.png
a astronaut in a cherry blossom dress, standing on top of an overgrown city street with the moon behin