In [2]:
import torch
from torch import Tensor
from torch import nn
from torch.nn.functional import softmax
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
from torch.nn.utils.rnn import pad_sequence
import math
import json
from scipy.optimize import linear_sum_assignment
import os
import spacy
width = 256
height = 256

## Dataset

In [2]:
spacy_en = spacy.load("en_core_web_sm")

class Vocabulary:
    def __init__(self, freq_threshold=1):
        self.itos = {0:'<sos>', 1:'<eos>', 2:'<pad>', 3:'<unk>'}
        self.stoi = {j:i for i,j in self.itos.items()}
        self.freq_threshold = freq_threshold
        self.loss_weight = {0:0,1:0,2:0,3:0}
    
    @staticmethod
    def tokenizer(sentence):
        return [tok.text.lower() for tok in spacy_en.tokenizer(sentence)]
    
    def build_vocab(self, sentences:list):
        frequencies = {}
        for sentence in sentences:
            for word in Vocabulary.tokenizer(sentence):
                if word not in frequencies.keys():
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
        
        idx = 4
        for word, freq in frequencies.items():
            if freq >= self.freq_threshold:
                self.loss_weight[idx] = 100/freq
                self.itos[idx] = word
                self.stoi[word] = idx
                idx += 1
        self.loss_weight[self.stoi['<eos>']] = 100 / len(self.stoi)        
    def numericalize(self,sentence):
        return [self.stoi[word] if word in self.stoi.keys() else self.stoi['<unk>']
                for word in Vocabulary.tokenizer(sentence)]
    
    def __len__(self):
        return len(self.stoi)

In [3]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

class ImageCaptionDateset(Dataset):
    def __init__(self,split='train'):
        super(ImageCaptionDateset,self).__init__()

        if split == 'train':
            self.directory = './train/'
            self.images = os.listdir(self.directory)
            with open('./annotations/train.json','rb') as f:
                data = json.load(f)
        else :
            self.directory = './val/'
            self.images = os.listdir(self.directory)
            with open('./annotations/val.json','rb') as f:
                data = json.load(f)
        self.captions = data['annotations']
        
        self.transforms = transforms.Compose([transforms.Resize((height,width)),
                                              transforms.ToTensor(),
                                              transforms.Normalize(mean,std)])
        
        self.vocab = Vocabulary()
        self.vocab.build_vocab([annotation['caption'] for annotation in self.captions])

    def __getitem__(self, index):
        image = Image.open(os.path.join(self.directory,self.images[index]))
        image = self.transforms(image)
        caption = '<sos>'
        for i in range(5*index,5*index+5):
            caption += self.captions[i]['caption']

        caption += '<eos>'
        caption = torch.tensor(self.vocab.numericalize(caption))

        return image, caption
            
    def __len__(self):
        return self.captions[-1]['image_id']+1

In [4]:
class collate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch:list):
        images = [data[0] for data in batch]
        captions = [data[1] for data in batch]
        
        images = torch.stack(images)
        captions = pad_sequence(captions,True,self.pad_idx)
        return images, captions

In [5]:
batch_size = 32
input_shape = (batch_size,3,height,width)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

trainset = ImageCaptionDateset('train')
#testset = ImageCaptionDateset('test')

trainLoader = DataLoader(trainset,batch_size,shuffle=True,collate_fn=collate(pad_idx=trainset.vocab.stoi['<pad>']))
#testLoader = DataLoader(testset,batch_size,shuffle=True,collate_fn=collate(pad_idx=trainset.vocab.stoi['<pad>']))

## Model

### Swin Transformer

In [6]:
class PatchMerging(nn.Module):
    def __init__(self, in_channel, out_channel,  m=2):
        super().__init__()
        self.m = m
        self.lin = nn.Linear(in_channel* m**2,out_channel)

    def forward(self,x:torch.Tensor):
        # x has shape B x C x h x w
        x = x.unfold(-2,self.m,self.m).unfold(-2,self.m,self.m) # B x C x h/m x w/m x m x m
        x = x.permute(0,2,3,1,4,5).flatten(3,-1) # B x h/m x w/m x c * m^2
        x = self.lin(x)
        return x.permute(0,3,1,2) # B x out_channel x h/m x w/m

