In [1]:
# GPT generated cross-attention codes
# cross-attention用于在不同模态之间(比如text and image)进行信息交互。
# image embeddings as q to attention with text embeddings as k,v

## Cross_Attention: image to text

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertModel, BertTokenizer
from torchvision import models, transforms

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### cross_attention

In [4]:
class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        # q, k, v projection layers
        self.q_proj = nn.Linear(d_model, d_model)  # (2048, 2048)
        self.k_proj = nn.Linear(768, d_model)  
        self.v_proj = nn.Linear(768, d_model)
        
        # output projection layer
        self.out_proj = nn.Linear(d_model, d_model)
        
        # Scaling factor
        self.scale = torch.sqrt(torch.FloatTensor([d_model // num_heads])).to(device)
    
    def forward(self, queries, keys, values, mask=None):  # image, text, text
        batch_size = queries.size(0)
        
        # Linear projections
        queries = self.q_proj(queries)  # (batch_size, seq_len, d_model)
        keys = self.k_proj(keys)
        values = self.v_proj(values)
        
        # Split into multiple heads and transpose
        queries = queries.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1, 2)
        keys = keys.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1, 2)
        values = values.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention weights to values
        attention_output = torch.matmul(attention_weights, values)
        
        # Concatenate multiple heads and put through final linear layer
        attention_output = attention_output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)  # contiguous表示连续存储。
        output = self.out_proj(attention_output)
        
        return output, attention_weights

### multi-modal

In [5]:
class MultiModalModel(nn.Module):
    def __init__(self, image_model, text_model, d_model, num_heads):
        super(MultiModalModel, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.cross_attention = CrossAttention(d_model, num_heads)
        self.fc = nn.Linear(d_model, 1)   # 假设二分类任务
        
    def forward(self, images, input_ids, attention_mask):
        # image embeddings
        image_features = self.image_model(images)
        print('image_feature_shape:{}'.format(image_features.shape))
        
        # text embeddings
        text_outputs = self.text_model(input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state
        print('text_feature_shape:{}'.format(text_features.shape))
        
        '''
        transformers.BertModel默认输出是BaseModelOutputWithPoolingAndCrossAttention，这是一个包含多个字段命名元组NamedTuple。
        last_hidden_state是bert_model最后一层的隐藏状态, 是一个shape=(batch_size, seq_len, hidden_size)的tensor。
        NamedTuple还包括的参数：
            - pooler_output，表示池化后的输出，这个输出可以用于分类任务。
            - attention_weights,
        '''
        
        # cross attention
        cross_output, attention_weights = self.cross_attention(image_features.unsqueeze(1), text_features, text_features)
        
        # classfication
        output = self.fc(cross_output.squeeze(1))
        return output, attention_weights

### image model

In [6]:
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  # 移除ResNet模型的最后一层全连接层
        '''
        ResNet模型的最后一层通常是一个全连接层，用于分类任务。对于pretrained ResNet-50模型，这个全连接层的的作用是讲ResNet的output features
        (通常是一个2048维的向量)映射到一个指定数量的类别上。例如，Rest在ImageNet数据集上预训练，最后一层全连接层输出维度是1000，对应于ImageNet
        的1000个类别。在多模态学习中，我们通常不需要最后一层的分类器，而是需要获取iamge embeddings，以便与其他模态(i.e. text)特征进行融合。
        这种情况下，我们需要移除最后一层全连接层，只保留ResNet的特征提取部分。
        nn.Identity作用是torch的一个占位符层，它不改变输入的值，只是简单地返回输入。使用nn.Identity可以方便地移除某一层而不改变模型地其他部分!!!
        '''
    
    def forward(self, x):
        return self.resnet(x)

### text model

In [7]:
text_model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

### data_process

In [8]:
from PIL import Image

In [9]:
# data input
image = Image.open('./data/five.jpg')
text = 'five is a super pretty girl.'

In [10]:
# data preprocess
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [11]:
image = image_transforms(image).unsqueeze(0)
# tokenizer分词对象，来自transformers.BertTokenizer或RobertaTokenizer; truncation=True对长度超过max_length的文本进行截断；
# padding填充到最大长度。
text_tokens = tokenizer(text, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
input_ids = text_tokens['input_ids']
attention_mask = text_tokens['attention_mask']

### running

In [12]:
# configs
d_model = 2048
num_heads = 8

# initialize model
image_model = ImageModel()
multi_modal_model = MultiModalModel(image_model, text_model, d_model, num_heads)



In [13]:
image.shape

torch.Size([1, 3, 224, 224])

In [14]:
input_ids.shape

torch.Size([1, 128])

In [15]:
# feedforward
output, attention_weights = multi_modal_model(image, input_ids, attention_mask)

image_feature_shape:torch.Size([1, 2048])
text_feature_shape:torch.Size([1, 128, 768])


In [16]:
output

tensor([[0.0751]], grad_fn=<AddmmBackward0>)

In [17]:
attention_weights.shape  # cross_attention_weights

torch.Size([1, 8, 1, 128])