In [None]:
require 'xlua'
require 'optim'
require 'nn'
model_utils = require 'model_utils'
dofile './provider.lua'
c = require 'trepl.colorize'
require 'image'

cmd_params = {}
----- from the opt settings ------
cmd_params.save = 'logs/trial'
cmd_params.batchSize = 128
cmd_params.learningRate = 1
cmd_params.learningRateDecay = 1e-7
cmd_params.weightDecay = 0.0005
cmd_params.momentum = 0.9
cmd_params.epoch_step = 25
cmd_params.model_local = 'atten_1_softmax_conv/vgg_conv'
cmd_params.model_global = 'atten_1_softmax_conv/vgg_full'
cmd_params.model_atten = 'atten_1_softmax_conv/atten'
cmd_params.model_match = 'atten_1_softmax_conv/match_singleimagepred'
----------------------------------
cmd_params.max_epoch = 300
cmd_params.backend = 'nn'
cmd_params.type = 'cuda'
----------------------------------
cmd_params.gpumode = 1
cmd_params.gpu_setDevice = 1

In [None]:
function cast(t)
   if cmd_params.type == 'cuda' then
      require 'cunn'
        gpumode = cmd_params.gpumode
        if gpumode==1 then
            cutorch.setDevice(cmd_params.gpu_setDevice)
        end
      return t:cuda()
   elseif cmd_params.type == 'float' then
      return t:float()
   elseif cmd_params.type == 'cl' then
      require 'clnn'
      return t:cl()
   else
      error('Unknown type '..cmd_params.type)
   end
end

In [None]:
local seed = 1234567890
torch.manualSeed(seed)

train_or_val = cmd_params.train_or_val
im_path = cmd_params.im_dir
gt_path = cmd_params.gt_dir

In [None]:
----Data Augmentation
function data_aug(input)
      local bs = input:size(1)
      local flip_mask = torch.randperm(bs):le(bs/2)
      for i=1,input:size(1) do
        if flip_mask[i] == 1 then image.hflip(input[i], input[i]) end
      end    
    return input
    
end

In [None]:
--[[
require 'loadcaffe'
model = loadcaffe.load('VGG_ILSVRC_16_layers_deploy.prototxt','VGG_ILSVRC_16_layers.caffemodel')
print(model)
]]--

In [None]:
----Initiation

--1. Data loading
print(c.blue '==>' ..' loading data')
provider = torch.load 'provider.t7'
provider.trainData.data = provider.trainData.data:float()
provider.testData.data = provider.testData.data:float()

unnorm_provider = torch.load 'unnorm_provider.t7'
unnorm_provider.trainData.data = unnorm_provider.trainData.data:float()
unnorm_provider.testData.data = unnorm_provider.testData.data:float()

--2. Model creation

------model A
model_local = nn.Sequential()
model_local:add(cast(nn.Copy('torch.FloatTensor', torch.type(cast(torch.Tensor())))))
model_local:add(cast(dofile('models/'..cmd_params.model_local..'.lua')))
model_local:get(1).updateGradInput = function(input) return end
if cmd_params.backend == 'cudnn' then
   require 'cudnn'
   cudnn.convert(model_local:get(2), cudnn)
end

model_global = nn.Sequential()
model_global:add(cast(dofile('models/'..cmd_params.model_global..'.lua')))
if cmd_params.backend == 'cudnn' then
    cudnn.convert(model_global:get(1), cudnn)
end

model_atten = nn.Sequential()
model_atten:add(cast(dofile('models/'..cmd_params.model_atten..'.lua')))
if cmd_params.backend == 'cudnn' then
    cudnn.convert(model_atten:get(1),cudnn)
end

model_match = nn.Sequential()
model_match:add(cast(dofile('models/' ..cmd_params.model_match..'.lua')))
if cmd_params.backend == 'cudnn' then
    cudnn.convert(model_match:get(1), 'cudnn')
end

------model B
--[[
model_local2 = model_local:clone('weight','bias','gradWeight','gradBias')
joint_local = nn.Sequential():add(model_local):add(model_local2)

model_global2 = model_global:clone('weight','bias','gradWeight','gradBias')
joint_global = nn.Sequential():add(model_global):add(model_global2)

model_atten2 = model_atten:clone('weight','bias','gradWeight','gradBias')
joint_atten = nn.Sequential():add(model_atten):add(model_atten2)
]]--

model_all = {}
table.insert(model_all, model_local)
table.insert(model_all, model_global)
table.insert(model_all, model_atten)
table.insert(model_all, model_match)

parameters, gradParams = model_utils.combine_all_parameters(model_all)
print(parameters:size())

-------------------------------------------------------------------------------------------------

--3. Criterion
print(c.blue'==>' ..' setting criterion')
criterion = cast(nn.CrossEntropyCriterion())

--4. Testing and saving
confusion = optim.ConfusionMatrix(10)
print('Will save at '..cmd_params.save)
paths.mkdir(cmd_params.save)
testLogger = optim.Logger(paths.concat(cmd_params.save, 'test.log'))
testLogger:setNames{'% mean class accuracy (train set)', '% mean class accuracy (test set)'}
testLogger.showPlot = false


--5. Learning settings
print(c.blue'==>' ..' configuring optimizer')
optimState = {
  learningRate = cmd_params.learningRate,
  weightDecay = cmd_params.weightDecay,
  momentum = cmd_params.momentum,
  learningRateDecay = cmd_params.learningRateDecay,
}

