In [1]:
import torchvision.models as models
from torch import nn
import torch
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision.transforms import transforms
from nltk.tokenize import word_tokenize
from string import punctuation
from torchtext.vocab import build_vocab_from_iterator
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

cfg = xLSTMBlockStackConfig(
    mlstm_block=mLSTMBlockConfig(
        mlstm=mLSTMLayerConfig(
            conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
        )
    ),
    slstm_block=sLSTMBlockConfig(
        slstm=sLSTMLayerConfig(
            backend="vanilla",
            num_heads=4,
            conv1d_kernel_size=4,
            bias_init="powerlaw_blockdependent",
        ),
        feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
    ),
    context_length=256,
    num_blocks=7,
    embedding_dim=128,
    slstm_at=[1],

)

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

ModuleNotFoundError: No module named 'torch._custom_ops'

In [2]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, dropout = 0.5, grad = False):
        super(Encoder, self).__init__()
        self.resnet = models.resnet50(weights='DEFAULT')
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_dim)
        
        if not grad:
            for param in self.resnet.parameters():
                param.requires_grad = False

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        feature = self.resnet(x)
        return self.dropout(self.relu(feature))
    

In [3]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout=0.5):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = xLSTMBlockStack(cfg)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.device = device
        self.encoder = encoder.to(device)
    
    def forward(self, image, caption):
        features = self.encoder(image)
        
        embeddings = self.dropout(self.embed(caption))
       
        embeddings = torch.cat((features.unsqueeze(1),embeddings), dim=1)
        
        outputs, state = self.lstm(embeddings)
        outputs = self.linear(outputs)
        
        return outputs
    
#     def forward(self, image, captions):
#         features = self.encoder(image)
#         embeddings = self.dropout(self.embed(captions))
#         embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)

#         batch_size = features.size(0)
#         captions_length = captions.size(1)
#         vocab_size = self.linear.out_features

#         outputs = torch.zeros(batch_size, captions_length, vocab_size).to(self.device)
#         input = features.unsqueeze(1)

#         state = None
        
#         for i in range(captions_length):
#             output, state = self.lstm(input, state)
#             output = self.linear(output)
#             outputs[:, i, :] = output.squeeze(1)

#             top1 = output.argmax(2)
#             input = self.dropout(self.embed(top1))

#         return outputs

In [4]:
data = pd.read_csv("C:/Users/Admin/Desktop/Desktop/imgcaptioning/flick8k/captions.txt")

In [5]:
data.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [6]:
def clean_text(text, lowercase=False, remove_punc=False, remove_num=False, sos_token='<sos>', eos_token='<eos>'):
    if lowercase:
        text = text.lower()
    if remove_punc:
        text = ''.join([ch for ch in text if ch not in punctuation])
    if remove_num:
        text = ''.join([ch for ch in text if ch not in '1234567890'])
    text = [sos_token] + word_tokenize(text) + [eos_token]
    return text

In [7]:
clean_text("A cat is sitting on the table.", lowercase=True, remove_punc=True, remove_num=True)

['<sos>', 'a', 'cat', 'is', 'sitting', 'on', 'the', 'table', '<eos>']

In [8]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'

In [9]:
clean_cap = data['caption'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=True))

In [10]:
clean_cap.head()

