**LIBRARIES**

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import torchvision
from torchvision import transforms
import cv2
import math

from collections import Counter
from PIL import Image
import PIL

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as fun

import tensorflow
from tensorflow.keras import layers, models

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from tensorflow.keras import Model
from tensorflow.keras.layers import Add, GlobalAveragePooling2D,\
	Dense, Flatten, Conv2D, Lambda,	Input, BatchNormalization, Activation
from tensorflow.keras.optimizers import schedules, SGD
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint

# import nltk
# import ssl

# try:
#     _create_unverified_https_context = ssl._create_unverified_context
# except AttributeError:
#     pass
# else:
#     ssl._create_default_https_context = _create_unverified_https_context

# nltk.download()
from nltk.tokenize import word_tokenize

import match
import pickle
import gc
import random

There was a problem when trying to write in your cache folder (C:\Users\asus\.cache\huggingface\hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


**DATA Preprocessing** 

In [2]:
data = pd.read_csv("captions.txt", sep=',')
display(data)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...
...,...,...
40450,997722733_0cb5439472.jpg,A man in a pink shirt climbs a rock face
40451,997722733_0cb5439472.jpg,A man is rock climbing high in the air .
40452,997722733_0cb5439472.jpg,A person in a red shirt climbing up a rock fac...
40453,997722733_0cb5439472.jpg,A rock climber in a red shirt .


**Tokenization**

In [3]:
#Removes Single Char
def remove_single_char(caption_list):
    list = []
    for word in caption_list:
        if len(word)>1:
            list.append(word)
    return list

In [4]:
#Make an array of words out of caption and then remove useless single char words

data['caption'] = data['caption'].apply(lambda caption :word_tokenize(caption))

data['caption'] = data['caption'].apply(lambda word : remove_single_char(word))

#We need to make sure size of all the captions arrays is same so we add <cell> to cover up
lengths = []
lengths = data['caption'].apply(lambda caption : len(caption))

max_length = lengths.max()

data['caption'] = data['caption'].apply(lambda caption : ['<start>'] + caption + ['<cell>']*(max_length-len(caption)) + ['<end>'])

#For non truncated dataframe to appear
pd.set_option('display.max_colwidth', None)
display(data)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
1,1000268201_693b08cb0e.jpg,"[<start>, girl, going, into, wooden, building, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
2,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, into, wooden, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
3,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, the, stairs, to, her, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
4,1000268201_693b08cb0e.jpg,"[<start>, little, girl, in, pink, dress, going, into, wooden, cabin, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
...,...,...
40450,997722733_0cb5439472.jpg,"[<start>, man, in, pink, shirt, climbs, rock, face, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
40451,997722733_0cb5439472.jpg,"[<start>, man, is, rock, climbing, high, in, the, air, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
40452,997722733_0cb5439472.jpg,"[<start>, person, in, red, shirt, climbing, up, rock, face, covered, in, assist, handles, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"
40453,997722733_0cb5439472.jpg,"[<start>, rock, climber, in, red, shirt, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]"


**Vocab and Dictionary**

In [5]:
#Extracting words 
words = data['caption'].apply(lambda word : " ".join(word)).str.cat(sep = ' ').split(' ')
display(words)

['<start>',
 'child',
 'in',
 'pink',
 'dress',
 'is',
 'climbing',
 'up',
 'set',
 'of',
 'stairs',
 'in',
 'an',
 'entry',
 'way',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<end>',
 '<start>',
 'girl',
 'going',
 'into',
 'wooden',
 'building',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<end>',
 '<start>',
 'little',
 'girl',
 'climbing',
 'into',
 'wooden',
 'playhouse',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',
 '<cell>',


In [6]:
#Arranging the words in order of their frequency
word_dict = sorted(Counter(words), key=Counter(words).get, reverse=True)

dict_size = len(word_dict)
vocab_threshold = 5

print(len(word_dict))
print(vocab_threshold)
display(word_dict)

9596
5


['<cell>',
 '<start>',
 '<end>',
 'in',
 'the',
 'on',
 'is',
 'and',
 'dog',
 'with',
 'man',
 'of',
 'Two',
 'white',
 'black',
 'are',
 'boy',
 'woman',
 'girl',
 'to',
 'The',
 'wearing',
 'at',
 'water',
 'red',
 'brown',
 'people',
 'young',
 'his',
 'blue',
 'dogs',
 'running',
 'through',
 'playing',
 'while',
 'an',
 'down',
 'shirt',
 'standing',
 'ball',
 'little',
 'grass',
 'snow',
 'child',
 'jumping',
 'over',
 'person',
 'front',
 'sitting',
 'holding',
 'field',
 'two',
 'up',
 'by',
 'green',
 'small',
 'yellow',
 'large',
 'her',
 'group',
 'walking',
 'Three',
 'into',
 'air',
 'beach',
 'men',
 'near',
 'one',
 'children',
 'mouth',
 'jumps',
 'another',
 'for',
 'street',
 'runs',
 'its',
 'from',
 'riding',
 'stands',
 'as',
 'bike',
 'girls',
 'outside',
 'other',
 'out',
 'rock',
 'next',
 'play',
 'off',
 'looking',
 'pink',
 'orange',
 'player',
 'their',
 'pool',
 'camera',
 'hat',
 'jacket',
 'around',
 'boys',
 'behind',
 'women',
 'background',
 'toy',
 '

In [7]:
#Encoding the words with index in dictionary made above
data['sequence'] = data['caption'].apply(lambda caption : [word_dict.index(word) for word in caption])

display(data)

Unnamed: 0,image,caption,sequence
0,1000268201_693b08cb0e.jpg,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 43, 3, 90, 174, 6, 120, 52, 409, 11, 405, 3, 35, 5475, 714, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
1,1000268201_693b08cb0e.jpg,"[<start>, girl, going, into, wooden, building, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 18, 320, 62, 197, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
2,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, into, wooden, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 62, 197, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
3,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, the, stairs, to, her, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 4, 405, 19, 58, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
4,1000268201_693b08cb0e.jpg,"[<start>, little, girl, in, pink, dress, going, into, wooden, cabin, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 3, 90, 174, 320, 62, 197, 3091, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
...,...,...,...
40450,997722733_0cb5439472.jpg,"[<start>, man, in, pink, shirt, climbs, rock, face, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 3, 90, 37, 257, 85, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40451,997722733_0cb5439472.jpg,"[<start>, man, is, rock, climbing, high, in, the, air, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 6, 85, 120, 198, 3, 4, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40452,997722733_0cb5439472.jpg,"[<start>, person, in, red, shirt, climbing, up, rock, face, covered, in, assist, handles, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 46, 3, 24, 37, 120, 52, 85, 123, 187, 3, 3701, 1763, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40453,997722733_0cb5439472.jpg,"[<start>, rock, climber, in, red, shirt, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 85, 374, 3, 24, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"


In [8]:
data = data.sort_values(by = 'image')

**ResNet18 Pretrained Model**

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [10]:
# resnet50 = torchvision.models.resnet50(pretrained=True).to(device)
# resnet50.eval()
# list(resnet50._modules)
# # for parameters in resnet50.parameters():
# #     parameters.requires_grad_(False)
# resnet50Layer4 = resnet50._modules.get('layer4').to(device)

**Extracting Features from Images**

In [11]:
preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

def get_image_tensor(index, preprocess, data):
    image_name = data.iloc[index]['image']
    img_loc = 'Images/'+str(image_name)
    img = Image.open(img_loc).convert('RGB')
    tensor_image = preprocess(img).unsqueeze(0)

    return tensor_image

**Encoder CNN**

In [12]:
class Encoder_CNN(nn.Module):
    def __init__(self, embed_size):
        super(Encoder_CNN, self).__init__()
        resnet50 = torchvision.models.resnet50(pretrained=True)
        resnet50.eval()
        # for parameters in resnet50.parameters():
        #     parameters.requires_grad_(False)
        modules = list(resnet50.children())[:-1]
        self.model = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet50.fc.in_features, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, image):
        features = self.model(image)
        features = features.view(features.size(0), -1)
        features = self.dropout(self.relu(self.embed(features)))
        return features

In [None]:
encoder = Encoder_CNN(256)
encoder.eval()

embedded_vectors = []

for i, row in tqdm(data.iterrows(), total=len(data), desc="Processing Images"):
    image_tensor = get_image_tensor(i, preprocess, data)
    image_tensor = image_tensor.to(device)

    print(f"Initial tensor shape: {image_tensor.shape}")

    with torch.no_grad():
        for name, module in encoder.model.named_children():
            image_tensor = module(image_tensor)
            print(f"After layer {name}: {image_tensor.shape}")

    
        output = image_tensor.view(image_tensor.size(0), -1)
        print(f"Output tensor shape: {output.shape}")
        
        embed_vector = encoder.embed(output)
        embedded_vectors.append(embed_vector.cpu().numpy()) 

data['embedded'] = embedded_vectors
print(data)

embedded_data = open("Embedded_Images_data.pkl", "wb")
pickle.dump(data, embedded_data)
embedded_data.close()



Processing Images:   0%|          | 0/40455 [00:00<?, ?it/s]

Initial tensor shape: torch.Size([1, 3, 224, 224])
After layer 0: torch.Size([1, 64, 112, 112])
After layer 1: torch.Size([1, 64, 112, 112])
After layer 2: torch.Size([1, 64, 112, 112])
After layer 3: torch.Size([1, 64, 56, 56])
After layer 4: torch.Size([1, 256, 56, 56])
After layer 5: torch.Size([1, 512, 28, 28])
After layer 6: torch.Size([1, 1024, 14, 14])
After layer 7: torch.Size([1, 2048, 7, 7])
After layer 8: torch.Size([1, 2048, 1, 1])
Output tensor shape: torch.Size([1, 2048])
Initial tensor shape: torch.Size([1, 3, 224, 224])
After layer 0: torch.Size([1, 64, 112, 112])
After layer 1: torch.Size([1, 64, 112, 112])
After layer 2: torch.Size([1, 64, 112, 112])
After layer 3: torch.Size([1, 64, 56, 56])
After layer 4: torch.Size([1, 256, 56, 56])
After layer 5: torch.Size([1, 512, 28, 28])
After layer 6: torch.Size([1, 1024, 14, 14])
After layer 7: torch.Size([1, 2048, 7, 7])
After layer 8: torch.Size([1, 2048, 1, 1])
Output tensor shape: torch.Size([1, 2048])
Initial tensor sha

**DataLoader**

In [13]:
#Training and Validation Data
#Same images but different captions

display(data)
embed_data = pd.read_pickle('Embedded_Images_data.pkl')
train, validation = train_test_split(embed_data,test_size=0.1,train_size=0.9)

display(embed_data)

print(len(train), train['image'].nunique())

words = train['caption'].apply(lambda word : " ".join(word)).str.cat(sep = ' ').split(' ')
word_dict_train = sorted(Counter(words), key=Counter(words).get, reverse=True)
vocab_size_train = len(word_dict_train)

print(len(validation), validation['image'].nunique())

train_file_features = open("Image_Features_Embed_ResNet_Train.pkl", "wb")
pickle.dump(train, train_file_features)
train_file_features.close()

valid_file_features = open("Image_Features_Embed_ResNet_Valid.pkl", "wb")
pickle.dump(validation, valid_file_features)
valid_file_features.close()


Unnamed: 0,image,caption,sequence
0,1000268201_693b08cb0e.jpg,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 43, 3, 90, 174, 6, 120, 52, 409, 11, 405, 3, 35, 5475, 714, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
1,1000268201_693b08cb0e.jpg,"[<start>, girl, going, into, wooden, building, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 18, 320, 62, 197, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
2,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, into, wooden, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 62, 197, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
3,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, the, stairs, to, her, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 4, 405, 19, 58, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
4,1000268201_693b08cb0e.jpg,"[<start>, little, girl, in, pink, dress, going, into, wooden, cabin, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 3, 90, 174, 320, 62, 197, 3091, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
...,...,...,...
40452,997722733_0cb5439472.jpg,"[<start>, person, in, red, shirt, climbing, up, rock, face, covered, in, assist, handles, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 46, 3, 24, 37, 120, 52, 85, 123, 187, 3, 3701, 1763, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40453,997722733_0cb5439472.jpg,"[<start>, rock, climber, in, red, shirt, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 85, 374, 3, 24, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40450,997722733_0cb5439472.jpg,"[<start>, man, in, pink, shirt, climbs, rock, face, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 3, 90, 37, 257, 85, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"
40451,997722733_0cb5439472.jpg,"[<start>, man, is, rock, climbing, high, in, the, air, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 6, 85, 120, 198, 3, 4, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]"


Unnamed: 0,image,caption,sequence,embedded
0,1000268201_693b08cb0e.jpg,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 43, 3, 90, 174, 6, 120, 52, 409, 11, 405, 3, 35, 5475, 714, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.25657812, 0.1596796, -0.18215069, -0.12559445, -0.48660097, -0.24840401, -0.33725232, -0.29922915, -0.3841372, 0.4398388, 0.2446603, -0.12654766, 0.07274709, -0.5547995, -0.4264813, -0.175623, 0.7033246, -0.30885378, -0.096250094, -0.09895194, 0.12551063, -0.5078764, 0.47357342, -0.5843763, 0.1548141, -0.01372434, 0.7916916, -0.16412023, -0.3251785, -0.15404087, 0.45641607, 0.19545189, 0.1538179, -1.1228678, -0.11292797, 0.5357347, -0.71602875, -0.24327391, -0.10776442, 0.4421068, -0.124913335, -0.4749884, -0.20841569, -0.6593651, -0.6421927, -0.7728128, 0.042490676, -0.2585397, -0.15238129, 0.08205474, 0.23887092, 1.444557, 0.23176722, -0.36183125, 0.27660522, 0.18003926, 0.46660706, 0.4998935, 0.14573471, 0.043374393, 0.016538221, 0.18691948, -0.13965648, -0.21066114, 0.31631973, 0.4713003, -0.24490699, 0.41625744, -0.66051066, -0.14184164, 0.88977003, 0.24161881, 0.19531491, -0.7841556, -0.101970814, 0.028744856, 0.18281052, -0.20052283, 0.097187586, 0.22262156, 0.07440245, 0.16425854, 0.3325673, 0.3630356, -0.20271915, -0.34309855, -0.35268533, -0.20669198, -0.45370284, -0.28678265, -0.35213754, -0.31758907, -0.116825625, -0.43406555, 0.988559, -0.37352133, -0.5324662, 0.3071673, -0.3646447, -0.29254013, ...]]"
1,1000268201_693b08cb0e.jpg,"[<start>, girl, going, into, wooden, building, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 18, 320, 62, 197, 118, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.25657812, 0.1596796, -0.18215069, -0.12559445, -0.48660097, -0.24840401, -0.33725232, -0.29922915, -0.3841372, 0.4398388, 0.2446603, -0.12654766, 0.07274709, -0.5547995, -0.4264813, -0.175623, 0.7033246, -0.30885378, -0.096250094, -0.09895194, 0.12551063, -0.5078764, 0.47357342, -0.5843763, 0.1548141, -0.01372434, 0.7916916, -0.16412023, -0.3251785, -0.15404087, 0.45641607, 0.19545189, 0.1538179, -1.1228678, -0.11292797, 0.5357347, -0.71602875, -0.24327391, -0.10776442, 0.4421068, -0.124913335, -0.4749884, -0.20841569, -0.6593651, -0.6421927, -0.7728128, 0.042490676, -0.2585397, -0.15238129, 0.08205474, 0.23887092, 1.444557, 0.23176722, -0.36183125, 0.27660522, 0.18003926, 0.46660706, 0.4998935, 0.14573471, 0.043374393, 0.016538221, 0.18691948, -0.13965648, -0.21066114, 0.31631973, 0.4713003, -0.24490699, 0.41625744, -0.66051066, -0.14184164, 0.88977003, 0.24161881, 0.19531491, -0.7841556, -0.101970814, 0.028744856, 0.18281052, -0.20052283, 0.097187586, 0.22262156, 0.07440245, 0.16425854, 0.3325673, 0.3630356, -0.20271915, -0.34309855, -0.35268533, -0.20669198, -0.45370284, -0.28678265, -0.35213754, -0.31758907, -0.116825625, -0.43406555, 0.988559, -0.37352133, -0.5324662, 0.3071673, -0.3646447, -0.29254013, ...]]"
2,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, into, wooden, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 62, 197, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.25657812, 0.1596796, -0.18215069, -0.12559445, -0.48660097, -0.24840401, -0.33725232, -0.29922915, -0.3841372, 0.4398388, 0.2446603, -0.12654766, 0.07274709, -0.5547995, -0.4264813, -0.175623, 0.7033246, -0.30885378, -0.096250094, -0.09895194, 0.12551063, -0.5078764, 0.47357342, -0.5843763, 0.1548141, -0.01372434, 0.7916916, -0.16412023, -0.3251785, -0.15404087, 0.45641607, 0.19545189, 0.1538179, -1.1228678, -0.11292797, 0.5357347, -0.71602875, -0.24327391, -0.10776442, 0.4421068, -0.124913335, -0.4749884, -0.20841569, -0.6593651, -0.6421927, -0.7728128, 0.042490676, -0.2585397, -0.15238129, 0.08205474, 0.23887092, 1.444557, 0.23176722, -0.36183125, 0.27660522, 0.18003926, 0.46660706, 0.4998935, 0.14573471, 0.043374393, 0.016538221, 0.18691948, -0.13965648, -0.21066114, 0.31631973, 0.4713003, -0.24490699, 0.41625744, -0.66051066, -0.14184164, 0.88977003, 0.24161881, 0.19531491, -0.7841556, -0.101970814, 0.028744856, 0.18281052, -0.20052283, 0.097187586, 0.22262156, 0.07440245, 0.16425854, 0.3325673, 0.3630356, -0.20271915, -0.34309855, -0.35268533, -0.20669198, -0.45370284, -0.28678265, -0.35213754, -0.31758907, -0.116825625, -0.43406555, 0.988559, -0.37352133, -0.5324662, 0.3071673, -0.3646447, -0.29254013, ...]]"
3,1000268201_693b08cb0e.jpg,"[<start>, little, girl, climbing, the, stairs, to, her, playhouse, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 120, 4, 405, 19, 58, 2490, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.25657812, 0.1596796, -0.18215069, -0.12559445, -0.48660097, -0.24840401, -0.33725232, -0.29922915, -0.3841372, 0.4398388, 0.2446603, -0.12654766, 0.07274709, -0.5547995, -0.4264813, -0.175623, 0.7033246, -0.30885378, -0.096250094, -0.09895194, 0.12551063, -0.5078764, 0.47357342, -0.5843763, 0.1548141, -0.01372434, 0.7916916, -0.16412023, -0.3251785, -0.15404087, 0.45641607, 0.19545189, 0.1538179, -1.1228678, -0.11292797, 0.5357347, -0.71602875, -0.24327391, -0.10776442, 0.4421068, -0.124913335, -0.4749884, -0.20841569, -0.6593651, -0.6421927, -0.7728128, 0.042490676, -0.2585397, -0.15238129, 0.08205474, 0.23887092, 1.444557, 0.23176722, -0.36183125, 0.27660522, 0.18003926, 0.46660706, 0.4998935, 0.14573471, 0.043374393, 0.016538221, 0.18691948, -0.13965648, -0.21066114, 0.31631973, 0.4713003, -0.24490699, 0.41625744, -0.66051066, -0.14184164, 0.88977003, 0.24161881, 0.19531491, -0.7841556, -0.101970814, 0.028744856, 0.18281052, -0.20052283, 0.097187586, 0.22262156, 0.07440245, 0.16425854, 0.3325673, 0.3630356, -0.20271915, -0.34309855, -0.35268533, -0.20669198, -0.45370284, -0.28678265, -0.35213754, -0.31758907, -0.116825625, -0.43406555, 0.988559, -0.37352133, -0.5324662, 0.3071673, -0.3646447, -0.29254013, ...]]"
4,1000268201_693b08cb0e.jpg,"[<start>, little, girl, in, pink, dress, going, into, wooden, cabin, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 40, 18, 3, 90, 174, 320, 62, 197, 3091, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.25657812, 0.1596796, -0.18215069, -0.12559445, -0.48660097, -0.24840401, -0.33725232, -0.29922915, -0.3841372, 0.4398388, 0.2446603, -0.12654766, 0.07274709, -0.5547995, -0.4264813, -0.175623, 0.7033246, -0.30885378, -0.096250094, -0.09895194, 0.12551063, -0.5078764, 0.47357342, -0.5843763, 0.1548141, -0.01372434, 0.7916916, -0.16412023, -0.3251785, -0.15404087, 0.45641607, 0.19545189, 0.1538179, -1.1228678, -0.11292797, 0.5357347, -0.71602875, -0.24327391, -0.10776442, 0.4421068, -0.124913335, -0.4749884, -0.20841569, -0.6593651, -0.6421927, -0.7728128, 0.042490676, -0.2585397, -0.15238129, 0.08205474, 0.23887092, 1.444557, 0.23176722, -0.36183125, 0.27660522, 0.18003926, 0.46660706, 0.4998935, 0.14573471, 0.043374393, 0.016538221, 0.18691948, -0.13965648, -0.21066114, 0.31631973, 0.4713003, -0.24490699, 0.41625744, -0.66051066, -0.14184164, 0.88977003, 0.24161881, 0.19531491, -0.7841556, -0.101970814, 0.028744856, 0.18281052, -0.20052283, 0.097187586, 0.22262156, 0.07440245, 0.16425854, 0.3325673, 0.3630356, -0.20271915, -0.34309855, -0.35268533, -0.20669198, -0.45370284, -0.28678265, -0.35213754, -0.31758907, -0.116825625, -0.43406555, 0.988559, -0.37352133, -0.5324662, 0.3071673, -0.3646447, -0.29254013, ...]]"
...,...,...,...,...
40452,997722733_0cb5439472.jpg,"[<start>, person, in, red, shirt, climbing, up, rock, face, covered, in, assist, handles, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 46, 3, 24, 37, 120, 52, 85, 123, 187, 3, 3701, 1763, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.1538693, 0.01727238, -0.42711896, -0.069321975, -0.19021437, 0.1282044, -0.57983214, -0.38408458, -0.14259058, 1.0224355, 0.13585599, 0.30872723, 0.36482245, -0.65265876, -0.037631817, -0.83522165, 0.8193441, 0.2013842, -0.5694834, 0.56673944, 0.105403796, -0.03984061, 0.17162825, -0.4818016, -0.15630296, 0.48582703, 0.17576697, -0.71782917, 0.22920676, -0.26425207, 0.10103524, 0.035923366, 0.17855792, -0.7130602, 0.37215614, -0.19474639, -0.7028228, -0.4371991, 0.17535582, 0.41654384, -0.446667, -0.5539143, 0.018177724, -1.1418701, -0.5091324, -0.42707402, 0.42207944, -0.5784515, -0.464106, 0.13840695, 0.20463654, 0.5197857, 0.8890513, 0.00057186186, 0.2323207, 0.23825228, 0.5469229, 0.8805841, 0.011265902, 0.33942652, -0.26033875, 0.25930104, 0.07704953, -0.012371924, 0.39401752, 0.15471226, -0.5789501, -0.33814797, -0.7126805, 0.20962623, 0.68570745, 0.39234167, 0.22027324, -0.30258492, -0.15331626, -0.0218336, 0.16207498, -0.043828662, -0.37590557, -0.1493859, -0.18898186, 0.22827166, -0.028128767, 0.15688524, 0.12482412, -0.24013954, -0.7548429, -0.18155475, 0.052313913, -0.36837956, -0.57133776, -0.6624626, -0.5554635, -0.35049942, 0.70338154, -0.64652455, -0.57492536, -0.05822436, -0.2977836, -0.1466739, ...]]"
40453,997722733_0cb5439472.jpg,"[<start>, rock, climber, in, red, shirt, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 85, 374, 3, 24, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.1538693, 0.01727238, -0.42711896, -0.069321975, -0.19021437, 0.1282044, -0.57983214, -0.38408458, -0.14259058, 1.0224355, 0.13585599, 0.30872723, 0.36482245, -0.65265876, -0.037631817, -0.83522165, 0.8193441, 0.2013842, -0.5694834, 0.56673944, 0.105403796, -0.03984061, 0.17162825, -0.4818016, -0.15630296, 0.48582703, 0.17576697, -0.71782917, 0.22920676, -0.26425207, 0.10103524, 0.035923366, 0.17855792, -0.7130602, 0.37215614, -0.19474639, -0.7028228, -0.4371991, 0.17535582, 0.41654384, -0.446667, -0.5539143, 0.018177724, -1.1418701, -0.5091324, -0.42707402, 0.42207944, -0.5784515, -0.464106, 0.13840695, 0.20463654, 0.5197857, 0.8890513, 0.00057186186, 0.2323207, 0.23825228, 0.5469229, 0.8805841, 0.011265902, 0.33942652, -0.26033875, 0.25930104, 0.07704953, -0.012371924, 0.39401752, 0.15471226, -0.5789501, -0.33814797, -0.7126805, 0.20962623, 0.68570745, 0.39234167, 0.22027324, -0.30258492, -0.15331626, -0.0218336, 0.16207498, -0.043828662, -0.37590557, -0.1493859, -0.18898186, 0.22827166, -0.028128767, 0.15688524, 0.12482412, -0.24013954, -0.7548429, -0.18155475, 0.052313913, -0.36837956, -0.57133776, -0.6624626, -0.5554635, -0.35049942, 0.70338154, -0.64652455, -0.57492536, -0.05822436, -0.2977836, -0.1466739, ...]]"
40450,997722733_0cb5439472.jpg,"[<start>, man, in, pink, shirt, climbs, rock, face, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 3, 90, 37, 257, 85, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.1538693, 0.01727238, -0.42711896, -0.069321975, -0.19021437, 0.1282044, -0.57983214, -0.38408458, -0.14259058, 1.0224355, 0.13585599, 0.30872723, 0.36482245, -0.65265876, -0.037631817, -0.83522165, 0.8193441, 0.2013842, -0.5694834, 0.56673944, 0.105403796, -0.03984061, 0.17162825, -0.4818016, -0.15630296, 0.48582703, 0.17576697, -0.71782917, 0.22920676, -0.26425207, 0.10103524, 0.035923366, 0.17855792, -0.7130602, 0.37215614, -0.19474639, -0.7028228, -0.4371991, 0.17535582, 0.41654384, -0.446667, -0.5539143, 0.018177724, -1.1418701, -0.5091324, -0.42707402, 0.42207944, -0.5784515, -0.464106, 0.13840695, 0.20463654, 0.5197857, 0.8890513, 0.00057186186, 0.2323207, 0.23825228, 0.5469229, 0.8805841, 0.011265902, 0.33942652, -0.26033875, 0.25930104, 0.07704953, -0.012371924, 0.39401752, 0.15471226, -0.5789501, -0.33814797, -0.7126805, 0.20962623, 0.68570745, 0.39234167, 0.22027324, -0.30258492, -0.15331626, -0.0218336, 0.16207498, -0.043828662, -0.37590557, -0.1493859, -0.18898186, 0.22827166, -0.028128767, 0.15688524, 0.12482412, -0.24013954, -0.7548429, -0.18155475, 0.052313913, -0.36837956, -0.57133776, -0.6624626, -0.5554635, -0.35049942, 0.70338154, -0.64652455, -0.57492536, -0.05822436, -0.2977836, -0.1466739, ...]]"
40451,997722733_0cb5439472.jpg,"[<start>, man, is, rock, climbing, high, in, the, air, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <cell>, <end>]","[1, 10, 6, 85, 120, 198, 3, 4, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]","[[-0.1538693, 0.01727238, -0.42711896, -0.069321975, -0.19021437, 0.1282044, -0.57983214, -0.38408458, -0.14259058, 1.0224355, 0.13585599, 0.30872723, 0.36482245, -0.65265876, -0.037631817, -0.83522165, 0.8193441, 0.2013842, -0.5694834, 0.56673944, 0.105403796, -0.03984061, 0.17162825, -0.4818016, -0.15630296, 0.48582703, 0.17576697, -0.71782917, 0.22920676, -0.26425207, 0.10103524, 0.035923366, 0.17855792, -0.7130602, 0.37215614, -0.19474639, -0.7028228, -0.4371991, 0.17535582, 0.41654384, -0.446667, -0.5539143, 0.018177724, -1.1418701, -0.5091324, -0.42707402, 0.42207944, -0.5784515, -0.464106, 0.13840695, 0.20463654, 0.5197857, 0.8890513, 0.00057186186, 0.2323207, 0.23825228, 0.5469229, 0.8805841, 0.011265902, 0.33942652, -0.26033875, 0.25930104, 0.07704953, -0.012371924, 0.39401752, 0.15471226, -0.5789501, -0.33814797, -0.7126805, 0.20962623, 0.68570745, 0.39234167, 0.22027324, -0.30258492, -0.15331626, -0.0218336, 0.16207498, -0.043828662, -0.37590557, -0.1493859, -0.18898186, 0.22827166, -0.028128767, 0.15688524, 0.12482412, -0.24013954, -0.7548429, -0.18155475, 0.052313913, -0.36837956, -0.57133776, -0.6624626, -0.5554635, -0.35049942, 0.70338154, -0.64652455, -0.57492536, -0.05822436, -0.2977836, -0.1466739, ...]]"


36409 8091
4046 3339


In [14]:
def loader_dataset(pickle_file):
    data = pd.read_pickle(pickle_file)
    dataset = pd.DataFrame(columns = ['sequence', 'target', 'image'])
    
    for i, row in data.iterrows():
        sequence = torch.tensor(row['sequence'])
        target = torch.tensor(row['sequence'][1:]+[0])
        image_name = row['image']
        
        image_tensor = torch.tensor(row['embedded'], dtype=torch.float32)
        image_tensor_view = image_tensor

        dataset = dataset._append({'sequence' : sequence, 'target' : target, 'image' : image_tensor_view},
        ignore_index = True)

    display(dataset)
    return dataset

class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        return self.dataframe.loc[idx, 'sequence'], self.dataframe.loc[idx, 'target'], self.dataframe.loc[idx, 'image']

In [15]:
train_dataset = loader_dataset('Image_Features_Embed_ResNet_Train.pkl')
valid_dataset = loader_dataset('Image_Features_Embed_ResNet_Valid.pkl')

train_dataloader = DataLoader(CustomDataset(train_dataset), batch_size=32, shuffle=True)
valid_dataloader = DataLoader(CustomDataset(valid_dataset), batch_size=32, shuffle=True)

Unnamed: 0,sequence,target,image
0,"[tensor(1), tensor(280), tensor(21), tensor(29), tensor(97), tensor(7), tensor(172), tensor(6), tensor(397), tensor(36), tensor(134), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(280), tensor(21), tensor(29), tensor(97), tensor(7), tensor(172), tensor(6), tensor(397), tensor(36), tensor(134), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.1406), tensor(0.6248), tensor(-0.1775), tensor(-0.0899), tensor(-0.2103), tensor(0.6778), tensor(-0.4795), tensor(-0.6255), tensor(-0.8369), tensor(0.9090), tensor(0.5809), tensor(0.5518), tensor(0.2556), tensor(-0.1485), tensor(-0.0175), tensor(-0.3420), tensor(1.0069), tensor(0.0923), tensor(-0.0783), tensor(0.5203), tensor(0.5814), tensor(-0.4557), tensor(-0.2757), tensor(-0.7138), tensor(-0.1619), tensor(0.7562), tensor(0.2028), tensor(-0.6230), tensor(-0.0706), tensor(-0.6337), tensor(0.3270), tensor(0.4622), tensor(0.7288), tensor(-0.1924), tensor(-0.0891), tensor(-0.0721), tensor(-0.4069), tensor(-0.7046), tensor(-0.1451), tensor(0.2944), tensor(0.1115), tensor(0.5206), tensor(-0.2681), tensor(-1.1392), tensor(-0.1759), tensor(-0.4458), tensor(-0.1670), tensor(-0.3144), tensor(0.0174), tensor(-0.0540), tensor(0.2224), tensor(0.8146), tensor(0.4876), tensor(0.2195), tensor(0.3840), tensor(-0.1061), tensor(0.4558), tensor(0.2565), tensor(-0.2509), tensor(-0.1547), tensor(-0.3640), tensor(0.3248), tensor(0.0939), tensor(-0.5189), tensor(0.3629), tensor(-0.2266), tensor(0.0561), tensor(-0.4797), tensor(-0.4916), tensor(-0.0381), tensor(0.6459), tensor(0.2290), tensor(0.1389), tensor(-0.4287), tensor(0.0085), tensor(-0.5571), tensor(-0.1032), tensor(0.1630), tensor(0.1073), tensor(-0.0825), tensor(-0.0182), tensor(0.2683), tensor(0.3352), tensor(0.1355), tensor(-0.2545), tensor(0.3493), tensor(0.1190), tensor(-0.6093), tensor(0.0739), tensor(-0.0484), tensor(0.4114), tensor(-1.5148), tensor(-1.0970), tensor(-0.2340), tensor(0.7656), tensor(-0.5047), tensor(-0.5005), tensor(0.3177), tensor(0.1485), tensor(-0.1185), ...]]"
1,"[tensor(1), tensor(20), tensor(14), tensor(7), tensor(13), tensor(8), tensor(74), tensor(32), tensor(4), tensor(50), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(20), tensor(14), tensor(7), tensor(13), tensor(8), tensor(74), tensor(32), tensor(4), tensor(50), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.3389), tensor(0.3181), tensor(-0.4009), tensor(0.0327), tensor(0.0834), tensor(0.1306), tensor(-0.1251), tensor(-0.1513), tensor(-0.1883), tensor(0.7268), tensor(0.2345), tensor(0.3493), tensor(-0.0661), tensor(-0.0914), tensor(0.1656), tensor(-0.3151), tensor(0.3214), tensor(-0.0784), tensor(-0.2522), tensor(0.0723), tensor(0.0920), tensor(-0.1335), tensor(0.4806), tensor(-0.1740), tensor(-0.3828), tensor(0.1145), tensor(0.1903), tensor(-0.4766), tensor(-0.2232), tensor(-0.2443), tensor(0.2288), tensor(0.4340), tensor(0.3422), tensor(-0.8185), tensor(0.1021), tensor(0.1560), tensor(-0.2659), tensor(-0.3295), tensor(-0.1093), tensor(0.2122), tensor(0.0703), tensor(-0.1089), tensor(-0.1405), tensor(-0.8669), tensor(-0.3174), tensor(-0.1412), tensor(0.2463), tensor(-0.3545), tensor(-0.1413), tensor(-0.1128), tensor(0.1473), tensor(0.5440), tensor(0.1202), tensor(-0.0851), tensor(-0.3618), tensor(-0.1666), tensor(0.5508), tensor(0.5932), tensor(0.4621), tensor(-0.1752), tensor(0.2096), tensor(0.3319), tensor(-0.4131), tensor(-0.1092), tensor(-0.2223), tensor(0.0111), tensor(-0.2731), tensor(-0.1038), tensor(-0.8047), tensor(0.6019), tensor(0.2002), tensor(0.6634), tensor(0.3155), tensor(-0.6797), tensor(0.1174), tensor(-0.6728), tensor(-0.2563), tensor(0.0622), tensor(-0.2005), tensor(0.5590), tensor(-0.2895), tensor(0.4506), tensor(0.1937), tensor(0.2434), tensor(-0.4841), tensor(-0.1134), tensor(-0.1439), tensor(0.1765), tensor(0.1393), tensor(-0.1131), tensor(-0.2760), tensor(-0.2704), tensor(-0.5514), tensor(-0.3036), tensor(0.7312), tensor(-0.5125), tensor(-0.6665), tensor(0.3102), tensor(0.1509), tensor(-0.5645), ...]]"
2,"[tensor(1), tensor(20), tensor(46), tensor(3), tensor(4), tensor(13), tensor(1108), tensor(37), tensor(6), tensor(1100), tensor(36), tensor(22), tensor(4), tensor(39), tensor(3011), tensor(3), tensor(4), tensor(1404), tensor(1405), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(20), tensor(46), tensor(3), tensor(4), tensor(13), tensor(1108), tensor(37), tensor(6), tensor(1100), tensor(36), tensor(22), tensor(4), tensor(39), tensor(3011), tensor(3), tensor(4), tensor(1404), tensor(1405), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.1045), tensor(0.0898), tensor(-0.4953), tensor(-0.0399), tensor(-0.2682), tensor(-0.1485), tensor(-0.3909), tensor(-0.4440), tensor(-0.2988), tensor(0.3910), tensor(-0.0841), tensor(-0.4195), tensor(0.0560), tensor(-0.4834), tensor(-0.1418), tensor(0.0508), tensor(0.6889), tensor(-0.3389), tensor(-0.5127), tensor(0.1819), tensor(0.4584), tensor(-0.3689), tensor(0.3095), tensor(-0.1769), tensor(-0.4142), tensor(0.1997), tensor(0.4309), tensor(-0.2969), tensor(-0.0814), tensor(-0.1797), tensor(0.7351), tensor(0.5656), tensor(0.4160), tensor(0.0532), tensor(0.3270), tensor(0.3610), tensor(-0.7914), tensor(-0.3850), tensor(0.2533), tensor(0.8553), tensor(0.2581), tensor(-0.1959), tensor(-0.5223), tensor(-1.1583), tensor(-0.3215), tensor(-0.2011), tensor(-0.0426), tensor(-0.3599), tensor(-0.2746), tensor(0.2455), tensor(0.3682), tensor(0.1911), tensor(0.5976), tensor(-0.1726), tensor(-0.3281), tensor(0.2181), tensor(0.1963), tensor(0.3938), tensor(0.2254), tensor(0.1811), tensor(0.0894), tensor(0.3457), tensor(-0.3267), tensor(0.2154), tensor(0.0347), tensor(0.0352), tensor(0.1016), tensor(-0.2927), tensor(-0.8025), tensor(0.0945), tensor(0.5601), tensor(0.4269), tensor(0.1361), tensor(-1.1977), tensor(0.3328), tensor(-0.3585), tensor(-0.2332), tensor(0.2365), tensor(-0.3865), tensor(0.3403), tensor(-0.6547), tensor(0.0547), tensor(0.0835), tensor(0.3933), tensor(0.1639), tensor(-0.2855), tensor(-0.1968), tensor(0.0567), tensor(0.2798), tensor(-0.2193), tensor(-0.8428), tensor(-0.7396), tensor(-0.4698), tensor(-0.1582), tensor(1.3280), tensor(-0.6528), tensor(-0.8485), tensor(0.5674), tensor(0.0138), tensor(-0.2001), ...]]"
3,"[tensor(1), tensor(730), tensor(80), tensor(273), tensor(6), tensor(77), tensor(28), tensor(80), tensor(32), tensor(7207), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(730), tensor(80), tensor(273), tensor(6), tensor(77), tensor(28), tensor(80), tensor(32), tensor(7207), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.3824), tensor(0.2187), tensor(-0.5133), tensor(-0.0112), tensor(-0.3447), tensor(0.4584), tensor(-0.3467), tensor(0.0147), tensor(-0.2500), tensor(0.8801), tensor(-0.2017), tensor(0.6635), tensor(0.1622), tensor(-0.0961), tensor(0.4386), tensor(-0.0330), tensor(0.4198), tensor(0.2950), tensor(-0.1920), tensor(0.6327), tensor(0.0715), tensor(-0.0977), tensor(0.2392), tensor(-0.6131), tensor(0.5093), tensor(0.2635), tensor(0.1690), tensor(-0.3976), tensor(-0.2303), tensor(-0.2108), tensor(0.4485), tensor(0.2876), tensor(0.4865), tensor(-0.4467), tensor(0.2265), tensor(-0.0466), tensor(-0.3351), tensor(-0.4790), tensor(0.2439), tensor(0.1989), tensor(-0.2167), tensor(-0.0307), tensor(-0.3031), tensor(-1.0307), tensor(-0.2003), tensor(-0.4702), tensor(-0.0266), tensor(-0.3039), tensor(-0.0828), tensor(-0.1561), tensor(0.1352), tensor(0.6142), tensor(0.5322), tensor(0.0159), tensor(-0.1555), tensor(0.0010), tensor(0.3622), tensor(0.4626), tensor(0.0901), tensor(0.4114), tensor(-0.1373), tensor(0.3893), tensor(-0.1157), tensor(-0.3617), tensor(-0.2298), tensor(0.0114), tensor(-0.0188), tensor(-0.0673), tensor(-0.3911), tensor(0.0481), tensor(0.3722), tensor(0.2218), tensor(0.0137), tensor(-0.3431), tensor(0.3242), tensor(-0.4448), tensor(0.0726), tensor(-0.0114), tensor(-0.3342), tensor(0.1198), tensor(-0.1216), tensor(0.3488), tensor(-0.0841), tensor(0.1515), tensor(-0.6827), tensor(-0.3075), tensor(0.0693), tensor(-0.2588), tensor(0.0751), tensor(-0.3448), tensor(-0.2797), tensor(-0.5361), tensor(-0.6762), tensor(-0.4810), tensor(0.8319), tensor(-0.4966), tensor(-0.8377), tensor(-0.0826), tensor(-0.3026), tensor(-0.0174), ...]]"
4,"[tensor(1), tensor(46), tensor(6), tensor(120), tensor(52), tensor(85), tensor(9), tensor(90), tensor(215), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(46), tensor(6), tensor(120), tensor(52), tensor(85), tensor(9), tensor(90), tensor(215), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.3231), tensor(0.1184), tensor(0.0721), tensor(-0.0348), tensor(-0.7514), tensor(-0.1560), tensor(-0.6320), tensor(-0.2368), tensor(-0.3031), tensor(0.7405), tensor(0.3723), tensor(0.6165), tensor(-0.0450), tensor(-0.5056), tensor(-0.0399), tensor(-0.4700), tensor(0.4389), tensor(-0.0313), tensor(-0.3697), tensor(-0.1223), tensor(-0.2627), tensor(-0.1827), tensor(0.0007), tensor(-0.5413), tensor(0.0479), tensor(-0.0856), tensor(0.4709), tensor(-0.4601), tensor(-0.4802), tensor(0.3218), tensor(-0.0638), tensor(-0.0197), tensor(0.4005), tensor(-0.6807), tensor(0.3672), tensor(-0.0990), tensor(-0.8664), tensor(-0.0479), tensor(0.2629), tensor(0.3572), tensor(-0.3964), tensor(-0.6252), tensor(-0.2247), tensor(-0.5050), tensor(-0.1513), tensor(-0.5824), tensor(0.2107), tensor(0.1604), tensor(-0.2412), tensor(0.1355), tensor(0.0584), tensor(0.4019), tensor(0.6784), tensor(0.1662), tensor(-0.0983), tensor(0.4102), tensor(0.1892), tensor(1.0345), tensor(0.0234), tensor(-0.0642), tensor(-0.1055), tensor(0.2516), tensor(-0.0538), tensor(-0.0725), tensor(0.4350), tensor(0.2384), tensor(-0.1860), tensor(-0.3069), tensor(-0.3836), tensor(0.2998), tensor(0.1993), tensor(0.3900), tensor(0.3395), tensor(-0.3345), tensor(-0.4906), tensor(-0.0186), tensor(0.3854), tensor(-0.2316), tensor(-0.4340), tensor(-0.3216), tensor(-0.3771), tensor(0.5615), tensor(0.0873), tensor(-0.2584), tensor(-0.3823), tensor(-0.5116), tensor(-0.5482), tensor(0.1763), tensor(-0.0039), tensor(-0.4223), tensor(-0.4388), tensor(-0.6914), tensor(-0.4481), tensor(-0.1480), tensor(0.8232), tensor(-0.5395), tensor(-0.3744), tensor(0.1980), tensor(-0.1703), tensor(-0.2709), ...]]"
...,...,...,...
36404,"[tensor(1), tensor(20), tensor(40), tensor(18), tensor(3), tensor(4), tensor(90), tensor(37), tensor(6), tensor(31), tensor(220), tensor(145), tensor(3), tensor(4), tensor(683), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(20), tensor(40), tensor(18), tensor(3), tensor(4), tensor(90), tensor(37), tensor(6), tensor(31), tensor(220), tensor(145), tensor(3), tensor(4), tensor(683), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.4316), tensor(0.4016), tensor(-0.2763), tensor(0.2867), tensor(-0.4603), tensor(-0.2395), tensor(-0.6733), tensor(-0.4215), tensor(-0.8731), tensor(1.0322), tensor(0.0469), tensor(0.1936), tensor(0.3535), tensor(-0.4066), tensor(0.0947), tensor(-0.3053), tensor(0.2500), tensor(-0.4189), tensor(0.0742), tensor(0.5406), tensor(0.4396), tensor(-0.4409), tensor(-0.0678), tensor(-0.3414), tensor(0.5241), tensor(0.3352), tensor(0.5825), tensor(-0.6823), tensor(-0.2613), tensor(-0.2518), tensor(0.4980), tensor(0.2873), tensor(0.4910), tensor(-0.5593), tensor(0.3786), tensor(-0.1725), tensor(-0.6319), tensor(0.0084), tensor(-0.3540), tensor(0.5706), tensor(-0.2838), tensor(-0.4154), tensor(-0.3804), tensor(-0.8917), tensor(-0.2854), tensor(-0.2011), tensor(-0.1150), tensor(-0.3838), tensor(0.1694), tensor(-0.1765), tensor(0.0528), tensor(0.8229), tensor(0.6885), tensor(-0.1720), tensor(-0.1053), tensor(-0.3100), tensor(0.8358), tensor(0.4609), tensor(0.0658), tensor(0.2562), tensor(0.1903), tensor(-0.1107), tensor(-0.3531), tensor(0.0526), tensor(0.3510), tensor(0.3654), tensor(-0.1600), tensor(-0.2209), tensor(-0.6530), tensor(-0.2872), tensor(0.6236), tensor(0.2278), tensor(0.3891), tensor(-0.5704), tensor(0.1610), tensor(-0.1262), tensor(0.1056), tensor(-0.0262), tensor(-0.5325), tensor(-0.0433), tensor(-0.6366), tensor(0.0678), tensor(0.4746), tensor(0.3292), tensor(-0.3721), tensor(-0.2329), tensor(-0.0250), tensor(0.0909), tensor(-0.0274), tensor(-0.0692), tensor(-0.3150), tensor(-0.5145), tensor(-0.8965), tensor(-0.5562), tensor(0.9168), tensor(-0.4575), tensor(-0.6507), tensor(0.3869), tensor(-0.4372), tensor(-0.4497), ...]]"
36405,"[tensor(1), tensor(55), tensor(43), tensor(1435), tensor(5), tensor(4830), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(55), tensor(43), tensor(1435), tensor(5), tensor(4830), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.2939), tensor(-0.0849), tensor(-0.2186), tensor(0.0625), tensor(-0.5611), tensor(0.2214), tensor(-0.2852), tensor(-0.6773), tensor(-0.8304), tensor(1.0698), tensor(0.1501), tensor(0.2046), tensor(0.3009), tensor(-0.3827), tensor(-0.4909), tensor(0.0505), tensor(0.4222), tensor(-0.5277), tensor(-0.3137), tensor(-0.2530), tensor(0.1995), tensor(-0.4129), tensor(0.3088), tensor(-0.3033), tensor(0.1829), tensor(0.4480), tensor(0.6963), tensor(-0.4250), tensor(-0.1011), tensor(-0.1396), tensor(0.1337), tensor(0.2851), tensor(0.1642), tensor(-0.5413), tensor(-0.0891), tensor(0.3981), tensor(-0.8821), tensor(0.1348), tensor(-0.1966), tensor(0.9224), tensor(0.0430), tensor(-0.3046), tensor(-0.2029), tensor(-1.0047), tensor(-0.5944), tensor(-0.2955), tensor(0.3050), tensor(0.1992), tensor(-0.2020), tensor(0.1204), tensor(0.3570), tensor(0.5193), tensor(0.3108), tensor(0.3090), tensor(-0.1804), tensor(0.0737), tensor(0.3550), tensor(0.2387), tensor(0.2957), tensor(0.1695), tensor(-0.1098), tensor(0.2107), tensor(-0.0580), tensor(0.2158), tensor(0.2450), tensor(0.1549), tensor(-0.3035), tensor(-0.2027), tensor(-0.9892), tensor(0.2074), tensor(0.2187), tensor(0.5650), tensor(0.0983), tensor(-0.7598), tensor(0.2261), tensor(0.1483), tensor(-0.1000), tensor(0.1387), tensor(-0.3147), tensor(0.5460), tensor(-0.7397), tensor(0.2354), tensor(0.1839), tensor(0.3727), tensor(0.3081), tensor(0.1597), tensor(-0.5765), tensor(0.2367), tensor(-0.1976), tensor(0.2513), tensor(-0.5909), tensor(-0.5423), tensor(0.0507), tensor(-0.2847), tensor(0.9723), tensor(-0.7386), tensor(-0.7947), tensor(0.0938), tensor(-0.1125), tensor(-0.1760), ...]]"
36406,"[tensor(1), tensor(17), tensor(49), tensor(39), tensor(7), tensor(8), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(17), tensor(49), tensor(39), tensor(7), tensor(8), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.4878), tensor(0.6549), tensor(-0.7299), tensor(-0.5570), tensor(-0.2978), tensor(0.3435), tensor(-0.7120), tensor(-0.9591), tensor(-0.2978), tensor(0.8582), tensor(0.4837), tensor(-0.0921), tensor(-0.5118), tensor(-0.2703), tensor(-0.2174), tensor(-0.3345), tensor(0.5080), tensor(-0.5852), tensor(0.1651), tensor(0.0834), tensor(0.2302), tensor(-0.0004), tensor(0.1230), tensor(-0.4359), tensor(-0.5269), tensor(0.0336), tensor(0.9650), tensor(-0.4161), tensor(-0.0170), tensor(-0.2181), tensor(0.9166), tensor(0.4108), tensor(0.3747), tensor(-0.5599), tensor(-0.0382), tensor(-0.3904), tensor(-0.2777), tensor(-0.0560), tensor(-0.4425), tensor(0.4997), tensor(-0.0436), tensor(-0.8891), tensor(-0.3086), tensor(-0.7914), tensor(-0.7663), tensor(-0.1435), tensor(0.1670), tensor(-0.3012), tensor(-0.2897), tensor(-0.0781), tensor(-0.0054), tensor(0.6871), tensor(0.3481), tensor(-0.2061), tensor(-0.2053), tensor(-0.1001), tensor(0.6577), tensor(0.3754), tensor(0.2396), tensor(0.3099), tensor(0.2897), tensor(0.3528), tensor(-0.6482), tensor(0.1045), tensor(-0.0589), tensor(0.3058), tensor(-0.4156), tensor(-0.2549), tensor(-0.7625), tensor(0.0871), tensor(0.3127), tensor(0.2122), tensor(0.4578), tensor(-0.8982), tensor(0.4648), tensor(-0.3982), tensor(0.0261), tensor(0.2898), tensor(-0.0640), tensor(0.1399), tensor(0.0914), tensor(0.7873), tensor(0.0779), tensor(0.4957), tensor(-0.1688), tensor(-0.2431), tensor(-0.1718), tensor(0.1969), tensor(-0.1349), tensor(-0.0628), tensor(-0.0872), tensor(-0.9957), tensor(-0.7920), tensor(-0.7462), tensor(0.3809), tensor(-0.5819), tensor(-0.6891), tensor(0.5487), tensor(0.1435), tensor(-0.9014), ...]]"
36407,"[tensor(1), tensor(10), tensor(3), tensor(535), tensor(122), tensor(78), tensor(86), tensor(19), tensor(5497), tensor(629), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(10), tensor(3), tensor(535), tensor(122), tensor(78), tensor(86), tensor(19), tensor(5497), tensor(629), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.1957), tensor(0.2496), tensor(-0.2653), tensor(0.0641), tensor(-0.4143), tensor(-0.1122), tensor(-0.2588), tensor(-0.3179), tensor(-0.5631), tensor(0.5238), tensor(0.3114), tensor(0.2015), tensor(0.3615), tensor(-0.3474), tensor(0.0215), tensor(0.0585), tensor(0.5456), tensor(-0.2098), tensor(-0.2164), tensor(0.4018), tensor(0.5231), tensor(-0.4594), tensor(0.1489), tensor(-0.2067), tensor(0.1044), tensor(-0.0012), tensor(0.2088), tensor(-0.1935), tensor(-0.2680), tensor(-0.3899), tensor(0.5761), tensor(0.3510), tensor(0.3388), tensor(-0.4777), tensor(0.1069), tensor(0.3066), tensor(-0.1857), tensor(-0.0730), tensor(-0.2964), tensor(0.5066), tensor(0.2428), tensor(-0.1849), tensor(-0.2913), tensor(-0.6468), tensor(-0.1465), tensor(-0.2272), tensor(-0.0105), tensor(-0.3866), tensor(-0.3053), tensor(0.3204), tensor(-0.0491), tensor(0.4870), tensor(0.2978), tensor(-0.2321), tensor(0.1712), tensor(0.1164), tensor(0.1176), tensor(0.3617), tensor(-0.0553), tensor(0.0456), tensor(0.1587), tensor(-0.1391), tensor(-0.2846), tensor(-0.1520), tensor(0.1487), tensor(0.2404), tensor(-0.1852), tensor(-0.0973), tensor(-0.4611), tensor(-0.1566), tensor(0.5268), tensor(0.6563), tensor(0.0088), tensor(-0.1611), tensor(-0.0265), tensor(-0.1837), tensor(0.2477), tensor(-0.0438), tensor(-0.2053), tensor(-0.0343), tensor(-0.0454), tensor(0.0918), tensor(0.0171), tensor(0.2010), tensor(-0.0925), tensor(-0.2737), tensor(0.0191), tensor(0.1558), tensor(0.2431), tensor(-0.3254), tensor(-0.5947), tensor(0.0655), tensor(-0.4466), tensor(-0.1569), tensor(0.8178), tensor(-0.4832), tensor(-0.5162), tensor(-0.0178), tensor(-0.0204), tensor(-0.1514), ...]]"


Unnamed: 0,sequence,target,image
0,"[tensor(1), tensor(13), tensor(8), tensor(7), tensor(25), tensor(8), tensor(87), tensor(3), tensor(4), tensor(41), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(13), tensor(8), tensor(7), tensor(25), tensor(8), tensor(87), tensor(3), tensor(4), tensor(41), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.1574), tensor(0.5240), tensor(-0.1646), tensor(-0.4492), tensor(-0.1369), tensor(0.0135), tensor(-0.4912), tensor(-0.5353), tensor(-0.3384), tensor(0.5774), tensor(0.3731), tensor(0.2721), tensor(0.3061), tensor(-0.0813), tensor(-0.0961), tensor(-0.1329), tensor(0.5808), tensor(0.1491), tensor(-0.2382), tensor(0.1255), tensor(0.2206), tensor(-0.0425), tensor(0.2134), tensor(-0.1477), tensor(-0.2269), tensor(0.1593), tensor(0.2901), tensor(-0.5571), tensor(-0.1506), tensor(-0.2339), tensor(0.2801), tensor(0.3152), tensor(0.4599), tensor(-0.3976), tensor(-0.0118), tensor(0.0662), tensor(-0.3855), tensor(-0.0384), tensor(-0.0571), tensor(0.3091), tensor(0.0749), tensor(-0.3807), tensor(-0.1222), tensor(-0.7298), tensor(-0.3165), tensor(-0.2439), tensor(0.2341), tensor(-0.3919), tensor(-0.0237), tensor(0.1298), tensor(0.4268), tensor(0.5546), tensor(0.3199), tensor(0.0605), tensor(0.3142), tensor(0.0548), tensor(0.1948), tensor(0.3160), tensor(0.2159), tensor(-0.1696), tensor(0.3064), tensor(0.3680), tensor(-0.1781), tensor(0.1773), tensor(0.0493), tensor(-0.1654), tensor(-0.4430), tensor(-0.3569), tensor(-0.6341), tensor(0.2927), tensor(0.4210), tensor(0.1476), tensor(0.5620), tensor(-0.5013), tensor(0.1913), tensor(-0.5178), tensor(-0.0014), tensor(0.0461), tensor(-0.0896), tensor(0.1336), tensor(-0.0434), tensor(0.4529), tensor(0.0436), tensor(0.0871), tensor(-0.2500), tensor(-0.5120), tensor(-0.0868), tensor(-0.2879), tensor(0.0902), tensor(0.1159), tensor(-0.2162), tensor(-0.6548), tensor(-0.2951), tensor(0.0396), tensor(0.7492), tensor(-0.6921), tensor(-0.6022), tensor(0.2618), tensor(0.3600), tensor(-0.3024), ...]]"
1,"[tensor(1), tensor(10), tensor(3), tensor(14), tensor(379), tensor(452), tensor(1204), tensor(62), tensor(23), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(10), tensor(3), tensor(14), tensor(379), tensor(452), tensor(1204), tensor(62), tensor(23), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(0.0626), tensor(0.2891), tensor(-0.1070), tensor(0.0748), tensor(-0.6388), tensor(-0.0651), tensor(-0.7808), tensor(-0.3747), tensor(-0.4314), tensor(0.6211), tensor(-0.3471), tensor(0.2574), tensor(0.2968), tensor(-0.3971), tensor(0.0228), tensor(-0.1463), tensor(0.1725), tensor(-0.3317), tensor(0.1589), tensor(0.2798), tensor(-0.0501), tensor(0.0566), tensor(0.3680), tensor(-0.1061), tensor(-0.2152), tensor(0.4207), tensor(0.4211), tensor(-0.7258), tensor(-0.0421), tensor(0.1124), tensor(0.5267), tensor(0.3506), tensor(0.1053), tensor(-0.4636), tensor(0.5104), tensor(0.1626), tensor(-0.1980), tensor(0.0595), tensor(0.1096), tensor(0.5185), tensor(-0.2985), tensor(-0.5081), tensor(-0.5382), tensor(-0.3585), tensor(-0.3918), tensor(-0.0984), tensor(0.3251), tensor(-0.7730), tensor(-0.3862), tensor(-0.0110), tensor(0.2296), tensor(0.4485), tensor(0.7734), tensor(-0.0872), tensor(-0.0295), tensor(0.0966), tensor(0.2777), tensor(0.8083), tensor(0.0574), tensor(0.0017), tensor(0.0125), tensor(0.3371), tensor(-0.1171), tensor(-0.0216), tensor(0.3353), tensor(-0.0478), tensor(-0.4350), tensor(0.0049), tensor(-0.2798), tensor(0.0409), tensor(0.1312), tensor(0.6565), tensor(-0.3733), tensor(-0.1560), tensor(0.0120), tensor(-0.3406), tensor(-0.0941), tensor(-0.2806), tensor(-0.4395), tensor(0.3624), tensor(-0.2082), tensor(0.3669), tensor(-0.0798), tensor(0.0973), tensor(0.0203), tensor(-0.2944), tensor(-0.5649), tensor(0.1078), tensor(0.3444), tensor(0.1625), tensor(-0.4343), tensor(-0.6040), tensor(-0.3806), tensor(0.2219), tensor(0.9917), tensor(-0.7459), tensor(-0.5914), tensor(0.3133), tensor(0.2069), tensor(-0.1104), ...]]"
2,"[tensor(1), tensor(18), tensor(21), tensor(7605), tensor(96), tensor(7), tensor(14), tensor(97), tensor(6), tensor(60), tensor(297), tensor(19), tensor(413), tensor(175), tensor(84), tensor(11), tensor(58), tensor(29), tensor(889), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(18), tensor(21), tensor(7605), tensor(96), tensor(7), tensor(14), tensor(97), tensor(6), tensor(60), tensor(297), tensor(19), tensor(413), tensor(175), tensor(84), tensor(11), tensor(58), tensor(29), tensor(889), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.4019), tensor(0.1266), tensor(-0.2683), tensor(0.2931), tensor(-0.3177), tensor(0.1498), tensor(-0.7697), tensor(-0.4825), tensor(-0.6843), tensor(0.5914), tensor(0.0087), tensor(0.5240), tensor(0.3623), tensor(-0.3247), tensor(-0.2905), tensor(-0.3032), tensor(0.3543), tensor(-0.2230), tensor(-0.0839), tensor(0.2160), tensor(0.5979), tensor(-0.6551), tensor(0.1290), tensor(-0.2330), tensor(-0.0288), tensor(0.1657), tensor(0.4313), tensor(-0.3343), tensor(-0.1055), tensor(-0.0033), tensor(0.3802), tensor(0.5453), tensor(0.3513), tensor(-0.2567), tensor(0.0717), tensor(0.3573), tensor(-0.2724), tensor(-0.1868), tensor(-0.5853), tensor(0.9101), tensor(0.1200), tensor(-0.1458), tensor(-0.2665), tensor(-0.7852), tensor(-0.3184), tensor(-0.5016), tensor(0.2196), tensor(-0.5283), tensor(-0.0839), tensor(0.3318), tensor(0.3118), tensor(0.5926), tensor(0.5144), tensor(-0.3052), tensor(0.0005), tensor(0.0362), tensor(0.6595), tensor(0.5294), tensor(-0.0776), tensor(-0.1357), tensor(0.4660), tensor(0.2160), tensor(-0.5385), tensor(-0.4271), tensor(0.1483), tensor(0.4925), tensor(0.0326), tensor(-0.1645), tensor(-0.9039), tensor(-0.0749), tensor(0.5263), tensor(0.5013), tensor(0.4031), tensor(-0.4991), tensor(0.1084), tensor(-0.3521), tensor(-0.0175), tensor(0.1089), tensor(-0.0289), tensor(0.0020), tensor(-0.2095), tensor(0.2574), tensor(0.3940), tensor(0.4297), tensor(-0.0937), tensor(-0.0017), tensor(-0.3420), tensor(-0.3405), tensor(0.0651), tensor(0.1200), tensor(0.0641), tensor(-0.4132), tensor(-0.8971), tensor(-0.1333), tensor(0.7378), tensor(-0.8230), tensor(-0.5642), tensor(0.0925), tensor(0.0231), tensor(-0.3118), ...]]"
3,"[tensor(1), tensor(16), tensor(3), tensor(801), tensor(407), tensor(85), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(16), tensor(3), tensor(801), tensor(407), tensor(85), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(0.0062), tensor(0.2175), tensor(-0.3213), tensor(0.0275), tensor(-0.4646), tensor(-0.1892), tensor(-0.4391), tensor(-0.2791), tensor(-0.0749), tensor(0.5415), tensor(-0.1626), tensor(0.2948), tensor(-0.0508), tensor(-0.2598), tensor(-0.1096), tensor(-0.1077), tensor(0.5028), tensor(-0.0031), tensor(-0.0744), tensor(-0.2315), tensor(-0.1766), tensor(-0.1855), tensor(-0.1873), tensor(-0.0762), tensor(0.3184), tensor(0.4580), tensor(0.7229), tensor(-0.3795), tensor(-0.2055), tensor(-0.1197), tensor(0.4585), tensor(0.2578), tensor(0.3801), tensor(-0.4255), tensor(0.2819), tensor(-0.1232), tensor(-0.4963), tensor(-0.1386), tensor(0.0927), tensor(-0.0174), tensor(-0.1563), tensor(-0.0377), tensor(-0.0252), tensor(-0.9513), tensor(-0.1357), tensor(-0.2188), tensor(0.1043), tensor(-0.7385), tensor(0.1176), tensor(-0.3450), tensor(0.3953), tensor(0.8302), tensor(0.5298), tensor(0.1258), tensor(-0.0108), tensor(0.1708), tensor(0.4050), tensor(0.6823), tensor(0.1738), tensor(0.3227), tensor(0.1798), tensor(-0.1224), tensor(-0.0232), tensor(-0.0183), tensor(0.3636), tensor(0.3375), tensor(-0.3556), tensor(-0.2081), tensor(-0.3775), tensor(-0.1163), tensor(0.3062), tensor(0.3769), tensor(0.1757), tensor(-0.5194), tensor(-0.0397), tensor(0.3142), tensor(-0.1048), tensor(-0.2673), tensor(-0.3380), tensor(0.0976), tensor(-0.6549), tensor(0.4513), tensor(0.3252), tensor(0.1420), tensor(-0.1157), tensor(-0.1791), tensor(-0.3389), tensor(0.0245), tensor(-0.1445), tensor(-0.2239), tensor(-0.1414), tensor(-0.2069), tensor(-0.1170), tensor(-0.5025), tensor(0.5727), tensor(-0.2239), tensor(-0.1277), tensor(0.1240), tensor(-0.0672), tensor(-0.0398), ...]]"
4,"[tensor(1), tensor(27), tensor(10), tensor(7), tensor(17), tensor(15), tensor(60), tensor(113), tensor(35), tensor(347), tensor(749), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(27), tensor(10), tensor(7), tensor(17), tensor(15), tensor(60), tensor(113), tensor(35), tensor(347), tensor(749), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.2524), tensor(0.0457), tensor(-0.3864), tensor(0.3061), tensor(-0.4598), tensor(-0.1062), tensor(-0.9026), tensor(-0.3955), tensor(-0.5915), tensor(0.6081), tensor(-0.1351), tensor(0.5235), tensor(0.2712), tensor(-0.3663), tensor(-0.0024), tensor(-0.2619), tensor(0.4273), tensor(-0.4591), tensor(-0.0471), tensor(0.3995), tensor(0.7184), tensor(-0.5871), tensor(0.1925), tensor(-0.3087), tensor(0.1085), tensor(0.1657), tensor(0.4434), tensor(-0.5443), tensor(-0.0237), tensor(-0.1321), tensor(0.8531), tensor(0.1637), tensor(0.4874), tensor(-0.1390), tensor(0.0065), tensor(0.2927), tensor(-0.6289), tensor(0.1278), tensor(-0.2361), tensor(1.1467), tensor(-0.0665), tensor(-0.4785), tensor(-0.5070), tensor(-0.6974), tensor(-0.1888), tensor(-0.6399), tensor(0.1212), tensor(-0.2189), tensor(-0.1154), tensor(0.2302), tensor(0.1814), tensor(0.5829), tensor(0.6085), tensor(-0.4551), tensor(-0.0896), tensor(-0.1585), tensor(0.8420), tensor(0.5083), tensor(-0.2375), tensor(0.1556), tensor(0.3987), tensor(-0.0428), tensor(-0.1490), tensor(-0.1518), tensor(0.2000), tensor(0.4552), tensor(0.1696), tensor(-0.4505), tensor(-0.9201), tensor(-0.2913), tensor(0.3650), tensor(0.6225), tensor(0.2816), tensor(-0.9132), tensor(0.2395), tensor(-0.3327), tensor(-0.1231), tensor(-0.0605), tensor(0.1270), tensor(-0.0564), tensor(-0.3025), tensor(0.3221), tensor(0.2831), tensor(0.4898), tensor(-0.1512), tensor(0.2497), tensor(-0.4249), tensor(-0.1943), tensor(0.2565), tensor(-0.0760), tensor(-0.1728), tensor(-0.5670), tensor(-0.6512), tensor(-0.2871), tensor(0.8219), tensor(-0.5641), tensor(-0.6616), tensor(0.4478), tensor(0.1891), tensor(-0.1476), ...]]"
...,...,...,...
4041,"[tensor(1), tensor(171), tensor(68), tensor(120), tensor(52), tensor(42), tensor(187), tensor(134), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(171), tensor(68), tensor(120), tensor(52), tensor(42), tensor(187), tensor(134), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.0656), tensor(0.9152), tensor(-0.4532), tensor(0.3226), tensor(-0.2026), tensor(0.4178), tensor(-0.5167), tensor(-0.7019), tensor(-0.5624), tensor(0.6595), tensor(0.1428), tensor(0.4995), tensor(0.3493), tensor(-0.1300), tensor(0.1024), tensor(-0.3238), tensor(0.6502), tensor(0.1702), tensor(-0.3469), tensor(0.3367), tensor(0.4492), tensor(-0.3186), tensor(0.0094), tensor(-0.6813), tensor(-0.2712), tensor(0.5468), tensor(0.2200), tensor(-0.6038), tensor(-0.2868), tensor(-0.2477), tensor(0.3294), tensor(0.4596), tensor(0.5240), tensor(-0.0964), tensor(-0.0937), tensor(-0.0113), tensor(-0.1523), tensor(-0.4047), tensor(-0.1232), tensor(0.3250), tensor(-0.3714), tensor(0.4202), tensor(0.2447), tensor(-1.0834), tensor(-0.1531), tensor(-0.6934), tensor(0.1676), tensor(-0.8972), tensor(-0.2911), tensor(-0.4186), tensor(0.3242), tensor(0.6357), tensor(0.5903), tensor(-0.0231), tensor(0.5159), tensor(-0.1366), tensor(0.4882), tensor(0.5124), tensor(-0.2593), tensor(-0.0212), tensor(-0.1614), tensor(0.2914), tensor(0.2493), tensor(-0.2293), tensor(0.5020), tensor(-0.2144), tensor(-0.3141), tensor(-0.4867), tensor(-0.6548), tensor(-0.2738), tensor(0.8293), tensor(0.4789), tensor(0.1640), tensor(-0.2197), tensor(0.1526), tensor(-0.2662), tensor(0.1009), tensor(0.1661), tensor(-0.1258), tensor(-0.0396), tensor(-0.0319), tensor(0.3506), tensor(0.3802), tensor(0.0013), tensor(-0.1101), tensor(0.4173), tensor(0.0043), tensor(0.0016), tensor(0.3023), tensor(0.3593), tensor(0.2923), tensor(-1.1010), tensor(-0.6615), tensor(-0.4293), tensor(0.5116), tensor(-0.7376), tensor(-0.5022), tensor(0.3469), tensor(0.1958), tensor(0.0396), ...]]"
4042,"[tensor(1), tensor(20), tensor(3351), tensor(8641), tensor(11), tensor(23), tensor(1039), tensor(2413), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(20), tensor(3351), tensor(8641), tensor(11), tensor(23), tensor(1039), tensor(2413), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.0592), tensor(0.2914), tensor(-0.4901), tensor(0.1435), tensor(-0.6344), tensor(0.2497), tensor(-0.4544), tensor(-0.3666), tensor(-0.4852), tensor(0.7177), tensor(0.0533), tensor(0.0539), tensor(0.1328), tensor(-0.2880), tensor(-0.1122), tensor(0.1557), tensor(0.6630), tensor(-0.6257), tensor(0.0170), tensor(-0.0086), tensor(0.3657), tensor(-0.3847), tensor(-0.0935), tensor(-0.6297), tensor(-0.1048), tensor(0.0629), tensor(0.5651), tensor(-0.4006), tensor(-0.1338), tensor(-0.2331), tensor(0.5645), tensor(0.4385), tensor(0.5121), tensor(-0.3613), tensor(0.0416), tensor(0.0843), tensor(-0.8103), tensor(0.1554), tensor(-0.3432), tensor(0.6052), tensor(-0.2975), tensor(-0.5872), tensor(-0.1821), tensor(-1.1435), tensor(-0.4660), tensor(-0.2878), tensor(0.3529), tensor(-0.1348), tensor(0.0758), tensor(0.0709), tensor(0.1799), tensor(0.5905), tensor(0.5514), tensor(-0.2719), tensor(0.0751), tensor(0.2590), tensor(0.5403), tensor(0.5220), tensor(0.3370), tensor(0.2316), tensor(0.4528), tensor(-0.0665), tensor(-0.4867), tensor(0.2181), tensor(0.3202), tensor(0.2137), tensor(-0.2045), tensor(-0.2143), tensor(-0.8893), tensor(0.0108), tensor(0.3015), tensor(0.3900), tensor(0.4101), tensor(-1.1636), tensor(0.4464), tensor(-0.3842), tensor(0.0441), tensor(0.0733), tensor(-0.3994), tensor(0.4999), tensor(-0.3590), tensor(0.1978), tensor(0.1525), tensor(0.2776), tensor(-0.2256), tensor(0.2134), tensor(0.1262), tensor(0.0599), tensor(0.1632), tensor(0.1959), tensor(-0.3407), tensor(-0.6354), tensor(-0.4786), tensor(-0.6729), tensor(0.7828), tensor(-0.5838), tensor(-0.6604), tensor(0.4859), tensor(0.1083), tensor(0.0134), ...]]"
4043,"[tensor(1), tensor(43), tensor(78), tensor(3), tensor(4), tensor(42), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(43), tensor(78), tensor(3), tensor(4), tensor(42), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.0359), tensor(0.5903), tensor(-0.4362), tensor(-0.2188), tensor(-0.4905), tensor(0.3647), tensor(-0.7030), tensor(-1.1457), tensor(-0.8413), tensor(0.9564), tensor(0.1925), tensor(0.6394), tensor(0.6059), tensor(0.0869), tensor(-0.0847), tensor(0.0396), tensor(0.0192), tensor(-0.2722), tensor(-0.1822), tensor(0.4194), tensor(0.8097), tensor(-0.6182), tensor(0.1961), tensor(-0.4138), tensor(0.0105), tensor(0.2103), tensor(0.3722), tensor(-0.4882), tensor(-0.1564), tensor(-0.0456), tensor(0.2671), tensor(0.4247), tensor(0.5496), tensor(-0.7208), tensor(-0.0560), tensor(0.1887), tensor(-0.3188), tensor(-0.2877), tensor(-0.2215), tensor(0.7153), tensor(-0.5357), tensor(0.3225), tensor(-0.6256), tensor(-0.8511), tensor(-1.0527), tensor(-0.3137), tensor(-0.2507), tensor(-0.7342), tensor(-0.4733), tensor(0.2552), tensor(-0.1589), tensor(1.0088), tensor(0.9257), tensor(0.0914), tensor(0.3528), tensor(-0.3617), tensor(0.6919), tensor(0.2470), tensor(-0.2484), tensor(0.1552), tensor(-0.1196), tensor(0.3938), tensor(-0.1859), tensor(-0.3016), tensor(0.6244), tensor(0.3083), tensor(0.2000), tensor(-0.2520), tensor(-0.7388), tensor(0.1927), tensor(1.0014), tensor(0.6634), tensor(0.5951), tensor(-0.5863), tensor(0.1026), tensor(-0.5126), tensor(0.3335), tensor(-0.2618), tensor(-0.4085), tensor(-0.0329), tensor(-0.5353), tensor(0.6612), tensor(0.1862), tensor(0.2787), tensor(-0.0374), tensor(-0.1186), tensor(-0.0857), tensor(-0.1730), tensor(-0.4881), tensor(-0.2826), tensor(0.2064), tensor(-0.6221), tensor(-1.5272), tensor(-0.1478), tensor(1.3506), tensor(-1.0736), tensor(-0.4476), tensor(0.0429), tensor(-0.2042), tensor(-0.2622), ...]]"
4044,"[tensor(1), tensor(27), tensor(192), tensor(92), tensor(2811), tensor(52), tensor(19), tensor(637), tensor(4), tensor(39), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2)]","[tensor(27), tensor(192), tensor(92), tensor(2811), tensor(52), tensor(19), tensor(637), tensor(4), tensor(39), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(2), tensor(0)]","[[tensor(-0.5058), tensor(0.4205), tensor(-1.6087), tensor(-0.7068), tensor(-0.2418), tensor(0.0077), tensor(-0.2238), tensor(-0.6772), tensor(0.1125), tensor(0.8237), tensor(0.4941), tensor(-0.1922), tensor(-0.0088), tensor(-0.7087), tensor(0.1688), tensor(-0.3344), tensor(0.6629), tensor(0.0718), tensor(-0.0482), tensor(-0.0328), tensor(0.1761), tensor(0.0907), tensor(0.0346), tensor(-0.3966), tensor(-0.2319), tensor(0.5929), tensor(0.4989), tensor(-0.4789), tensor(0.0451), tensor(-0.3729), tensor(0.1132), tensor(-0.0437), tensor(0.4555), tensor(-0.6735), tensor(-0.1772), tensor(-0.1239), tensor(-0.3573), tensor(-0.2061), tensor(-0.7156), tensor(0.1845), tensor(-0.3893), tensor(-0.1554), tensor(-0.5098), tensor(-1.0611), tensor(-0.5606), tensor(-0.1466), tensor(0.3993), tensor(-0.4448), tensor(-0.3781), tensor(0.8573), tensor(0.1721), tensor(0.3506), tensor(0.2700), tensor(-0.0675), tensor(0.3878), tensor(-0.1151), tensor(0.9412), tensor(0.2498), tensor(0.3788), tensor(-0.2290), tensor(0.0243), tensor(0.3510), tensor(-0.4593), tensor(0.5528), tensor(-0.1430), tensor(-0.5836), tensor(-0.4842), tensor(0.0848), tensor(-1.1328), tensor(0.0628), tensor(0.4510), tensor(0.5814), tensor(0.5509), tensor(-0.4656), tensor(0.8338), tensor(-0.0225), tensor(0.0271), tensor(-0.0483), tensor(-0.1316), tensor(0.3730), tensor(-0.2024), tensor(1.0808), tensor(0.1503), tensor(0.1904), tensor(-0.5433), tensor(-0.2939), tensor(-0.1405), tensor(0.2208), tensor(0.1401), tensor(-0.4636), tensor(-0.6541), tensor(-0.4953), tensor(0.1081), tensor(-0.3202), tensor(-0.0305), tensor(-0.3415), tensor(-0.4114), tensor(0.0564), tensor(0.0779), tensor(-0.3032), ...]]"


**Decoder Transformer**

In [40]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=(max_length+2)):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        

    def forward(self, x):
        if self.pe.size(0) < x.size(0):
            self.pe = self.pe.repeat(x.size(0), 1, 1).to(device)
        self.pe = self.pe[:x.size(0), : , : ]
        
        x = x + self.pe
        return self.dropout(x)