class W_MHA(nn.Module):
    def __init__(self,shifted:bool, in_channel,num_head, window_size):
        super().__init__()
        self.shifted = shifted
        self.window_size = window_size
        self.mha = nn.MultiheadAttention(in_channel,num_head,batch_first=True)
    
    def forward(self,x:torch.Tensor):
        # x has shape B x C x h x w
        if not(self.shifted):
            B, C, h, w = x.shape
            x = x.unfold(-2,self.window_size,self.window_size).unfold(-2,self.window_size,self.window_size).reshape((B,C,-1,self.window_size * self.window_size))
            x = x.permute(0,2,3,1).reshape(-1,self.window_size*self.window_size,C)
            x,_ = self.mha(x,x,x.clone())
            return x.reshape(B,h//self.window_size,w//self.window_size,self.window_size,self.window_size,C).permute(0,5,1,3,2,4).reshape(B,C,h,w)
        else:
            padding = self.window_size // 2
            x = F.pad(x,(padding,padding,padding,padding))
            B, C, h, w = x.shape
            x = x.unfold(-2,self.window_size,self.window_size).unfold(-2,self.window_size,self.window_size).reshape((B,C,-1,self.window_size * self.window_size))
            x = x.permute(0,2,3,1).reshape(-1,self.window_size*self.window_size,C)
            x,_ = self.mha(x,x,x.clone(),key_padding_mask= x[:,:,0]==0)
            return x.reshape(B,h//self.window_size,w//self.window_size,self.window_size,self.window_size,C).permute(0,5,1,3,2,4).reshape(B,C,h,w)[:,:,padding:-padding,padding:-padding]


class SwinTransformer(nn.Module):
    def __init__(self, input_shape):
        super(SwinTransformer,self).__init__()
        # block 1
        _, C, h, w = input_shape
        self.LN1_1 = nn.LayerNorm(C)
        self.wmsa1 = W_MHA(False,C,8,8)
        self.LN1_2 = nn.LayerNorm(C)
        self.mlp1 = nn.Sequential(nn.Linear(C,4*C),nn.GELU(),nn.Linear(4*C,C))
        
        self.LN2_1 = nn.LayerNorm(C)
        self.wmsa2 = W_MHA(True,C,8,8)
        self.LN2_2 = nn.LayerNorm(C)
        self.mlp2 = nn.Sequential(nn.Linear(C,4*C),nn.GELU(),nn.Linear(4*C,C))
        

    def forward(self, x:torch.Tensor):
        B, C, h, w = x.shape
        x = x.permute(0,2,3,1) # B x h x w x C
        y = x
        x = self.LN1_1(x)
        x = self.wmsa1(x.permute(0,3,1,2)).permute(0,2,3,1)
        x += y
        y = x
        x = self.LN1_2(x)
        x = self.mlp1(x)
        x += y
        
        y = x
        x = self.LN2_1(x)
        x = self.wmsa2(x.permute(0,3,1,2)).permute(0,2,3,1)
        x += y
        y = x
        x = self.LN2_2(x)
        x = self.mlp2(x)
        x += y
        return x.permute(0,3,1,2)

In [7]:
class featureExtractor(nn.Module):
    def __init__(self,input_shape):
        super().__init__()

        B, c, h, w = input_shape
        self.C = 96
        self.stage1 = nn.Sequential(PatchMerging(c,self.C,4),SwinTransformer((B,self.C,h//4,w//4)))
        
        self.stage2 = nn.Sequential(PatchMerging(self.C,2*self.C,2),SwinTransformer((B,2*self.C,h//8,w//8)))

        self.stage3 = nn.Sequential(PatchMerging(2*self.C,4*self.C,2),SwinTransformer((B,4*self.C,h//16,w//16))
                                    ,SwinTransformer((B,4*self.C,h//16,w//16)),SwinTransformer((B,4*self.C,h//16,w//16)))
        
        self.stage4 = nn.Sequential(PatchMerging(4*self.C,8*self.C,2),SwinTransformer((B,8*self.C,h//32,w//32)))
    
    def forward(self,x:torch.Tensor):
        x = self.stage1(x)
        x1 = x
        x = self.stage2(x)
        x2 = x
        x = self.stage3(x)
        x3 = x
        x4 = self.stage4(x)
        return [x1,x2,x3,x4]

### Deformable DETR Decoder

In [8]:
class Multi_Scale_Cross_Attention(nn.Module):
    def __init__(self, device, hidden_channel=256, num_head=8, layer_num=4, keys_num=4):
        super().__init__()
        
        self.device = device
        self.layer_num = 4
        self.num_head = 8
        self.hidden_channel = hidden_channel
        self.keys_num = keys_num

        self.sample_points = nn.Sequential(nn.Linear(hidden_channel, 4*hidden_channel), nn.GELU(), nn.Linear(4*hidden_channel, layer_num*keys_num*2),nn.Sigmoid())
        self.mlh = nn.MultiheadAttention(hidden_channel,num_head,batch_first=True)
    
    def forward(self, x:torch.Tensor, v:list):
        # x has shape B x N x d
        # v has shape L x [B x Hl x Wl x d]
        feature_maps_shape = torch.tensor([[map.shape[1]-1,map.shape[2]-1] for map in v],device=self.device)
        B, N, d = x.shape
        sample_points = self.sample_points(x).reshape(B,N,self.layer_num,self.keys_num,2)
        sample_points *= feature_maps_shape[None,None,:,None,:]
        keys = self.interpolation(v,sample_points) # B x N x LK x d
        assert keys.shape == (B,N,self.layer_num*self.keys_num,d)
        keys = keys.flatten(1,-2)
        mask = torch.ones((N,N*self.layer_num*self.keys_num),dtype=torch.bool)
        index = torch.arange(0,N*self.layer_num*self.keys_num).reshape(N,self.layer_num*self.keys_num)
        mask[torch.arange(N)[:,None],index] = False
        x,_ = self.mlh(x,keys,keys.clone(),attn_mask=mask.to(device))
        return x
    
    def interpolation(self,v:list, sample_points:torch.Tensor):
        # sample point has shape B x N x L x K x 2
        # v has shape L x [B x Hl x Wl x d]
        B, N, _, _,_= sample_points.shape
        up_points = torch.ceil(sample_points).type(torch.int)
        down_points = torch.floor(sample_points).type(torch.int)
        keys = torch.zeros((B,N,self.layer_num,self.keys_num,self.hidden_channel)).to(self.device) # B x N x L x K
        batches = torch.arange(0,B)[:,None]

        for l in range(self.layer_num):
            row_diff_up = 1 - torch.abs(up_points[:,:,l,:,0]-sample_points[:,:,l,:,0])[:,:,:,None] # B x N x K x 1
            row_diff_down = 1 - torch.abs(down_points[:,:,l,:,0]-sample_points[:,:,l,:,0])[:,:,:,None]
            
            col_diff_up = 1 - torch.abs(up_points[:,:,l,:,1]-sample_points[:,:,l,:,1])[:,:,:,None]
            col_diff_down = 1 - torch.abs(down_points[:,:,l,:,1]-sample_points[:,:,l,:,1])[:,:,:,None]
                       
            value1 = v[l][batches,down_points[:,:,l,:,0].reshape(B,-1),down_points[:,:,l,:,1].reshape(B,-1)].reshape(B,N,self.keys_num,self.hidden_channel) # B x N x K x d
            value2 = v[l][batches,down_points[:,:,l,:,0].reshape(B,-1),up_points[:,:,l,:,1].reshape(B,-1)].reshape(B,N,self.keys_num,self.hidden_channel)
            value3 = v[l][batches,up_points[:,:,l,:,0].reshape(B,-1),down_points[:,:,l,:,1].reshape(B,-1)].reshape(B,N,self.keys_num,self.hidden_channel)
            value4 = v[l][batches,up_points[:,:,l,:,0].reshape(B,-1),up_points[:,:,l,:,1].reshape(B,-1)].reshape(B,N,self.keys_num,self.hidden_channel)
            
            keys[:,:,l,:,:] = row_diff_down * col_diff_down * value1 + row_diff_down * col_diff_up * value2 + row_diff_up * col_diff_down * value3 + row_diff_up * col_diff_up * value4
            
        return keys.flatten(2,-2)


In [26]:
class DecoderBlock(nn.Module):
    def __init__(self, device, N=100, d=256, num_head=8, layer_num=4, key_num=4):
        super().__init__()
        self.N = N
        self.d = d
        self.device = device
        self.self_mha = nn.MultiheadAttention(d,num_head,batch_first=True)
        self.cross_mha = Multi_Scale_Cross_Attention(device,d,num_head,layer_num,key_num)
        self.LN1 = nn.LayerNorm(d)
        self.LN2 = nn.LayerNorm(d)
        self.LN3 = nn.LayerNorm(d)
        self.mlp = nn.Sequential(nn.Linear(d,4*d),nn.GELU(),nn.Linear(4*d,d))
        
    def forward(self, x:torch.Tensor, v:list):
        # x has shape B x N x d
        y = x
        x,_ = self.self_mha(x,x,x.clone())
        x += y
        x = self.LN1(x)
        y = x
        x = self.cross_mha(x,v)
        x += y
        x = self.LN2(x)
        y = x
        x = self.mlp(x)
        x += y
        x = self.LN3(x)
        return x

class Decoder(nn.Module):
    def __init__(self, device, N=100, hidden_channel=256, num_head=8, layer_num=4, key_num=4, batch_size=32):
        super().__init__()
        
        self.queries = nn.Parameter(torch.rand(batch_size,N,hidden_channel))
        self.dec1 = DecoderBlock(device,N,hidden_channel,num_head,layer_num,key_num) 
        self.dec2 = DecoderBlock(device,N,hidden_channel,num_head,layer_num,key_num) 
        self.dec3 = DecoderBlock(device,N,hidden_channel,num_head,layer_num,key_num) 
        self.dec4 = DecoderBlock(device,N,hidden_channel,num_head,layer_num,key_num) 
        self.LN = nn.LayerNorm(hidden_channel)
         
    def forward(self,v:list):
        # v has shape L x [B x dHl x Wl x d]
        output = torch.zeros_like(self.queries)
        output += self.queries
        output= self.LN(output)
        output = self.dec1(output,v)
        output = self.dec2(output,v)
        output = self.dec3(output,v)
        output = self.dec4(output,v)
        return output

### Grid Feature Network

In [10]:
class GridBlock(nn.Module):
    def __init__(self, hidden_channel=256, num_head=8):
        super().__init__()
        self.mlh = nn.MultiheadAttention(hidden_channel,num_head,batch_first=True)
        self.LN1 = nn.LayerNorm(hidden_channel)
        self.LN2 = nn.LayerNorm(hidden_channel)
        self.mlp = nn.Sequential(nn.Linear(hidden_channel,4*hidden_channel),nn.GELU(),nn.Linear(4*hidden_channel,hidden_channel))
    def forward(self,x:torch.Tensor):
        # x has shape B x H/32 * W/32 x d
        y = x
        x,_ = self.mlh(x,x,x.clone())
        x += y
        x = self.LN1(x)
        y = x
        x = self.mlp(x)
        x += y
        x = self.LN2(x)
        return x

class GridNet(nn.Module):
    def __init__(self, hidden_channel=256, num_head=8, block_num=3):
        super().__init__()
        self.LN = nn.LayerNorm(hidden_channel)
        self.blocks = nn.Sequential(*[GridBlock(hidden_channel,num_head) for _ in range(block_num)])
    
    def forward(self, x:torch.Tensor):
        # x has shape B x H/32 * W/32 x d
        x = self.LN(x)
        x = self.blocks(x)
        return x

### Caption Generator

In [11]:
class PosionalEncoding(nn.Module):
    def __init__(self,d, max_length=1000):
        super().__init__()
        assert d%2 == 0
        self.pe = torch.zeros((max_length,d))
        freq = 10000 ** (-2*torch.arange(0,d/2)/d)[None,:]
        indexes =  torch.arange(max_length)[:,None]
        self.pe[:,0::2] = torch.sin(freq*indexes)
        self.pe[:,1::2] = torch.cos(freq*indexes)

    def forward(self,x:torch.Tensor):
        # x has shape B x N x d
        return x + self.pe[None,:x.shape[1],:]

class CpationCrossAttention(nn.Module):
    def __init__(self, d=256, num_head=8):
        super().__init__()
        self.mlh_g = nn.MultiheadAttention(d,num_head,batch_first=True)
        self.mlh_def = nn.MultiheadAttention(d,num_head,batch_first=True)
        self.ling = nn.Sequential(nn.Linear(2*d,d),nn.Sigmoid())
        self.lindef = nn.Sequential(nn.Linear(2*d,d),nn.Sigmoid())

    def forward(self, x:torch.Tensor, grid_features:torch.Tensor, def_features:torch.Tensor):
        # x : B x L x d
        # grid_features : B x M x d
        # def_features : B x N x d
        x_g,_ = self.mlh_g(x,grid_features,grid_features.clone()) # x : B x L x d
        x_def,_ = self.mlh_g(x,def_features,def_features.clone()) # x : B x L x d
        c_g = self.ling(torch.cat((x,x_g),dim=-1))
        c_def = self.lindef(torch.cat((x,x_def),dim=-1)) # x : B x L x d
        return x_g * c_g + x_def * c_def + x

class Captionlayer(nn.Module):
    def __init__(self,device, hidden_channel=256, num_head=8):
        super().__init__()
        self.device = device
        self.self_mlh = nn.MultiheadAttention(hidden_channel,num_head,batch_first=True)
        self.cross_mlh = CpationCrossAttention(hidden_channel,num_head)
        self.LN1 = nn.LayerNorm(hidden_channel)
        self.LN2 = nn.LayerNorm(hidden_channel)
        self.LN3 = nn.LayerNorm(hidden_channel)
        self.mlp = nn.Sequential(nn.Linear(hidden_channel,4*hidden_channel),nn.GELU(),nn.Linear(4*hidden_channel,hidden_channel))

    def forward(self,x:torch.Tensor, grid_features:torch.Tensor, def_features:torch.Tensor):
        # x : B x L x d
        L = x.shape[1]
        y = x
        mask_atn = (~torch.tril(torch.ones((L,L),dtype=torch.bool))).to(self.device)
        x,_ = self.self_mlh(x,x,x.clone(),attn_mask=mask_atn)
        x += y
        x = self.LN1(x)
        y = x
        x = self.cross_mlh(x,grid_features,def_features)
        x += y
        x = self.LN2(x)
        y = x
        x = self.mlp(x)
        x += y
        x = self.LN3(x)
        return x

class CaptionGen(nn.Module):
    def __init__(self, device, vocab:Vocabulary, d=256, num_head=8):
        super().__init__()
        self.vocab = vocab
        self.pe = PosionalEncoding(d)
        self.embedding = nn.Embedding(len(vocab),d,vocab.stoi['<pad>'],device=device)
        self.layer1 = Captionlayer(device,d,num_head)
        self.layer2 = Captionlayer(device,d,num_head)
        self.layer3 = Captionlayer(device,d,num_head)
        self.fc = nn.Linear(d,len(vocab))
    
    def forward(self,trg:torch.Tensor, grid_features:torch.Tensor, def_features:torch.Tensor):

        x = self.embedding(trg) # B x L x d
        x = self.pe(x)
        x = self.layer1(x,grid_features,def_features)
        x = self.layer2(x,grid_features,def_features)
        x = self.layer3(x,grid_features,def_features)
        x = self.fc(x)
        return x # B x L x V
        

### Object Detector

In [27]:
class ObjDetector(nn.Module):
    def __init__(self, input_shape, device, batch_size, hidden_channel=256, class_num=10):
        super().__init__(self)
        self.backbone = featureExtractor(input_shape)
        self.decoder = Decoder(device,batch_size=batch_size)
        self.lin1 = nn.Linear(96,256)
        self.lin2 = nn.Linear(192,256)
        self.lin3 = nn.Linear(384,256)
        self.lin4 = nn.Linear(768,256)

        self.lin_class = nn.Sequential(nn.Linear(hidden_channel,class_num+1),nn.Softmax(class_num+1))
        self.lin_box = nn.Sequential(nn.Linear(hidden_channel,4*hidden_channel),nn.GELU(),nn.Linear(4*hidden_channel,hidden_channel),nn.GELU(),nn.Linear(hidden_channel,4),nn.Sigmoid())
    
    def forward(self, image:torch.tensor):
         # image : B x 3 x H x W
        # caption : B x L
        v1,v2,v3,v4 = self.backbone(image)
        
        v1 = self.lin1(v1.permute(0,2,3,1))
        v2 = self.lin1(v2.permute(0,2,3,1))
        v3 = self.lin1(v3.permute(0,2,3,1))
        v4 = self.lin1(v4.permute(0,2,3,1))

        image = self.decoder([v1,v2,v3,v4])
        classes = self.lin_class(image)
        boxes = self.lin_box(image)
        return classes, boxes

class myLoss(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.background = num_class+1
        
    def forward(self, outputs, targets):
        classes, boxes = outputs
        classes_tgt, boxes_tgt = targets
        B, N, _ = classes.shape
        for b in range(B):
            mathcing_cost = torch.zeros((N,N))
                        
        

In [3]:
cost_matrix = torch.rand(3,100,100)
linear_sum_assignment(cost_matrix)

ValueError: expected a matrix (2-D array), got a 3 array

### Image Captioning

In [12]:
class ImageCaptioning(nn.Module):
    def __init__(self, input_shape:tuple, device, batch_size):
        super().__init__()
        self.tr_backbone = featureExtractor(input_shape)
        self.grid = GridNet()
        self.objDetector = Decoder(device,batch_size=batch_size)
        self.captionGenerator = CaptionGen(device,trainset.vocab)
        self.lin1 = nn.Linear(96,256)
        self.lin2 = nn.Linear(192,256)
        self.lin3 = nn.Linear(384,256)
        self.lin4 = nn.Linear(768,256)
        
    
    def forward(self,image:torch.Tensor, caption:torch.Tensor):
        # image : B x 3 x H x W
        # caption : B x L
        v1,v2,v3,v4 = self.tr_backbone(image)
        
        v1 = self.lin1(v1.permute(0,2,3,1))
        v2 = self.lin1(v2.permute(0,2,3,1))
        v3 = self.lin1(v3.permute(0,2,3,1))
        v4 = self.lin1(v4.permute(0,2,3,1))

        grid_features = self.grid(v4)
        def_features = self.objDetector([v1,v2,v3,v4])
        return self.captionGenerator(caption,grid_features,def_features)

In [13]:
class Inference(nn.Module):
    def __init__(self, vocab:Vocabulary, model_path, device, input_shape, batch_size, beam_size=2, max_length=100, alpha=0.75):
        super().__init__()
        self.device = device
        self.batch_size = batch_size
        self.model = ImageCaptioning(input_shape,device,batch_size).to(device)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()
        self.vocab = vocab
        self.beam_size = beam_size
        self.max_length = max_length
        self.alpha = alpha
        
    def forward(self, image:torch.Tensor):
        # image : B x 3 x H x W
        B = image.shape[0]
        v1,v2,v3,v4 = self.model.tr_backbone(image)
        
        v1 = self.model.lin1(v1.permute(0,2,3,1))
        v2 = self.model.lin1(v2.permute(0,2,3,1))
        v3 = self.model.lin1(v3.permute(0,2,3,1))
        v4 = self.model.lin1(v4.permute(0,2,3,1))

        grid_features = self.model.grid(v4)
        def_features = self.model.objDetector([v1,v2,v3,v4])
        completed_captions = self.beam_search(grid_features,def_features)
        return completed_captions

    
    def beam_search(self, grid_features, def_features):
        captions = torch.zeros((self.batch_size,self.beam_size,self.max_length,2))
        batches = torch.arange(self.batch_size)[:,None]
        captions[:,:,0,0] = self.vocab.stoi['<sos>']
        captions[:,:,0,1] = 1
        vocab_len = len(self.vocab)
        beams = [i for i in range(self.beam_size)]
        eos = self.vocab.stoi['<eos>']

        completed_captions = self.batch_size*[[]]
        completed_scores= self.batch_size*[[]]

        for t in range(1,self.max_length):
            #for i in range(self.beam_size):
            #output = self.model.captionGenerator(captions.flatten(0,1)[beams][:,:t],grid_features,def_features,beams)
            outputs = []
            out_indexes = []
            k_sampling = 2 * self.beam_size
            for i in beams:
                output, out_index = torch.topk(self.model.captionGenerator(captions[:,i,:t,0],grid_features,def_features)[:,-1,:],k_sampling) # B x 2*beam_size
                output = softmax(output,dim=-1) * captions[:,i,t-1,[1]] # B x 2*beam_size, unsorted
                outputs.append(output) 
                out_indexes.append(out_index)
            
            scores , indexes = torch.topk(torch.cat(outputs,dim=-1),self.beam_size) # B x beam_size
            out_indexes = torch.cat(out_indexes,dim=-1) # B x 2*beam_size*beam_size
            captions= captions[batches,indexes // k_sampling] # update history
            new_token = out_indexes[batches,indexes] # B x beam_size, Vocabulary indexes
            captions[:,:,t,0] = new_token
            captions[:,:,t,1] = scores   

            rows, cols = torch.where(new_token==eos)
            for j in len(rows):
                batch = rows[j]
                beam = cols[j]
                completed_captions[batch].append(captions[batch,beam,:t+1,0])
                completed_scores[batch].append(captions[batch,beam,t,1])
            
        # Normalizing scores (different caption length)
        for batch in range(self.batch_size):
            for counter in len(completed_captions[batch]):
                den = len(completed_captions[batch][counter]) ** self.alpha
                score = completed_scores[batch][counter] 
                completed_scores[batch][counter] = torch.log(score) / den
        
        for batch in range(self.batch_size):
            scores = torch.cat(completed_scores[batch])
            index = torch.argmax(scores)
            completed_captions[batch] = completed_captions[batch][torch.argmax(scores).item()]
        
        return completed_captions
        

# Training

In [41]:
"""
learning_rate = 0.0001
epoch_num = 20
load = False
model = ImageCaptioning(input_shape,device,batch_size).to(device)
if load:
    model.load_state_dict(torch.load('./ImageCaptioning.pth'))

loss_weight = torch.tensor(list(trainset.vocab.loss_weight.values()))
loss_criterion = nn.CrossEntropyLoss(loss_weight).to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

batch_num = len(trainLoader)
losses = []
for epoch in range(epoch_num):  # loop over the dataset multiple times
    model.train()
    running_loss = 0.0
    tqdm_bar = tqdm(trainLoader, desc=f'Training Epoch {epoch} ', total=int(len(trainLoader)))

    for i, (image,caption) in enumerate(tqdm_bar):
        image = image.to(device)
        caption = caption.to(device)
        
        output = model(image,caption)

        optimizer.zero_grad()
        loss = loss_criterion(output[:,0:-1,:].flatten(0,1),caption[:,1:].flatten())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    running_loss /= batch_num
    losses.append(running_loss)
    print('epoch : ',epoch,', loss : ',running_loss)
    torch.save(model.state_dict(), './ImageCaptioning.pth')
    
print('Finished Training')
"""

"\nlearning_rate = 0.0001\nepoch_num = 20\nload = False\nmodel = ImageCaptioning(input_shape,device,batch_size).to(device)\nif load:\n    model.load_state_dict(torch.load('./ImageCaptioning.pth'))\n\nloss_weight = torch.tensor(list(trainset.vocab.loss_weight.values()))\nloss_criterion = nn.CrossEntropyLoss(loss_weight).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)\n\nbatch_num = len(trainLoader)\nlosses = []\nfor epoch in range(epoch_num):  # loop over the dataset multiple times\n    model.train()\n    running_loss = 0.0\n    tqdm_bar = tqdm(trainLoader, desc=f'Training Epoch {epoch} ', total=int(len(trainLoader)))\n\n    for i, (image,caption) in enumerate(tqdm_bar):\n        image = image.to(device)\n        caption = caption.to(device)\n        \n        output = model(image,caption)\n\n        optimizer.zero_grad()\n        loss = loss_criterion(output[:,0:-1,:].flatten(0,1),caption[:,1:].flatten())\n        loss.backward()\n        optimizer.step(