In [1]:
using Pkg
Pkg.activate(".")
using JSON
using WordTokenizers
using StatsBase
using Flux,CuArrays
using Flux:onehot
using Base.Iterators:partition

In [2]:
include("utils.jl")
BASE_PATH = "../data/"

"../data/"

In [3]:
punc = "!#%&()*+.,-/:;=?@[]^_`{|}~"
punctuation = [punc[i] for i in 1:length(punc)]

NUM_SENTENCES = 5
data = load_data(BASE_PATH,NUM_SENTENCES,punctuation)

5-element Array{Any,1}:
 ("<s> A very clean and well decorated empty bathroom </s>", "../data/train2014/COCO_train2014_000000318556.jpg")            
 ("<s> A panoramic view of a kitchen and all of its appliances </s>", "../data/train2014/COCO_train2014_000000116100.jpg")   
 ("<s> A blue and white bathroom with butterfly themed wall tiles </s>", "../data/train2014/COCO_train2014_000000318556.jpg")
 ("<s> A panoramic photo of a kitchen and dining room </s>", "../data/train2014/COCO_train2014_000000116100.jpg")            
 ("<s> A graffitied stop sign across the street from a red car  </s>", "../data/train2014/COCO_train2014_000000379340.jpg")  

In [4]:
captions = [d[1] for d in data]
tokens = cat([tokenize(sentence) for sentence in captions]...,dims=1)
vocab = unique(tokens)
# Sort according to frequencies
freqs = reverse(sort(collect(countmap(tokens)),by=x->x[2]))
# Find top-k tokens
K = 30
top_k_tokens = [freqs[i][1] for i in 1:K]
tokenized_captions = []
for i in 1:length(captions)
    sent_tokens = tokenize(captions[i])
    for j in 1:length(sent_tokens)
        sent_tokens[j] = !(sent_tokens[j] in top_k_tokens) ? "<UNK>" : sent_tokens[j]
    end
    push!(tokenized_captions,sent_tokens)
end
max_length_sentence = maximum([length(cap) for cap in tokenized_captions])
# Pad the sequences
for (i,cap) in enumerate(tokenized_captions)
    if length(cap) < max_length_sentence
        tokenized_captions[i] = [tokenized_captions[i]...,["<PAD>" for i in 1:(max_length_sentence - length(cap))]...]
    end
end
# Define the vocabulary
vocab = [top_k_tokens...,"<UNK>","<PAD>"]
# Define mappings
word2idx = Dict(word=>i for (i,word) in enumerate(vocab))
idx2word = Dict(value=>key for (key,value) in word2idx)
SEQ_LEN = max_length_sentence
# Now - tokenized_captions contains the tokens for each caption in the form of an array

13

In [5]:
tokenized_captions

5-element Array{Any,1}:
 ["<s>", "A", "<UNK>", "clean", "and", "well", "<UNK>", "<UNK>", "bathroom", "</s>", "<PAD>", "<PAD>", "<PAD>"]
 ["<s>", "A", "panoramic", "<UNK>", "of", "a", "kitchen", "and", "all", "of", "its", "appliances", "</s>"]     
 ["<s>", "A", "blue", "and", "<UNK>", "bathroom", "with", "<UNK>", "themed", "wall", "<UNK>", "</s>", "<PAD>"] 
 ["<s>", "A", "panoramic", "photo", "of", "a", "kitchen", "and", "dining", "room", "</s>", "<PAD>", "<PAD>"]   
 ["<s>", "A", "graffitied", "stop", "sign", "across", "the", "street", "from", "a", "red", "car", "</s>"]      

In [6]:
onehotword(word) = Float32.(onehot(word2idx[word],1:length(vocab)))

onehotword (generic function with 1 method)

In [7]:
BATCH_SIZE = 64
mb_idxs = partition(1:length(data),BATCH_SIZE)

Base.Iterators.PartitionIterator{UnitRange{Int64}}(1:5, 64)

In [8]:
using Metalhead
using JLD

image_names = [d[2] for d in data]

function extract_embedding_features(image_names)
    # extract features from the images and save them to a file
    vgg = VGG19() |> gpu
    Flux.testmode!(vgg)
    vgg = vgg.layers[1:end-3] |> gpu
    
    features = Dict()
    for im_name in image_names
        if im_name in keys(features)
            continue
        end
        
        img = Metalhead.preprocess(load(im_name)) |> gpu
        out = vgg(img)
        
        features[im_name] = out |> cpu
    end
    
    save("features.jld","features",features)
end

function load_embedding_features()
    load("features.jld")["features"]