In [41]:
class ImageCaptionModel(nn.Module):
    def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size):
        super(ImageCaptionModel, self).__init__()
        self.pos_encoder = PositionalEncoding(embedding_size, 0.1)
        self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model =  embedding_size, nhead = n_head)
        self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer)
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(vocab_size , embedding_size)
        self.last_linear_layer = nn.Linear(embedding_size, vocab_size)
        self.init_weights()
        self.n_head = n_head

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.last_linear_layer.bias.data.zero_()
        self.last_linear_layer.weight.data.uniform_(-initrange, initrange)

    def generate_Mask(self, size, decoder_inp):
        decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0))

        decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0))
        decoder_input_pad_mask_bool = decoder_inp == 0

        return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool

    def forward(self, encoded_image, decoder_inp):
        encoded_image = encoded_image.permute(1,0,2)

        decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size)
        
        decoder_inp_embed = self.pos_encoder(decoder_inp_embed)
        decoder_inp_embed = decoder_inp_embed.permute(1,0,2)
        
        
        decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp)
        decoder_input_mask = decoder_input_mask.to(device)
        decoder_input_pad_mask = decoder_input_pad_mask.to(device)
        decoder_input_pad_mask_bool = decoder_input_pad_mask_bool.to(device)
        
        decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool)
        
        final_output = self.last_linear_layer(decoder_output)

        return final_output,  decoder_input_pad_mask

