In [1]:
import numpy as np
import torch
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.optim as optim
from nltk.translate.bleu_score import corpus_bleu

In [2]:
import os, sys
sys.path.append(os.path.dirname(sys.path[0]))

In [3]:
from model.encoder_perceiver import EncoderPerceiver
from model.encoder_cnn import EncoderResNet
from model.decoder import CaptioningTransformer
from utils.flickr8k_util import FlickrDataset, CapsCollate
from model.utils import show_image
from experiment.solver import CaptioningSolver

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [5]:
# load data
transform = T.Compose([
                T.Resize((224, 224)),
                T.ToTensor(),
                T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # for flickr8k
            ])

In [6]:
base_dir = "/Users/pinkaew/Stanford Files/Spring 2021/CS 231N/project/multimodalperceiver"
data_location = base_dir + "/data/Flickr8k"

In [7]:
dataset =  FlickrDataset(
    root_dir = data_location + "/Images",
    caption_file = data_location + "/captions.txt",
    transform = transform,
    verbose = True )

building vocab
buidling caption alias


In [8]:
caption_per_image = 5
num_data = len(dataset)
num_figures = num_data/caption_per_image
test_split = 0.2
val_split = 0.2
train_idx = int((1 - test_split - val_split) * num_figures)
train_abs_idx = train_idx * caption_per_image
val_idx = train_idx + int(val_split * num_figures)
val_abs_idx = val_idx * caption_per_image

In [9]:
train_indices = range(train_abs_idx)
val_indices = range(train_abs_idx, val_abs_idx)
test_indices = range(val_abs_idx, num_data)

In [10]:
# writing the dataloader
# setting the constants
BATCH_SIZE = 5
NUM_WORKER = 4

# token to represent the padding
pad_idx = dataset.vocab.stoi["<NULL>"]

loader_train = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    sampler = sampler.SubsetRandomSampler(range(10)),
    collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True)
)

loader_val = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    sampler = sampler.SubsetRandomSampler(val_indices),
    collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True)
)

loader_test = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    sampler = sampler.SubsetRandomSampler(test_indices),
    collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True)
)

In [11]:
# create models
input_dim = 224 * 224
input_channels = 3
feature_dim = 128
word_emb_dim = 256
encoder = EncoderPerceiver(input_dim, input_channels = input_channels, 
                             num_iterations = 2, num_transformer_blocks = 4,
                             num_latents=32, latent_dim = feature_dim, 
                             cross_heads=4, cross_dim_head=16, 
                             latent_heads=4, latent_dim_head=16,
                             attn_dropout=0.5, ff_dropout=0.5)
decoder = CaptioningTransformer(dataset.vocab.stoi, input_dim=feature_dim, 
                                  wordvec_dim = word_emb_dim, max_length=30)

#encoder_r = EncoderResNet(feature_dim)
#decoder_r = CaptioningTransformer(dataset.vocab.stoi, input_dim=feature_dim, 
#                                  wordvec_dim = word_emb_dim, max_length=30)

In [12]:
encoder_optimizer = optim.Adam(encoder.parameters())
decoder_optimizer = optim.Adam(decoder.parameters())

In [13]:
solver = CaptioningSolver(encoder, decoder, 
                          encoder_optimizer, decoder_optimizer,
                          dataset.vocab.itos, dataset.caption_alias, 
                          device = device)

In [14]:
#solver.train(loader_train, loader_val, num_epochs=1)

In [None]:
for images, captions in loader_train:
    show_image(images[0])

applying transformation
applying transformation