end

load_embedding_features (generic function with 1 method)

In [9]:
# extract_embedding_features(image_names)
features = load_embedding_features()

Dict{Any,Any} with 3 entries:
  "../data/train2014/COCO_train2014_… => Float32[10.4405; 0.0; … ; 0.0; 1.86139]
  "../data/train2014/COCO_train2014_… => Float32[3.51554; 0.0; … ; 0.0; 0.0]
  "../data/train2014/COCO_train2014_… => Float32[0.0; 0.0; … ; 0.0; 0.0]

In [10]:
function get_mb(idx,features)
    cap = tokenized_captions[idx]
    img_names = image_names[idx]
    
    mb_captions = []
    mb_features = []
    mb_targets = []
    
    for i in 1:length(img_names)
         push!(mb_features,features[img_names[i]])
    end
    
    mb_features = hcat(mb_features...)
    # Convert to - Array[SEQ_LEN] with each element - [V,BATCH_SIZE]
    for i in 1:SEQ_LEN
        # Extract and form a batch of each word in sequence
        words = hcat([onehotword(sentence[i]) for sentence in cap]...)
        
        if i < SEQ_LEN
            push!(mb_targets,hcat([onehotword(sentence[i + 1]) for sentence in cap]...))
        else
            push!(mb_targets,hcat([onehotword("<PAD>") for sentence in cap]...))
        end
        
        push!(mb_captions,words)
    end
    
    (mb_captions,mb_features,mb_targets)
end

get_mb (generic function with 1 method)

In [11]:
function nullify_grad!(p)
  if typeof(p) <: TrackedArray
    p.grad .= 0.0f0
  end
  return p
end

function zero_grad!(model)
  model = mapleaves(nullify_grad!, model)
end

zero_grad! (generic function with 1 method)

In [12]:
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
cnn_encoder = Chain(Dense(4096,EMBEDDING_DIM),x->relu.(x))
embedding = Chain(Dense(length(vocab),EMBEDDING_DIM))
rnn_decoder = Chain(LSTM(EMBEDDING_DIM,HIDDEN_DIM))
decoder = Chain(Dense(HIDDEN_DIM,length(vocab)))

function reset()
    Flux.reset!(rnn_decoder.layers[1])
end

reset (generic function with 1 method)

In [13]:
function zero_grad_models()
    zero_grad!(cnn_encoder)
    zero_grad!(embedding)
    zero_grad!(rnn_decoder)
    zero_grad!(decoder)
end

zero_grad_models (generic function with 1 method)

In [14]:
using BSON:@save,@load

function save_models()
    reset()
    @save "cnn_encoder.bson" cnn_encoder
    @save "embedding.bson" embedding
    @save "rnn_decoder.bson" rnn_decoder
    @save "decoder.bson" decoder
end

function load_models()
    @load "cnn_encoder.bson" cnn_encoder
    @load "embedding.bson" embedding
    @load "rnn_decoder.bson" rnn_decoder
    @load "decoder.bson" decoder
end

load_models (generic function with 1 method)

In [15]:
function get_loss_val(mb_captions,mb_features,mb_targets)
    reset()
    lstm_inp = cnn_encoder(mb_features)
    word_embeddings = embedding.(mb_captions)
    lstm_out = rnn_decoder(lstm_inp)
    predictions = softmax.(decoder.(rnn_decoder.(word_embeddings)))
    sum(Flux.crossentropy.(predictions,mb_targets))
end

get_loss_val (generic function with 1 method)

In [140]:
model_params = params(params(cnn_encoder)...,params(embedding)...,params(rnn_decoder)...,params(decoder)...)
lr = 1e-4
opt = ADAM(model_params,lr)

LOG_FREQUENCY = 10
EPOCHS = 300
SAVE_FREQUENCY = 10
global_step = 1

for epoch in 1:EPOCHS
    for idx in mb_idxs
        mb_captions,mb_features,mb_targets = get_mb(idx,features)
        zero_grad_models()
        Flux.back!(get_loss_val(mb_captions,mb_features,mb_targets))
        opt()
        global_step += 1
        
        if global_step % LOG_FREQUENCY == 0
            println("---Global Step : $(global_step)")
            println("Loss : $(get_loss_val(mb_captions,mb_features,mb_targets))")
        end
        
        if global_step % SAVE_FREQUENCY == 0
            save_models()
            println("Saved Models!")
        end
    end
end

│   caller = top-level scope at In[140]:3
└ @ Core In[140]:3