**Training Model**

In [42]:
EPOCH = 30

ictModel = ImageCaptionModel(16, 4, vocab_size_train, 256).to(device)
optimizer = torch.optim.Adam(ictModel.parameters(), lr = 0.00001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience=2, verbose = True)
criterion = torch.nn.CrossEntropyLoss(reduction='none')
min_val_loss = float('Inf')

In [43]:
for epoch in tqdm(range(EPOCH)):
    total_epoch_train_loss = 0
    total_epoch_valid_loss = 0
    total_train_words = 0
    total_valid_words = 0
    ictModel.train()

    ### Train Loop
    for caption_seq, target_seq, image_embed in train_dataloader:

        optimizer.zero_grad()

        image_embed = image_embed.to(device)
        caption_seq = caption_seq
        target_seq = target_seq

        output, padding_mask = ictModel.forward(image_embed, caption_seq)
        output = output.permute(1, 2, 0)

        loss = criterion(output,target_seq)

        loss_masked = torch.mul(loss, padding_mask)

        final_batch_loss = torch.sum(loss_masked)/torch.sum(padding_mask)

        final_batch_loss.backward()
        optimizer.step()
        total_epoch_train_loss += torch.sum(loss_masked).detach().item()
        total_train_words += torch.sum(padding_mask)

 
    total_epoch_train_loss = total_epoch_train_loss/total_train_words
  

    ### Eval Loop
    ictModel.eval()
    with torch.no_grad():
        for caption_seq, target_seq, image_embed in valid_dataloader:

            image_embed = image_embed.squeeze(1).to(device)
            caption_seq = caption_seq.to(device)
            target_seq = target_seq.to(device)

            output, padding_mask = ictModel.forward(image_embed, caption_seq)
            output = output.permute(1, 2, 0)

            loss = criterion(output,target_seq)

            loss_masked = torch.mul(loss, padding_mask)

            total_epoch_valid_loss += torch.sum(loss_masked).detach().item()
            total_valid_words += torch.sum(padding_mask)

    total_epoch_valid_loss = total_epoch_valid_loss/total_valid_words
  
    print("Epoch -> ", epoch," Training Loss -> ", total_epoch_train_loss.item(), "Eval Loss -> ", total_epoch_valid_loss.item() )
  
    if min_val_loss > total_epoch_valid_loss:
        print("Writing Model at epoch ", epoch)
        torch.save(ictModel, './BestModel')
        min_val_loss = total_epoch_valid_loss
  

    scheduler.step(total_epoch_valid_loss.item())


  0%|          | 0/30 [00:00<?, ?it/s]

IndexError: index out of range in self