In [None]:
--Training
function train()
    
    model_local:training(); 
    model_global:training(); 
    model_atten:training(); 
    model_match:training()
    
    epoch = epoch or 1
    
    if epoch % cmd_params.epoch_step == 0 then optimState.learningRate = optimState.learningRate/2 end
    print(c.blue '==>'.." online epoch # " .. epoch .. ' [batchSize = ' .. cmd_params.batchSize .. ']')

    local targets = cast(torch.FloatTensor(cmd_params.batchSize))
    local indices = torch.randperm(provider.trainData.data:size(1)):long():split(cmd_params.batchSize)
    -- remove last element so that all the batches have equal size
    indices[#indices] = nil
    print(#indices)
    
    local tic = torch.tic()
    ---------- the entire epoch run
    for t,v in ipairs(indices) do
        
        xlua.progress(t, #indices)
        ----gather images and labels for two adj splits
        --1. image
        inputs = provider.trainData.data:index(1,v)
        --2. data augmentation
        inputs = data_aug(inputs)
        --3. unnormed image        
        unnorm_inputs = unnorm_provider.trainData.data:index(1,v)
        --4. labels
        targets:copy(provider.trainData.labels:index(1,v))
        
        ----------------------------------------------------------------------- 
        
        local feval = function(x)
              if x ~= parameters then parameters:copy(x) end
                  gradParameters:zero()
            
            
                  ---------forward
                  local lfeat = model_local:forward(inputs)
            --print(lfeat:size())
            --print(lfeat:max())              
                  local gfeat = model_global:forward(lfeat)
            --print(gfeat:size())                          
                  local att_con = model_atten:forward({lfeat,gfeat})
            --print(att_con[1]:size())
            --print(att_con[2]:size())         
                  local prediction = model_match:forward(att_con[2])         
                  
                  local err = criterion:forward(prediction, targets)
            --print(err)
                  
            ---------backward            
                  local df_pred = criterion:backward(prediction, targets)
            --print(df_pred)
                  local df_context = model_match:backward({att_con[2]}, df_pred)
                  
                  local df_feat = model_atten:backward({lfeat,gfeat}, {torch.rand(att_con[1]:size()):cuda():fill(0), df_context})
                                
                  local df_lfeat = model_global:backward(lfeat, df_feat[2])                  
            
                  model_local:backward(inputs,(df_lfeat+df_feat[1])/2)
            
            
            print('here')
            print(prediction)
            print(targets)
                  confusion:batchAdd(prediction, targets)
            
                  return f,gradParameters
        end
        optim.sgd(feval, parameters, optimState)

    end
    -------------
  confusion:updateValids()
  print(('Train accuracy: '..c.cyan'%.2f'..' %%\t time: %.2f s'):format(confusion.totalValid * 100, torch.toc(tic)))

  train_acc = confusion.totalValid * 100

  confusion:zero()
  epoch = epoch + 1
    
end


In [None]:
function test()
  -- disable flips, dropouts and batch normalization
    model_local:evaluate()
    model_global:evaluate() 
    model_atten:evaluate() 
    model_match:evaluate()
    
  print(c.blue '==>'.." testing")
  local bs = 125
  for i=1,provider.testData.data:size(1),bs do
        
    ---------forward
    local lfeat = model_local:forward(provider.testData.data:narrow(1,i,bs))

    local gfeat = model_global:forward(lfeat)

    local att_con = model_atten:forward({lfeat,gfeat})

    local prediction = model_match:forward(att_con[2])         
        
    confusion:batchAdd(prediction, provider.testData.labels:narrow(1,i,bs))
  end

  confusion:updateValids()
  print('Test accuracy:', confusion.totalValid * 100)
  
  if testLogger then
    paths.mkdir(cmd_params.save)
    testLogger:add{train_acc, confusion.totalValid * 100}
    testLogger:style{'-','-'}
    testLogger:plot()

    local base64im
    do
      os.execute(('convert -density 200 %s/test.log.eps %s/test.png'):format(cmd_params.save,cmd_params.save))
      os.execute(('openssl base64 -in %s/test.png -out %s/test.base64'):format(cmd_params.save,cmd_params.save))
      local f = io.open(cmd_params.save..'/test.base64')
      if f then base64im = f:read'*all' end
    end

    local file = io.open(cmd_params.save..'/report.html','w')
    file:write(([[
    <!DOCTYPE html>
    <html>
    <body>
    <title>%s - %s</title>
    <img src="data:image/png;base64,%s">
    <h4>optimState:</h4>
    <table>
    ]]):format(cmd_params.save,epoch,base64im))
    for k,v in pairs(optimState) do
      if torch.type(v) == 'number' then
        file:write('<tr><td>'..k..'</td><td>'..v..'</td></tr>\n')
      end
    end
    file:write'</table><pre>\n'
    file:write(tostring(confusion)..'\n')
    file:write(tostring(model_local)..'\n')
    file:write(tostring(model_global)..'\n')
    file:write(tostring(model_atten)..'\n')
    file:write(tostring(model_match)..'\n')
    file:write'</pre></body></html>'
    file:close()
  end

  -- save model every 50 epochs
  if epoch % 50 == 0 then
    local filename_loc = paths.concat(cmd_params.save, 'model_local.net')
    print('==> saving model to '..filename_loc)
    torch.save(filename_loc, model_local:get(2):clearState())
        
    local filename_glo = paths.concat(cmd_params.save, 'model_global.net')
    print('==> saving model to '.. filename_glo)
    torch.save(filename_glo, model_global:get(1):clearState())
        
    local filename_att = paths.concat(cmd_params.save, 'model_atten.net')
    print('==> saving model to '.. filename_att)
    torch.save(filename_att, model_atten:get(1):clearState())
        
    local filename_mat = paths.concat(cmd_params.save, 'model_match.net')
    print('==> saving model to '.. filename_mat)
    torch.save(filename_mat, model_match:get(1):clearState())
        
  end

  confusion:zero()
end

In [None]:
for i=1,cmd_params.max_epoch do
  train()
  test()
end