In [1]:
using Pkg
Pkg.activate(".")
using JSON
using WordTokenizers
using StatsBase
using Flux,CuArrays
using Flux:onehot
using Base.Iterators:partition
using Metalhead
using JLD
using BSON:@save,@load

include("utils.jl")
BASE_PATH = "../data/"

#--------HYPERPARAMETERS----------#
NUM_SENTENCES = 5
# Find top-k tokens
K = 30
BATCH_SIZE = 64
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
LOG_FREQUENCY = 10
EPOCHS = 300
SAVE_FREQUENCY = 10
global global_step = 1
device = gpu

gpu (generic function with 1 method)

In [2]:
punc = "!#%&()*+.,-/:;=?@[]^_`{|}~"
punctuation = [punc[i] for i in 1:length(punc)]
data = load_data(BASE_PATH,NUM_SENTENCES,punctuation)

captions = [d[1] for d in data]
tokens = cat([tokenize(sentence) for sentence in captions]...,dims=1)
vocab = unique(tokens)
# Sort according to frequencies
freqs = reverse(sort(collect(countmap(tokens)),by=x->x[2]))
top_k_tokens = [freqs[i][1] for i in 1:K]
tokenized_captions = []
for i in 1:length(captions)
    sent_tokens = tokenize(captions[i])
    for j in 1:length(sent_tokens)
        sent_tokens[j] = !(sent_tokens[j] in top_k_tokens) ? "<UNK>" : sent_tokens[j]
    end
    push!(tokenized_captions,sent_tokens)
end
max_length_sentence = maximum([length(cap) for cap in tokenized_captions])
# Pad the sequences
for (i,cap) in enumerate(tokenized_captions)
    if length(cap) < max_length_sentence
        tokenized_captions[i] = [tokenized_captions[i]...,["<PAD>" for i in 1:(max_length_sentence - length(cap))]...]
    end
end
# Define the vocabulary
vocab = [top_k_tokens...,"<UNK>","<PAD>"]
# Define mappings
word2idx = Dict(word=>i for (i,word) in enumerate(vocab))
idx2word = Dict(value=>key for (key,value) in word2idx)
SEQ_LEN = max_length_sentence
# Now - tokenized_captions contains the tokens for each caption in the form of an array

onehotword(word) = Float32.(onehot(word2idx[word],1:length(vocab)))
mb_idxs = partition(1:length(data),BATCH_SIZE)
image_names = [d[2] for d in data]

5-element Array{String,1}:
 "../data/train2014/COCO_train2014_000000318556.jpg"
 "../data/train2014/COCO_train2014_000000116100.jpg"
 "../data/train2014/COCO_train2014_000000318556.jpg"
 "../data/train2014/COCO_train2014_000000116100.jpg"
 "../data/train2014/COCO_train2014_000000379340.jpg"

In [3]:
function extract_embedding_features(image_names)
    # extract features from the images and save them to a file
    vgg = VGG19() |> gpu
    Flux.testmode!(vgg)
    vgg = vgg.layers[1:end-3] |> gpu
    
    features = Dict()
    for im_name in image_names
        if im_name in keys(features)
            continue
        end
        
        img = Metalhead.preprocess(load(im_name)) |> gpu
        out = vgg(img)
        
        features[im_name] = out |> cpu
    end
    
    save("features.jld","features",features)
end

function load_embedding_features()
    load("features.jld")["features"]
end

# extract_embedding_features(image_names)
features = load_embedding_features()

function get_mb(idx,features)
    cap = tokenized_captions[idx]
    img_names = image_names[idx]
    
    mb_captions = []
    mb_features = []
    mb_targets = []
    
    for i in 1:length(img_names)
         push!(mb_features,features[img_names[i]])
    end
    
    mb_features = hcat(mb_features...)
    # Convert to - Array[SEQ_LEN] with each element - [V,BATCH_SIZE]
    for i in 1:SEQ_LEN
        # Extract and form a batch of each word in sequence
        words = hcat([onehotword(sentence[i]) for sentence in cap]...)
        
        if i < SEQ_LEN
            push!(mb_targets,hcat([onehotword(sentence[i + 1]) for sentence in cap]...))
        else
            push!(mb_targets,hcat([onehotword("<PAD>") for sentence in cap]...))
        end
        
        push!(mb_captions,words)
    end
    
    (mb_captions,mb_features,mb_targets)
end

function nullify_grad!(p)
  if typeof(p) <: TrackedArray
    p.grad .= 0.0f0
  end
  return p
end

function zero_grad!(model)
  model = mapleaves(nullify_grad!, model)
end

zero_grad! (generic function with 1 method)

In [11]:
global cnn_encoder = Chain(Dense(4096,EMBEDDING_DIM),x->relu.(x))
global embedding = Chain(Dense(length(vocab),EMBEDDING_DIM))
global rnn_decoder = Chain(LSTM(EMBEDDING_DIM,HIDDEN_DIM))
global decoder = Chain(Dense(HIDDEN_DIM,length(vocab)))

function zero_grad_models()
    global cnn_encoder,embedding,rnn_decoder,decoder
    
    zero_grad!(cnn_encoder)
    zero_grad!(embedding)
    zero_grad!(rnn_decoder)
    zero_grad!(decoder)
end

function to_device()
    global cnn_encoder,embedding,rnn_decoder,decoder
    
	cnn_encoder = cnn_encoder |> device
	embedding = embedding |> device
	rnn_decoder = rnn_decoder |> device
	decoder = decoder |> device
