In [1]:
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import torch.distributed.nn
from PIL import Image

  import pynvml  # type: ignore[import]


In [2]:
from utils.models import CLIP, Modality_Mergerv3
from utils.vision_transform import CLIPTransform
from utils.model_utils import text_process
from utils import _tokenizer

In [3]:
class I2I_CNCLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 vocab_size: int,
                 text_attention_probs_dropout_prob: float,
                 text_hidden_act: str,
                 text_hidden_dropout_prob: float,
                 text_hidden_size: int,
                 text_initializer_range: float,
                 text_intermediate_size: int,
                 text_max_position_embeddings: int,
                 text_num_attention_heads: int,
                 text_num_hidden_layers: int,
                 text_type_vocab_size: int,
                 tokenizer=_tokenizer,
                 output_dim=512):
        super().__init__()
        self.clip_model = CLIP(embed_dim,
                               # vision
                               image_resolution,
                               vision_layers,
                               vision_width,
                               vision_patch_size,
                               # text
                               vocab_size,
                               text_attention_probs_dropout_prob,
                               text_hidden_act,
                               text_hidden_dropout_prob,
                               text_hidden_size,
                               text_initializer_range,
                               text_intermediate_size,
                               text_max_position_embeddings,
                               text_num_attention_heads,
                               text_num_hidden_layers,
                               text_type_vocab_size,
                               tokenizer)
        self.embed_dim = embed_dim
        self.mm_projection = Modality_Mergerv3(embed_dim, output_dim, layer_num=1)

        self.sex_id_embedding = nn.Embedding(5, embed_dim)
        self.price_id_embedding = nn.Embedding(9, embed_dim)
        self.age_id_embedding = nn.Embedding(8, embed_dim)

        self.transform = CLIPTransform()
        self.tokenizer = text_process(context_length=80, mlm_probability=0.0)
        self.initialize_parameters()

    def initialize_parameters(self):
        pass
    
    def process_data(self, image, text, cate, sex_c2c, price_level_c2c, age_level_c2c):
        image = self.transform(image.convert("RGB")).unsqueeze(0)
        text = self.tokenizer(text).unsqueeze(0)
        cate = self.tokenizer(cate).unsqueeze(0)
        sex_c2c = torch.tensor(sex_c2c).unsqueeze(0)
        price_level_c2c = torch.tensor(price_level_c2c).unsqueeze(0)
        age_level_c2c = torch.tensor(age_level_c2c).unsqueeze(0)
        return image, text, cate, sex_c2c, price_level_c2c, age_level_c2c

    def encode_image_text(self, image, text, cate, sex_c2c, price_level_c2c, age_level_c2c):

        cls_image_feature, all_image_feature = self.clip_model.encode_image(image)
        cls_text_features, all_text_features = self.clip_model.encode_text(text)
        _, all_cate_features = self.clip_model.encode_text(cate)

        pad_index = self.clip_model.tokenizer.vocab['[PAD]']
        attn_mask = text.ne(pad_index).type(self.clip_model.dtype)
        cate_attn_mask = cate.ne(pad_index).type(self.clip_model.dtype)

        c2c_features = torch.stack([self.sex_id_embedding(sex_c2c),
                                     self.price_id_embedding(price_level_c2c),
                                     self.age_id_embedding(age_level_c2c)], dim=1)  # [bs, 3, 512]

        (mm_features_d512, mm_features_d256, mm_features_d128, mm_features_d64,
         mm_features_d32) = self.mm_projection(all_image_feature,
                                               all_text_features, attn_mask,
                                               all_cate_features, cate_attn_mask,
                                               c2c_features
                                               )

        features = torch.cat(
            [mm_features_d512, mm_features_d256, mm_features_d128, mm_features_d64, mm_features_d32],
            dim=-1)

        return cls_image_feature, cls_text_features, features

    def forward(self, data):
        return 

In [4]:
encoder = I2I_CNCLIP(embed_dim=512, image_resolution=224, vision_layers=12, vision_width=768, vision_patch_size=16,
                         vocab_size=21128, text_attention_probs_dropout_prob=0.1, text_hidden_act='gelu',
                         text_hidden_dropout_prob=0.1, text_hidden_size=768, text_initializer_range=0.02,
                         text_intermediate_size=3072, text_max_position_embeddings=512, text_num_attention_heads=12,
                         text_num_hidden_layers=12, text_type_vocab_size=2, output_dim=256)

  import pynvml  # type: ignore[import]


In [10]:
# prepare data
image = Image.open('test.jpg')
text = 'This is an item title'
cate = 'category'
sex_c2c = 0
price_level_c2c = 0
age_level_c2c = 0

In [11]:
# model
input_data = encoder.process_data(image, text, cate, sex_c2c, price_level_c2c, age_level_c2c)
mm_multidimension_embed = encoder.encode_image_text(*input_data)

In [12]:
mm_multidimension_embed

(tensor([[-3.4935e-01,  4.0145e-01, -4.0440e-01,  2.5425e-01, -1.9349e-01,
           7.3589e-01,  1.7450e+00, -1.0388e-02,  1.3663e+00,  7.8796e-03,
           8.3719e-01, -1.9666e+00,  1.3385e-01,  1.5496e-01, -5.6515e-01,
           1.5917e+00, -1.7393e+00,  1.0893e+00,  1.8774e-01, -1.7934e-01,
           1.1228e+00,  7.0981e-02,  4.2332e-02,  2.0804e-01,  5.6650e-02,
           1.0504e-01, -1.2992e-01,  5.9869e-01, -9.6595e-01,  3.9554e-01,
          -1.0446e+00, -1.6147e+00, -1.3114e+00,  8.9954e-01,  1.3597e+00,
          -9.4507e-02,  3.8227e-01,  8.3222e-01, -1.1500e+00, -2.8498e-01,
          -8.6153e-01,  1.8639e+00,  3.6670e-01, -8.0880e-01,  5.9870e-01,
           6.2835e-01, -3.7673e-01,  1.2536e+00,  1.8498e+00,  2.9234e-01,
          -2.9505e-01, -8.4607e-01,  2.7340e-01, -6.0979e-01,  3.4249e-01,
          -1.7489e-01, -1.9689e-01, -2.2345e-01,  1.4845e-01,  1.3392e-01,
          -8.0229e-01, -4.4208e-01,  9.4932e-01, -7.3213e-01, -2.7078e-01,
           1.0461e+00,  3

In [16]:
text.shape

torch.Size([80])