In [1]:
import torch
import torchvision
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageEncoder(nn.Module):
    def __init__(self, embed_size):
        super(ImageEncoder, self).__init__()
        self.resnet = nn.Sequential(*list(torchvision.models.resnet50(pretrained=True).children())[:-1])
        self.linear = nn.Linear(2048, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        out_linear = self.linear(features)
        embeddings = self.bn(out_linear)
        return embeddings


class SemanticEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(SemanticEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)

    def forward(self, texts):
        embeddings = self.embedding(texts)
        hiddens, _ = self.lstm(embeddings)

        return hiddens[:, -1]


class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        # lstm cell
        self.lstm_cell = nn.LSTMCell(input_size=embed_size + hidden_size, hidden_size=hidden_size)
        # output fully connected layer
        self.fc_out = nn.Linear(in_features=hidden_size, out_features=vocab_size)
        # embedding layer
        self.embed = nn.Embedding(num_embeddings=vocab_size + hidden_size, embedding_dim=embed_size + hidden_size)
        # activations
        self.softmax = nn.Softmax(dim=1)

    def forward(self, features, captions):
        # batch size
        batch_size = features.size(0)
        # init the hidden and cell states to zeros
        hidden_state = torch.zeros((batch_size, self.hidden_size)).to(device)
        cell_state = torch.zeros((batch_size, self.hidden_size)).to(device)
        # define the output tensor placeholder
        outputs = torch.empty((batch_size, captions.size(1), self.vocab_size)).to(device)
        # embed the captions
        captions_embed = self.embed(captions)
        # pass the caption word by word
        for t in range(captions.size(1)):
            # for the first time step the input is the feature vector
            if t == 0:
                hidden_state, cell_state = self.lstm_cell(features, (hidden_state, cell_state))
            # for the 2nd+ time step, using teacher forcer
            else:
                hidden_state, cell_state = self.lstm_cell(captions_embed[:, t, :], (hidden_state, cell_state))
            # output of the attention mechanism
            out = self.fc_out(hidden_state)
            # build the output tensor
            outputs[:, t, :] = out
        return outputs

    def sample(self, features, max_seg_length):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        # inputs = features.unsqueeze(1)
        batch_size = features.size(0)
        hidden_state = torch.zeros((batch_size, self.hidden_size)).to(device)
        cell_state = torch.zeros((batch_size, self.hidden_size)).to(device)
        for i in range(max_seg_length):
            # hiddens, states = self.lstm(inputs, states)  # hiddens: (batch_size, 1, hidden_size)
            hidden_state, cell_state = self.lstm_cell(features, (hidden_state, cell_state))
            # outputs = self.linear(hiddens.squeeze(1))  # outputs:  (batch_size, vocab_size)
            outputs = self.fc_out(hidden_state)
            _, predicted = outputs.max(1)  # predicted: (batch_size)
            sampled_ids.append(predicted)
            features = self.embed(predicted)  # inputs: (batch_size, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)  # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import torch
from torchvision import transforms
from PIL import Image
import json
from nltk.tokenize import word_tokenize
from torch.utils.data import DataLoader
import os
import pickle
from ipynb.fs.full.vocabulary_builder import Vocabulary
from ipynb.fs.full.data_loader import ROCODataset


class CaptionGenerator:
    def __init__(self, vocab_file, embed_size, hidden_size, num_layers, image_encoder_path, semantic_encoder_path,
                 decoder_path):
        
        with open(vocab_file, 'rb') as f:
            self.vocab = pickle.load(f)
        # convert vocab into index2word format
        vocab_size = 2200

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            map_location = torch.device('cuda')
        else:
            map_location = torch.device('cpu')

        self.image_encoder = ImageEncoder(embed_size).to(self.device)
        self.semantic_encoder = SemanticEncoder(vocab_size, embed_size, hidden_size, num_layers).to(self.device)
        self.decoder = Decoder(vocab_size, embed_size, hidden_size).to(self.device)

        # Load trained weights (assuming you saved them earlier as 'encoder.pth' and 'decoder.pth')
        self.image_encoder.load_state_dict(torch.load(image_encoder_path, map_location=map_location))
        self.semantic_encoder.load_state_dict(torch.load(semantic_encoder_path, map_location=map_location))
        self.decoder.load_state_dict(torch.load(decoder_path, map_location=map_location))

        # Set to evaluation mode
        self.image_encoder.eval()
        self.semantic_encoder.eval()
        self.decoder.eval()

        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def generate_caption(self, image_path, max_length=20):
        # Load and preprocess the image
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)

        # Get the feature vector from the encoder
        features_image = self.image_encoder(image)
        # tokens = word_tokenize('CTLA venous phase of donor'.lower())
        caption = [self.vocab('<START>')]
        # caption.extend([self.vocab.get(token, self.vocab['<UNK>']) for token in tokens])
        # caption.append(self.vocab['<END>'])
        print(caption)
        features_text = self.semantic_encoder(torch.Tensor([caption]).long()[:, :torch.Tensor([caption]).shape[1]])

        combined_features = torch.cat((features_image, features_text), dim=1)
        print(combined_features.size())

        # Initialize the LSTM hidden state with the feature vector
        # hidden = combined_features.unsqueeze(0).long()

        # Start with the <START> token
        # print(self.vocab['<START>'])
        # input_token = torch.tensor([[self.vocab['<START>']]], device=self.device).long()
        # print('caption size', torch.Tensor([caption]).long().size())
        # print('input_token', input_token.size())
        gen_caption = []
        predicted_token_ids = self.decoder.sample(combined_features, max_seg_length=20)
        print(predicted_token_ids)
        for id in predicted_token_ids[0].tolist():
            word = self.vocab_i2w[id]
            if word == '<END>':
                break
            gen_caption.append(word)
        # Generate caption
        # for _ in range(max_length):
        #     output = self.decoder(combined_features, input_token)
        #     predicted_token_ids = torch.argmax(output, dim=2)
        #     print('predicted_token_ids', predicted_token_ids.size())
        #     # _, predicted = output.max(1)
        #     # print(predicted.size())
        #     input_token = predicted_token_ids
        #     word = self.vocab_i2w[predicted_token_ids.item()]
        #     if word == '<END>':
        #         break
        #     caption.append(word)
        #
        return ' '.join(gen_caption)

    def process_batch(self, batch_size, max_seg_length, json_data_path='selected_dataset/selected_dataset_info.json'):
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        test_dataset = ROCODataset(data_json='selected_dataset/selected_dataset_info.json',
                                transform=transform,
                                vocab=self.vocab,
                                dataset_type='test')

        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True,
                                collate_fn=ROCODataset.collate_fn,
                                num_workers=2)

        processed_captions = []
        GT_captions = []
        for i, (images_val, GT, _) in enumerate(test_loader):
            caps = []
            for i in range(images_val.size()[0]):
                caps.append([self.vocab('<START>')])
            captions_val = torch.Tensor(caps).long()
            
            images = images_val.to(self.device)
            image_features = self.image_encoder(images)
            text_features = self.semantic_encoder(captions_val[:, :captions_val.shape[1]].to(self.device))
            combined_features = torch.cat((image_features, text_features), dim=1)
            
            # captions_target = captions_val.to(self.device)
            # predictions = torch.argmax(self.decoder(combined_features, captions_target), dim=2)
            
            predictions = self.decoder.sample(combined_features, max_seg_length=max_seg_length)
            
            for prediction in predictions:
                # Convert word_ids to words
                sampled_caption = []
                for word_id in prediction.detach().numpy():
                    word = self.vocab.idx2word[word_id]
                    if word != '<start>':
                        if word != '<unk>':
                            if word != '<end>':
                                sampled_caption.append(word)
                    if word == '<end>':
                        break
                image_caption = ' '.join(sampled_caption)
                processed_captions.append(image_caption)
            
            # getting GTs
            for gt in GT:
                ground_truth_caption = []
                for word_id in gt.detach().numpy():
                    word = self.vocab.idx2word[word_id]
                    if word != '<start>':
                        if word != '<unk>':
                            if word != '<end>':
                                if word != '<pad>':
                                    ground_truth_caption.append(word)
                image_gt_caption = ' '.join(ground_truth_caption)
                GT_captions.append(image_gt_caption)
        return processed_captions, GT_captions