0    [<sos>, a, child, in, a, pink, dress, is, clim...
1    [<sos>, a, girl, going, into, a, wooden, build...
2    [<sos>, a, little, girl, climbing, into, a, wo...
3    [<sos>, a, little, girl, climbing, the, stairs...
4    [<sos>, a, little, girl, in, a, pink, dress, g...
Name: caption, dtype: object

In [11]:
data['clean_caption'] = clean_cap

In [12]:
data.head()

Unnamed: 0,image,caption,clean_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,"[<sos>, a, child, in, a, pink, dress, is, clim..."
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<sos>, a, girl, going, into, a, wooden, build..."
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<sos>, a, little, girl, climbing, into, a, wo..."
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,"[<sos>, a, little, girl, climbing, the, stairs..."
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,"[<sos>, a, little, girl, in, a, pink, dress, g..."


In [2]:
vocab = build_vocab_from_iterator(clean_cap, specials=[unk_token, pad_token, sos_token, eos_token])

NameError: name 'build_vocab_from_iterator' is not defined

In [19]:
!pip show torchtext

Name: torchtext
Version: 0.6.0
Summary: Text utilities and datasets for PyTorch
Home-page: https://github.com/pytorch/text
Author: PyTorch core devs and James Bradbury
Author-email: jekbradbury@gmail.com
License: BSD
Location: C:\Users\Admin\anaconda3\envs\xlstm\Lib\site-packages
Requires: numpy, requests, sentencepiece, six, torch, tqdm
Required-by: 


In [None]:
vocab.get_itos()[:10]

In [None]:
pad_token_idx = vocab[pad_token]
unk_token_idx = vocab[unk_token]

In [None]:
vocab.set_default_index(unk_token_idx)

In [None]:
# to number
def text_to_number(text, vocab):
    return [vocab[token] for token in text]

In [None]:
to_int = clean_cap.apply(lambda x: text_to_number(x, vocab))

In [None]:
to_int

In [None]:
data['embed_caption'] = to_int

In [None]:
data.head()

In [None]:
vocab.lookup_tokens(data['embed_caption'][0])

In [None]:
train, test = train_test_split(data, test_size=0.2, random_state=42)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

In [None]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        images = []
        captions = []
        for img, cap in batch:
            images.append(img)
            captions.append(cap)
        images = torch.stack(images)
        captions = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=pad_index)
        return images, captions

    return collate_fn

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, data, transform=None):
        self.root_dir = root_dir
        self.captions = data['embed_caption']
        self.images = data['image']
        self.transform = transform
        
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.images[idx]))
        caption = torch.tensor(self.captions[idx])
        if self.transform:
            image = self.transform(image)
    
        return image, caption

In [None]:
transform = transforms.Compose([
    # data type convert to tensor
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
])

In [None]:
train_dataset = CustomDataset("/kaggle/input/flickr8k/Images", train, transform=transform)
test_dataset = CustomDataset("/kaggle/input/flickr8k/Images", test, transform=transform)

In [None]:
batch_size = 512 
num_workers = 4

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers, collate_fn=get_collate_fn(pad_token_idx))
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers, collate_fn=get_collate_fn(pad_token_idx))

In [None]:
embed_dim = 256
hidden_dim = 512
vocab_size = len(vocab)
num_layers = 2
dropout = 0.5

In [None]:
encoder = Encoder(embed_dim, dropout)
model = Decoder(embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout )
model = model.to(device)

In [None]:
n_epochs = 100
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index = pad_token_idx)
clip = 1.0
teacher_forcing_ratio = 0.5
best_valid_loss = float("inf")


In [None]:
def train_fn(model, data_loader, optimizer, criterion, clip, device):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(data_loader):
        images, captions = batch
        images, captions = images.to(device), captions.to(device)
     
        optimizer.zero_grad()
        
        captions_in = captions[:,:-1]
        outputs = model(images, captions_in)
        outputs = outputs.view(-1, outputs.shape[2]).to(device)
        
        captions = captions.view(-1)
        
        loss = criterion(outputs, captions)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(data_loader)


In [None]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            images, captions = batch
            images, captions = images.to(device), captions.to(device)
            
            captions_in = captions[:,:-1]
            outputs = model(images, captions_in)
            
            outputs = outputs.view(-1, outputs.shape[2]).to(device)
            captions = captions.view(-1)
        
            loss = criterion(outputs, captions)
            epoch_loss += loss.item()
            
    return epoch_loss / len(data_loader)


In [None]:
for epoch in tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        device)
    
    
    valid_loss = evaluate_fn(
        model,
        test_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "best-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f}")
torch.save(model.state_dict(), "last_model.pt")

In [None]:
def captioning_image(
    image,
    model,
    vocab,
    eos_token,
    transform,
    device,
    max_output_length=25,
):
    model.eval()
    with torch.no_grad():
        
        image = transform(image)

        for _ in range(max_output_length):
            inputs_tensor = image.unsqueeze(0).to(device)
            feature = model.encoder(inputs_tensor).to(device)
            output = model.decoder.get_prediction(feature, vocab)
            
            
    return output

In [None]:
import matplotlib.pyplot as plt

In [None]:
image = Image.open("/kaggle/input/flickr8k/Images/1015118661_980735411b.jpg")

In [None]:
image

In [None]:
def predict_caption(model, image, vocab, max_length=50):
    model.eval()
    with torch.no_grad():
        features = model.encoder(image).unsqueeze(1)
        input = features
        hidden = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)
        cell = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)

        caption = []
        for _ in range(max_length):
            output, (hidden, cell) = model.lstm(input, (hidden, cell))
            output = model.linear(output.squeeze(1))
            predicted = output.argmax(1)
            caption.append(predicted.item())
            input = model.dropout(model.embed(predicted)).unsqueeze(1)
            if predicted.item() == vocab['<eos>']:
                break
                
    return vocab.lookup_tokens(caption)

In [None]:
image = transform(image)

In [None]:
pred = predict_caption(model, image.unsqueeze(0).to(device), vocab, 20)

In [None]:
pred