---Global Step : 10
Loss : 1.396667f0 (tracked)
Saved Models!
---Global Step : 20
Loss : 1.1391783f0 (tracked)
Saved Models!
---Global Step : 30
Loss : 1.0661407f0 (tracked)
Saved Models!
---Global Step : 40
Loss : 1.0148447f0 (tracked)
Saved Models!
---Global Step : 50
Loss : 0.9833323f0 (tracked)
Saved Models!
---Global Step : 60
Loss : 0.9555775f0 (tracked)
Saved Models!
---Global Step : 70
Loss : 0.9302694f0 (tracked)
Saved Models!
---Global Step : 80
Loss : 0.9072114f0 (tracked)
Saved Models!
---Global Step : 90
Loss : 0.8858993f0 (tracked)
Saved Models!
---Global Step : 100
Loss : 0.8661209f0 (tracked)
Saved Models!
---Global Step : 110
Loss : 0.8477492f0 (tracked)
Saved Models!
---Global Step : 120
Loss : 0.83065426f0 (tracked)
Saved Models!
---Global Step : 130
Loss : 0.8147495f0 (tracked)
Saved Models!
---Global Step : 140
Loss : 0.7999724f0 (tracked)
Saved Models!
---Global Step : 150
Loss : 0.7862482f0 (tracked)
Saved Models!
---Global Step : 160
Loss : 0.7734897f0 (tracked)

In [70]:
vgg = VGG19() |> gpu
Flux.testmode!(vgg)
vgg = vgg.layers[1:end-3] |> gpu

Chain(Conv((3, 3), 3=>64, NNlib.relu), Conv((3, 3), 64=>64, NNlib.relu), getfield(Metalhead, Symbol("##42#48"))(), Conv((3, 3), 64=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu), getfield(Metalhead, Symbol("##43#49"))(), Conv((3, 3), 128=>256, NNlib.relu), Conv((3, 3), 256=>256, NNlib.relu), Conv((3, 3), 256=>256, NNlib.relu), Conv((3, 3), 256=>256, NNlib.relu), getfield(Metalhead, Symbol("##44#50"))(), Conv((3, 3), 256=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), getfield(Metalhead, Symbol("##45#51"))(), Conv((3, 3), 512=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu), getfield(Metalhead, Symbol("##46#52"))(), getfield(Metalhead, Symbol("##47#53"))(), Dense(25088, 4096, NNlib.relu), Dropout{Float32}(0.5f0, false), Dense(4096, 4096, NNlib.relu))

In [122]:
function sample(image_path)
    img = Metalhead.preprocess(load(image_path)) |> gpu
    features = vgg(img) |> cpu
    
    reset()
    prev_word = "<s>"
    lstm_inp = cnn_encoder(features)
    lstm_out = rnn_decoder(lstm_inp)
    output = ""
    
    for i in 1:15
        output = string(output," ",prev_word)
        if prev_word == "</s>"
            break
        end
        word_embeddings = embedding(onehotword(prev_word))
        predictions = softmax(decoder(rnn_decoder(word_embeddings)))
        next_word = idx2word[Flux.argmax(predictions)[1]]
        prev_word = next_word
    end
    
    output
end

sample (generic function with 1 method)

In [128]:
sample(image_names[1])

" <s> A blue and <UNK> bathroom with <UNK> themed wall <UNK> </s>"

In [137]:
reset()
@save "t.bson" rnn_decoder

In [136]:
m = Chain(LSTM(2,3))
@save "t.bson" m

In [17]:
mb_captions,mb_features,mb_targets = get_mb(collect(mb_idxs)[1],features)

(Any[Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[1.0 1.0 … 1.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 1.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 1.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0], Float32[0.0 0.0 

In [19]:
gpu.(mb_captions)

13-element Array{CuArray{Float32,2},1}:
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [1.0 1.0 … 1.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 1.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 1.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 

In [20]:
gpu(mb_features)

4096×5 CuArray{Float32,2}:
  3.51554  10.4405    3.51554  10.4405    0.0    
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
 10.7507    0.0      10.7507    0.0       0.0    
  0.0       6.03384   0.0       6.03384  10.7011 
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       5.73202
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
  2.32066   0.0       2.32066   0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
 10.5443    0.0      10.5443    0.0       0.0    
  ⋮                                              
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
  0.0       0.0       0.0       0.0       0.0    
  5.90421   0.0       5.90421   0.0       0.0    
  0.0       0.0       0

In [21]:
gpu.(mb_targets)

13-element Array{CuArray{Float32,2},1}:
 [1.0 1.0 … 1.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 1.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 1.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 1.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 1.0; … ; 0.0 0.0 … 0.0 