end

function to_cpu()
    global cnn_encoder,embedding,rnn_decoder,decoder
    
	cnn_encoder = cnn_encoder |> cpu
	embedding = embedding |> cpu
	rnn_decoder = rnn_decoder |> cpu
	decoder = decoder |> cpu
end

function reset_rnn()
    global rnn_decoder
    
    Flux.reset!(rnn_decoder.layers[1])
end

function save_models()
    global cnn_encoder,embedding,rnn_decoder,decoder
    
	to_cpu()

    reset_rnn()
    @save "cnn_encoder.bson" cnn_encoder
    @save "embedding.bson" embedding
    @save "rnn_decoder.bson" rnn_decoder
    @save "decoder.bson" decoder

	to_device()
end

function load_models()
    global cnn_encoder,embedding,rnn_decoder,decoder
    
    @load "cnn_encoder.bson" cnn_encoder
    @load "embedding.bson" embedding
    @load "rnn_decoder.bson" rnn_decoder
    @load "decoder.bson" decoder
    to_device()
end

load_models (generic function with 1 method)

In [12]:
# Move models to device
to_device()
println("Move models to respective device...")

function get_loss_val(mb_captions,mb_features,mb_targets)
    global cnn_encoder,embedding,rnn_decoder,decoder
    
    reset_rnn()
    lstm_inp = cnn_encoder(mb_features)
    word_embeddings = embedding.(mb_captions)
    lstm_out = rnn_decoder(lstm_inp)
    predictions = softmax.(decoder.(rnn_decoder.(word_embeddings)))
    sum(Flux.crossentropy.(predictions,mb_targets))
end

model_params = params(params(cnn_encoder)...,params(embedding)...,params(rnn_decoder)...,params(decoder)...)
lr = 1e-4
opt = ADAM(model_params,lr)

for epoch in 1:EPOCHS
    for idx in mb_idxs
    	global global_step

        mb_captions,mb_features,mb_targets = get_mb(idx,features)

        # Move data to device
        mb_captions = device.(mb_captions)
        mb_features = device(mb_features)
        mb_targets = device.(mb_targets)

        zero_grad_models()
        Flux.back!(get_loss_val(mb_captions,mb_features,mb_targets))
        opt()
        global_step += 1
        
        if global_step % LOG_FREQUENCY == 0
            println("---Global Step : $(global_step)")
            println("Loss : $(get_loss_val(mb_captions,mb_features,mb_targets))")
        end
        
        if global_step % SAVE_FREQUENCY == 0
            save_models()
            println("Saved Models!")
        end
    end
end

Move models to respective device...


│   caller = top-level scope at In[12]:18
└ @ Core In[12]:18
│   caller = ip:0x0
└ @ Core :-1


---Global Step : 10
Loss : 39.364044f0 (tracked)
Saved Models!
---Global Step : 20
Loss : 39.364044f0 (tracked)
Saved Models!
---Global Step : 30
Loss : 39.364044f0 (tracked)


InterruptException: InterruptException:

In [6]:
idx = collect(mb_idxs)[1]

5-element Array{Int64,1}:
 1
 2
 3
 4
 5

In [7]:
mb_captions,mb_features,mb_targets = get_mb(idx,features)

# Move data to device
mb_captions = device.(mb_captions)
mb_features = device(mb_features)
mb_targets = device.(mb_targets)

13-element Array{CuArray{Float32,2},1}:
 [1.0 1.0 … 1.0 1.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 1.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 1.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 1.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 0.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 1.0 0.0 … 1.0 0.0]
 [0.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 1.0; … ; 0.0 0.0 … 0.0 

In [8]:
get_loss_val(mb_captions,mb_features,mb_targets,cnn_encoder,embedding,rnn_decoder,decoder)

ArgumentError: ArgumentError: cannot take the CPU address of a CuArray{Float32,2}

In [21]:
to_device()

Chain(Dense(512, 32))

In [22]:
cnn_encoder.layers[1].W

Tracked 256×4096 CuArray{Float32,2}:
 -0.018943    -0.026169     0.0272873   …  -0.0175114    0.0318894 
 -0.0369808    0.0291363    0.0308406       0.0356787   -0.0365992 
 -0.0153231   -0.0110527    0.00839798      0.0101099    0.00530886
 -0.00649967  -0.0217655   -0.0305017      -0.032863     0.0365551 
 -0.0243772   -0.00878064  -0.0234215      -0.0218636    0.0352688 
  0.0255528    0.035949     0.0209666   …   0.0103029    0.00866148
  0.0166642    0.032007     0.0267857      -0.014934     0.0337227 
  0.00662626   0.0167969    0.0115766      -0.0198655    0.0192146 
  0.0192281    0.0308421   -0.0168888       0.0218251   -0.0300144 
 -0.0163103   -0.0256428   -0.00554477      0.0347868    0.0322564 
  0.0349331   -0.00920759  -0.0338403   …   0.00486652   0.0173492 
  0.0289183   -0.0302206    0.0179387       0.0284328   -0.0140727 
  0.0148864    0.0166573    0.0351221      -0.0223551   -0.0184862 
  ⋮                                     ⋱                ⋮         
 -0.0368572

In [15]:
cnn_encoder = cnn_encoder |> device

Chain(Dense(4096, 256), getfield(Main, Symbol("##31#32"))())