In [1]:
import sys
import json
import os; 
import numpy as np
from data import ImageSegmentationDataset
from torch.utils.data import DataLoader
from model import ImgSegRefExpModel


def load_vocab_dict_from_file(dict_file):
    with open(dict_file) as f:
        words = [w.strip() for w in f.readlines()]
    vocab_dict = {words[n]:n for n in range(len(words))}
    return vocab_dict

#Util functions

# # all boxes are [num, height, width] binary array
def compute_mask_IU(masks, target):
    #print (np.sum(np.logical_and(masks, target)))
    assert(target.shape[-2:] == masks.shape[-2:])
    I = torch.sum(np.logical_and(masks, target))
    U = torch.sum(np.logical_or(masks, target))
    return I, U

def resize_and_pad(im, input_h, input_w):
    # Resize and pad im to input_h x input_w size
    im_h, im_w = im.shape[:2]
    scale = min(input_h / im_h, input_w / im_w)
    resized_h = int(np.round(im_h * scale))
    resized_w = int(np.round(im_w * scale))
    pad_h = int(np.floor(input_h - resized_h) / 2)
    pad_w = int(np.floor(input_w - resized_w) / 2)

    resized_im = skimage.transform.resize(im, [resized_h, resized_w])
    if im.ndim > 2:
        new_im = np.zeros((input_h, input_w, im.shape[2]), dtype=resized_im.dtype)
    else:
        new_im = np.zeros((input_h, input_w), dtype=resized_im.dtype)
    new_im[pad_h:pad_h+resized_h, pad_w:pad_w+resized_w, ...] = resized_im

    return new_im

def resize_and_crop(im, input_h, input_w):
    # Resize and crop im to input_h x input_w size
    im_h, im_w = im.shape[:2]
    scale = max(input_h / im_h, input_w / im_w)
    resized_h = int(np.round(im_h * scale))
    resized_w = int(np.round(im_w * scale))
    crop_h = int(np.floor(resized_h - input_h) / 2)
    crop_w = int(np.floor(resized_w - input_w) / 2)

    resized_im = skimage.transform.resize(im, [resized_h, resized_w])
    if im.ndim > 2:
        new_im = np.zeros((input_h, input_w, im.shape[2]), dtype=resized_im.dtype)
    else:
        new_im = np.zeros((input_h, input_w), dtype=resized_im.dtype)
    new_im[...] = resized_im[crop_h:crop_h+input_h, crop_w:crop_w+input_w, ...]

    return new_im

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [2]:

################################################################################
# Parameters
################################################################################

root = '/home/nishaddawkhar/text_objseg/exp-referit/'

image_dir = root + 'referit-dataset/images/'
mask_dir = root + 'referit-dataset/mask/'
query_file = root + 'data/referit_query_train.json'
bbox_file = root + 'data/referit_bbox.json'
imcrop_file = root + 'data/referit_imcrop.json'
imsize_file = root + 'data/referit_imsize.json'
vocab_file = root + 'data/vocabulary_referit.txt'

query_file_val = root + 'data/referit_query_val.json'



train_dataset = ImageSegmentationDataset(query_file, image_dir, mask_dir)
val_dataset = ImageSegmentationDataset(query_file_val, image_dir, mask_dir)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=10, shuffle=False)


In [3]:
# #saving image sizes
# query_dict = json.load(open(query_file))

# image_names = set()
# for key, value in query_dict.items():
#     image_names.add(key.split('_')[0])

# import skimage.io
# image_sizes = {}
# for name in image_names:
#     im = skimage.io.imread(image_dir + name + '.jpg')
#     image_sizes[name] = im.shape

In [4]:
import torch
import torch.nn as nn
import time

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# trained model
pretrained_model_file = "/home/nishaddawkhar/text_objseg_pretrained_torch_converted_with_lstm.dms"
vocab_file = './vocabulary_referit.txt'

# Load vocabulary
vocab_dict = load_vocab_dict_from_file(vocab_file)

# Load model and weights
model = ImgSegRefExpModel(mlp_hidden=500, vocab_size=8803, emb_size=1000, lstm_hidden_size=1000)
pre_trained = torch.load(pretrained_model_file)
model.load_state_dict(pre_trained)
model.to(device)

