In [2]:
import torch

from idavatar.unet import UNetModel
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from transformers import CLIPTextModel

from transformers.models.clip.modeling_clip import (
    _expand_mask,
    CLIPTextTransformer,
    CLIPPreTrainedModel,
    CLIPModel,
)

model_path = 'D:/apps/tools/programming/.cache/huggingface/hub/models--SG161222--Realistic_Vision_V5.1_noVAE/snapshots/9cd4afd23ecbf0348e2c46f4ac712dbf032da73c'

In [5]:
class IDAvatarTextEncoder(CLIPPreTrainedModel):
    _build_causal_attention_mask = CLIPTextTransformer._build_causal_attention_mask

    @staticmethod
    def from_pretrained(model_name_or_path, **kwargs):
        model = CLIPTextModel.from_pretrained(model_name_or_path, **kwargs)
        text_model = model.text_model
        return IDAvatarTextEncoder(text_model)

    def __init__(self, text_model):
        super().__init__(text_model.config)
        self.config = text_model.config
        self.final_layer_norm = text_model.final_layer_norm
        self.embeddings = text_model.embeddings
        self.encoder = text_model.encoder

    def forward(
        self,
        input_ids,
        use_causual_mask=True,
    ):

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        hidden_states = self.embeddings(input_ids)

        bsz, seq_len = input_shape

        if use_causual_mask:
            causal_attention_mask = self._build_causal_attention_mask(
                                        bsz, seq_len, hidden_states.dtype
                                    ).to(self.device)
            
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            causal_attention_mask=causal_attention_mask,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
            input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
                dim=-1
            ),
        ]

        return (last_hidden_state, pooled_output) + encoder_outputs[1:]
    
text_encoder = IDAvatarTextEncoder.from_pretrained(model_path, subfolder='text_encoder')
input_ids = torch.randint(low=1, high=400, size=(2, 77))
a, b = text_encoder(input_ids)
a.shape, b.shape

(torch.Size([2, 77, 768]), torch.Size([2, 768]))

In [3]:
import torchvision.transforms as T
import torch.nn.functional as F
import open_clip
import torch.nn as nn
# clip_path = 'D:/apps/tools/programming/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff'

class IDAvatarCLIPImageEncoder(nn.Module):
    @staticmethod
    def from_pretrained(
        global_model_name_or_path,
    ):
        model,_,_ = open_clip.create_model_and_transforms('ViT-L-14',global_model_name_or_path)
        vision_model = model.visual
        vision_model.output_tokens = True
        vision_processor = T.Normalize(
            (0.48145466, 0.4578275, 0.40821073),
            (0.26862954, 0.26130258, 0.27577711),
        )
        return IDAvatarCLIPImageEncoder(
            vision_model,
            vision_processor,
        )

    def __init__(
        self,
        vision_model,
        vision_processor,
    ):
        super().__init__()
        self.vision_model = vision_model
        self.vision_processor = vision_processor

        self.image_size = vision_model.image_size

    def forward(self, person_pixel_values):
        b, c, h, w = person_pixel_values.shape

        if (h, w) != self.image_size:
            h, w = self.image_size
            person_pixel_values = F.interpolate(
                person_pixel_values, (h, w), mode="bilinear", antialias=True
            )# b, c, h, w -> b, c, 224, 224
        person_pixel_values = self.vision_processor(person_pixel_values) 
        person_embeds, patch_features = self.vision_model(person_pixel_values)# b, 1048; b, 256, 1280
        person_embeds = person_embeds.view(b, 1, -1) # b, 1, 1280
        return person_embeds, patch_features
    
image_encoder = IDAvatarCLIPImageEncoder.from_pretrained('datacomp_xl_s13b_b90k')
images = torch.rand((2, 3, 256, 256))
a, b = image_encoder(images)
a.shape, b.shape
# image_encoder.image_size

(torch.Size([2, 1, 768]), torch.Size([2, 256, 1024]))

In [11]:
image_encoder(images)[0].shape,image_encoder(images)[1].shape

(torch.Size([1024]), torch.Size([1024]))