In [1]:
# GPT-4 generated codes
# torch实现一个结合image和text的多模态model。

## GPT-4 Generation

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

from transformers import BertModel, BertTokenizer
from torchvision import models, transforms

### model_construction

#### image model

In [6]:
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fn = nn.Identity()  # 恒等函数
    
    def forward(self, x):
        return self.resnet(x)

#### text model

In [8]:
class TextModel(nn.Module):
    def __init__(self):
        super(TextModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        return outputs.pooler_output  # 使用bert的pooling output

#### multi-modal integration

In [34]:
class MultiModalModel(nn.Module):
    def __init__(self):
        super(MultiModalModel, self).__init__()
        self.image_model = ImageModel()
        self.text_model = TextModel()
        self.fc = nn.Linear(1768, 1)  # 假设是二分类任务
    
    def forward(self, image, input_ids, attention_mask):
        image_features = self.image_model(image)
        print('image_features_shape: {}'.format(image_features.shape))
        text_features = self.text_model(input_ids, attention_mask)
        print('text_features_shape: {}'.format(text_features.shape))
        combined_features = torch.cat((image_features, text_features), dim=1)
        print('combined_features_shape: {}'.format(combined_features.shape))
        output = self.fc(combined_features)
        return output

### data_process

In [13]:
# 数据预处理
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [14]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [19]:
def preprocess_data(image, text):
    image = image_transforms(image)
    text_tokens = tokenizer(text, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
    return image, text_tokens['input_ids'].squeeze(), text_tokens['attention_mask'].squeeze(),

### running

In [16]:
# image data
from PIL import Image

In [18]:
image = Image.open('./data/five.jpg')
text = 'five is a super pretty girl.'

In [20]:
image, input_ids, attention_mask = preprocess_data(image, text)
image = image.unsqueeze(0)  # 扩展batch dimension
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)

In [35]:
# initilize model
model = MultiModalModel()

In [26]:
# 定义loss和optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [36]:
# feed-forward
output = model(image, input_ids, attention_mask)
print(output)

image_features_shape: torch.Size([1, 1000])
text_features_shape: torch.Size([1, 768])
combined_features_shape: torch.Size([1, 1768])
tensor([[-0.1499]], grad_fn=<AddmmBackward0>)


In [38]:
output.shape

torch.Size([1, 1])

In [41]:
# label
label = torch.tensor([1.0])

loss = criterion(output.squeeze(dim=-1), label)
print(loss)

tensor(0.7709, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [42]:
# BP
loss.backward()
optimizer.step()

### save_model

In [None]:
torch.save(model.)