In [1]:
import os

import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader

from PIL import Image

from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, TrainerCallback

In [2]:
class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()

    def forward(self, images):
        features = self.resnet(images)
        return features

In [3]:
def load_semtypes(file_path):
    data = [line.split() for line in open(file_path)]
    data = [[t[0], " ".join(t[1:])] for t in data]
    data = pd.DataFrame(data)
    data.columns = ["id", "semtypes"]
    return data

In [4]:
class CaptioningDataset(Dataset):
    def __init__(self, image_dir, data_file, semtypes_file, tokenizer):
        self.image_dir = image_dir        
        self.data = pd.read_csv(data_file)
        self.semtypes = load_semtypes(semtypes_file)

        # TODO (REMOVE DOWNSIZING)
        self.data = self.data[:int(len(self.data) * 0.001)]
        self.semtypes = self.semtypes[:int(len(self.semtypes) * 0.001)]

        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.data.loc[index, 'name'])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        caption = self.data.loc[index, 'caption']
        semtypes = self.semtypes.loc[index, 'semtypes']
        encoding = self.tokenizer.encode_plus(caption, semtypes, padding='max_length', max_length=128, truncation=True)
#         encoding = self.tokenizer.encode_plus(caption, semtypes, padding='longest')

        return {
            'image': image,
            'input_ids': torch.tensor(encoding['input_ids']),
            'attention_mask': torch.tensor(encoding['attention_mask']),
        }

In [5]:
class VLT5Model(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.image_encoder = ImageEncoder()
        self.visual_projection = nn.Linear(2048, config.d_model//6)

    def forward(self, input_ids, attention_mask, image):
        image_features = self.image_encoder(image)
        projected_features = self.visual_projection(image_features)
        
#         print(input_ids)

#         input_ids = torch.cat((projected_features, input_ids[:, 1:]), dim=1)
#         attention_mask = torch.cat((torch.ones_like(projected_features[:, :1]), attention_mask[:, 1:]), dim=1)

#         outputs = super().forward(
#             input_ids=input_ids.long(),
#             attention_mask=attention_mask.long(),
#             decoder_input_ids=input_ids.long(),
#             decoder_attention_mask=attention_mask.long()
#         )

        print(input_ids.shape)
        print(attention_mask.shape)
        print(projected_features.shape)

        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=(projected_features,),
            decoder_input_ids=input_ids,
            decoder_attention_mask=attention_mask,
            return_dict=True
        )

        return outputs.logits

In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = VLT5Model.from_pretrained('t5-base')

Some weights of VLT5Model were not initialized from the model checkpoint at t5-base and are newly initialized: ['image_encoder.resnet.conv1.weight', 'image_encoder.resnet.bn1.weight', 'image_encoder.resnet.bn1.bias', 'image_encoder.resnet.bn1.running_mean', 'image_encoder.resnet.bn1.running_var', 'image_encoder.resnet.layer1.0.conv1.weight', 'image_encoder.resnet.layer1.0.bn1.weight', 'image_encoder.resnet.layer1.0.bn1.bias', 'image_encoder.resnet.layer1.0.bn1.running_mean', 'image_encoder.resnet.layer1.0.bn1.running_var', 'image_encoder.resnet.layer1.0.conv2.weight', 'image_encoder.resnet.layer1.0.bn2.weight', 'image_encoder.resnet.layer1.0.bn2.bias', 'image_encoder.resnet.layer1.0.bn2.running_mean', 'image_encoder.resnet.layer1.0.bn2.running_var', 'image_encoder.resnet.layer1.0.conv3.weight', 'image_encoder.resnet.layer1.0.bn3.weight', 'image_encoder.resnet.layer1.0.bn3.bias', 'image_encoder.resnet.layer1.0.bn3.running_mean', 'image_encoder.resnet.layer1.0.bn3.running_var', 'image_en

In [7]:
train_image_dir = './all_data/train/radiology/images/'
valid_image_dir = './all_data/validation/radiology/images/'
train_data_file = './all_data/train/radiology/traindata.csv'
valid_data_file = './all_data/validation/radiology/valdata.csv'
train_semtypes_file = './all_data/train/radiology/semtypes.txt'
valid_semtypes_file = './all_data/validation/radiology/semtypes.txt'

In [8]:
train_dataset = CaptioningDataset(train_image_dir, train_data_file, train_semtypes_file, tokenizer)
valid_dataset = CaptioningDataset(valid_image_dir, valid_data_file, valid_semtypes_file, tokenizer)

In [9]:
train_dataset.__len__()

65

In [10]:
training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=50,
    save_total_limit=2,
    overwrite_output_dir=True,
    learning_rate=1e-4,
    warmup_steps=10,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='steps',
    eval_steps=50,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

In [11]:
trainer.train()
trainer.save_model('./trained_model')

torch.Size([8, 128])
torch.Size([8, 128])
torch.Size([8, 128])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x128 and 768x768)