ImgSegRefExpModel(
  (text_features): LanguageModule(
    (embedding): Embedding(8803, 1000)
    (lstm): LSTM(1000, 1000, batch_first=True)
  )
  (img_features): ImageModule(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inp

In [None]:
from torch.optim import SGD

# Model Params
T = 20
N = 10
input_H = 512; featmap_H = (input_H // 32)
input_W = 512; featmap_W = (input_W // 32)
num_vocab = 8803
embed_dim = 1000
lstm_dim = 1000
mlp_hidden_dims = 500

#Training Params
pos_loss_mult = 1.
neg_loss_mult = 1.

start_lr = 0.01
lr_decay_step = 10000
lr_decay_rate = 0.1
weight_decay = 0.0005
momentum = 0.9
max_iter = 30000

fix_convnet = False
vgg_dropout = False
mlp_dropout = False
vgg_lr_mult = 1.

cls_loss_avg = 0
avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg = 0, 0, 0
decay = 0.99

# Combine weight decay regularisation with optimiser
optimiser = torch.optim.SGD(model.parameters(),lr=start_lr, momentum=momentum, weight_decay=weight_decay)
torch.optim.lr_scheduler.StepLR(optimiser, step_size=lr_decay_step, gamma=lr_decay_rate)
loss = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor(int(pos_loss_mult),int(neg_loss_mult)).to(device))

for n_iter in range(max_iter):
    start = time.time()
    model.train()
    print("Training\nNumber of batches: {}\tBatch Size: {}\tDataset size: {}".format(len(train_loader),train_loader.batch_size,len(train_loader.dataset)))
    cls_loss_avg = 0.0
    end_time = 0
#     for batchId,(image, text, gt_mask, original_image_name) in enumerate(train_loader):
#         optimiser.zero_grad()
#         batch_start = time.time()
#         text = text.long()
#         output_mask = model((image.to(device), text.to(device)))
#         output_mask = output_mask.squeeze(1)
#         cls_loss_val = loss(output_mask,gt_mask.float().to(device))
#         cls_loss_val.backward()
#         cls_loss_avg = decay*cls_loss_avg + (1-decay)*cls_loss_val.item()
#         optimiser.step()
# #         import pdb; pdb.set_trace();
#         print("Batch Time with data loading = {}s, Batch #{}: Loss = {}\tAvg Loss: {}\tTime: {}s".format(time.time()-end_time,batchId,cls_loss_val.item(),cls_loss_avg,time.time()-batch_start))
#         end_time = time.time()
#     print('\titer = {},  Batch Loss (avg) = {}, lr = {}, time = {}s'.format(n_iter, cls_loss_avg, get_lr(optimiser),time.time()-start))
    
    print("Validating\nNumber of batches: {}\tBatch Size: {}\tDataset size: {}".format(len(val_loader),val_loader.batch_size,len(val_loader.dataset)))
    cls_loss_avg = 0.0
    model.eval()
    with torch.no_grad():
        batch_start = time.time()
        for batchId,(image, text, gt_mask, original_image_name) in enumerate(val_loader):
            optimiser.zero_grad()
            text = text.long()
            output_mask = model((image.to(device), text.to(device)))
            output_mask = output_mask.squeeze(1)
            cls_loss_val = loss(output_mask,gt_mask.float().to(device))
            cls_loss_avg = decay*cls_loss_avg + (1-decay)*cls_loss_val.item()
            if batchId % 100 == 0:
                print("Batch #{}: Loss = {}\tAvg Loss: {}\tTime: {}s".format(batchId,cls_loss_val.item(),cls_loss_avg,time.time()-batch_start))
        print('\titer = {},  Batch Loss (avg) = {}, lr = {}, time = {}s'.format(n_iter, cls_loss_avg, get_lr(optimiser),time.time()))   
print('Optimization done.')


Training
Number of batches: 5414	Batch Size: 10	Dataset size: 54134
Validating
Number of batches: 585	Batch Size: 10	Dataset size: 5842
torch.Size([10, 1000, 16, 16])
16 16
Batch #0: Loss = 0.0009770632022991776	Avg Loss: 9.770632022991785e-06	Time: 0.9425015449523926s
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
torch.Size([10, 1000, 16, 16])
16 16