In [13]:
path_image_encoder = 'train/image_encoder.pth'
path_semantic_encoder_path = 'train/semantic_encoder.pth'
path_decoder_path = 'train/decoder.pth'
# Initialize the caption generator and generate caption
caption_generator = CaptionGenerator(vocab_file='vocab.pkl', embed_size=256, hidden_size=256, num_layers=1,
                                        image_encoder_path=path_image_encoder,
                                        semantic_encoder_path=path_semantic_encoder_path,
                                        decoder_path=path_decoder_path)
# image_path = "selected_dataset/test/radiology/images/PMC4803869_GJHS-7-124-g006.jpg"
# caption = caption_generator.generate_caption(image_path)
captions, GTs = caption_generator.process_batch(batch_size=4, max_seg_length=10)
for caption, gt in zip(captions, GTs):
    print(caption, gt)

within within within within within within within within within cystic mass with irregular septations the head of the pancreas ( arrow ) , mimicking the appearance of a cystic .
contours contours contours contours contours contours contours contours ct with multiplanar reconstruction revealed a tumor originating at the lateral and distal part of the trachea and protruding into the
contours contours contours contours contours contours contours contours view on a left femoral hernia . arrows show the internal course of the suture
contours contours contours contours contours contours contours contours the arrow indicates the compression fracture at the level
within within within within within within within within within an ultrasound of the plantar arch was made and found a hypoechoic and homogeneous nodule at the thickness of the plantar fascia with a significant in doppler .
contours contours contours contours contours contours contours contours tissue of margins , with acoustic shadow (