In [1]:
CUDA_VISIBLE_DEVICES=0

In [1]:
import torch
from options  import stage2_opts
from utils    import logger, recorders
from datasets import custom_data_loader
from models   import custom_model, solver_utils, model_utils

import train_stage2 as train_utils
import test_stage2 as test_utils

In [3]:
import argparse
parser = argparse.ArgumentParser() 

In [4]:
parser.add_argument('--dataset',     default='UPS_PRPS_Dataset')
parser.add_argument('--data_dir',    default='data/datasets/PRPS_Dataset')
parser.add_argument('--data_dir2',   default='data/datasets/PS_Sculpture_Dataset')
parser.add_argument('--concat_data', default=False, action='store_false')
parser.add_argument('--l_suffix',    default='_mtrl.txt')

#### Training Data and Preprocessing Arguments ####
parser.add_argument('--rescale',     default=True,  action='store_false')
parser.add_argument('--rand_sc',     default=True,  action='store_false')
parser.add_argument('--scale_h',     default=128,   type=int)
parser.add_argument('--scale_w',     default=128,   type=int)
parser.add_argument('--crop',        default=True,  action='store_false')
parser.add_argument('--crop_h',      default=128,   type=int)
parser.add_argument('--crop_w',      default=128,   type=int)
parser.add_argument('--test_h',      default=128,   type=int)
parser.add_argument('--test_w',      default=128,   type=int)
parser.add_argument('--test_resc',   default=True,  action='store_false')
parser.add_argument('--int_aug',     default=False,  action='store_false')
parser.add_argument('--noise_aug',   default=False,  action='store_false')
parser.add_argument('--noise',       default=0.05,  type=float)
parser.add_argument('--color_aug',   default=False,  action='store_false')
parser.add_argument('--color_ratio', default=3,     type=float)
parser.add_argument('--normalize',   default=False, action='store_true')

#### Device Arguments ####
parser.add_argument('--cuda',        default=True,  action='store_false')
parser.add_argument('--multi_gpu',   default=False, action='store_true')
parser.add_argument('--time_sync',   default=False, action='store_true')
parser.add_argument('--workers',     default=4,     type=int)
parser.add_argument('--seed',        default=0,     type=int)

#### Stage 1 Model Arguments ####
parser.add_argument('--dirs_cls',    default=36,    type=int)
parser.add_argument('--ints_cls',    default=20,    type=int)
parser.add_argument('--dir_int',     default=False, action='store_true')
parser.add_argument('--model',       default='LCNet')
parser.add_argument('--fuse_type',   default='max')
parser.add_argument('--in_img_num',  default=10,    type=int)
parser.add_argument('--s1_est_n',    default=False, action='store_true')
parser.add_argument('--s1_est_d',    default=True,  action='store_false')
parser.add_argument('--s1_est_i',    default=True,  action='store_false')
parser.add_argument('--in_light',    default=False, action='store_true')
parser.add_argument('--in_mask',     default=True,  action='store_false')
parser.add_argument('--use_BN',      default=False, action='store_true')
parser.add_argument('--resume',      default=None)
parser.add_argument('--retrain',     default='data/logdir/UPS_Synth_Dataset/CVPR2019/10_images/checkp_20.pth.tar')
parser.add_argument('--save_intv',   default=10,     type=int)

#### Stage 2 Model Arguments ####
parser.add_argument('--stage2',      default=True, action='store_true')
parser.add_argument('--model_s2',    default='NENet')
parser.add_argument('--retrain_s2',  default=None)
parser.add_argument('--s2_est_n',    default=True,  action='store_false')
parser.add_argument('--s2_est_i',    default=False, action='store_true')
parser.add_argument('--s2_est_d',    default=False, action='store_true')
parser.add_argument('--s2_in_light', default=True,  action='store_false')

#### Displaying Arguments ####
parser.add_argument('--train_disp',    default=20,  type=int)
parser.add_argument('--train_save',    default=200, type=int)
parser.add_argument('--val_intv',      default=1,   type=int)
parser.add_argument('--val_disp',      default=1,   type=int)
parser.add_argument('--val_save',      default=1,   type=int)
parser.add_argument('--max_train_iter',default=-1,  type=int)
parser.add_argument('--max_val_iter',  default=-1,  type=int)
parser.add_argument('--max_test_iter', default=-1,  type=int)
parser.add_argument('--train_save_n',  default=4,   type=int)
parser.add_argument('--test_save_n',   default=4,   type=int)

#### Log Arguments ####
parser.add_argument('--save_root',  default='data/logdir/')
parser.add_argument('--item',       default='CVPR2019')
parser.add_argument('--suffix',     default=None)
parser.add_argument('--debug',      default=False, action='store_true')
parser.add_argument('--make_dir',   default=True,  action='store_false')
parser.add_argument('--save_split', default=False, action='store_true')


_StoreTrueAction(option_strings=['--save_split'], dest='save_split', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)

In [5]:
parser.add_argument('--solver',      default='adam', help='adam|sgd')
parser.add_argument('--milestones',  default=[2, 4, 6, 8, 10], nargs='+', type=int)
parser.add_argument('--start_epoch', default=1,      type=int)
parser.add_argument('--epochs',      default=200,     type=int)
parser.add_argument('--batch',       default=8,     type=int)
parser.add_argument('--val_batch',   default=1,      type=int)
parser.add_argument('--init_lr',     default=0.0005, type=float)
parser.add_argument('--lr_decay',    default=0.5,    type=float)
parser.add_argument('--beta_1',      default=0.9,    type=float, help='adam')
parser.add_argument('--beta_2',      default=0.999,  type=float, help='adam')
parser.add_argument('--momentum',    default=0.9,    type=float, help='sgd')
parser.add_argument('--w_decay',     default=4e-4,   type=float)

#### Loss Arguments ####
parser.add_argument('--normal_loss', default='cos',  help='cos|mse')
parser.add_argument('--normal_w',    default=1,      type=float)
parser.add_argument('--dir_loss',    default='mse',  help='cos|mse')
parser.add_argument('--dir_w',       default=1,      type=float)
parser.add_argument('--ints_loss',   default='mse',  help='mse')
parser.add_argument('--ints_w',      default=1,      type=float)
parser.add_argument('--rec_loss', default='L1',  help='L1|L2')
parser.add_argument('--rec_w',    default=60,      type=float)

_StoreAction(option_strings=['--rec_w'], dest='rec_w', nargs=None, const=None, default=60, type=<class 'float'>, choices=None, help=None, metavar=None)

In [6]:
args = parser.parse_args(args=[])

In [7]:
args

Namespace(batch=8, beta_1=0.9, beta_2=0.999, color_aug=False, color_ratio=3, concat_data=False, crop=True, crop_h=128, crop_w=128, cuda=True, data_dir='data/datasets/PRPS_Dataset', data_dir2='data/datasets/PS_Sculpture_Dataset', dataset='UPS_PRPS_Dataset', debug=False, dir_int=False, dir_loss='mse', dir_w=1, dirs_cls=36, epochs=200, fuse_type='max', in_img_num=10, in_light=False, in_mask=True, init_lr=0.0005, int_aug=False, ints_cls=20, ints_loss='mse', ints_w=1, item='CVPR2019', l_suffix='_mtrl.txt', lr_decay=0.5, make_dir=True, max_test_iter=-1, max_train_iter=-1, max_val_iter=-1, milestones=[2, 4, 6, 8, 10], model='LCNet', model_s2='NENet', momentum=0.9, multi_gpu=False, noise=0.05, noise_aug=False, normal_loss='cos', normal_w=1, normalize=False, rand_sc=True, rec_loss='L1', rec_w=60, rescale=True, resume=None, retrain='data/logdir/UPS_Synth_Dataset/CVPR2019/10_images/checkp_20.pth.tar', retrain_s2=None, s1_est_d=True, s1_est_i=True, s1_est_n=False, s2_est_d=False, s2_est_i=False, s

In [8]:
def buildModel(args):
    print('Creating Model %s' % (args.model))
    in_c = 4
    other = {
            'img_num':  args.in_img_num, 
            'test_h':   args.test_h,   'test_w':   args.test_w,
            'in_mask':  args.in_mask,  'in_light': args.in_light, 
            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
            's1_est_d': args.s1_est_d, 's1_est_i': args.s1_est_i, 's1_est_n': args.s1_est_n, 
            }
    models = __import__('models.' + args.model)
    model_file = getattr(models, args.model)
    model = getattr(model_file, args.model)(args.fuse_type, args.use_BN, in_c, other)

    if args.cuda: model = model.cuda()

    if args.retrain: 
        model_utils.loadCheckpoint(args.retrain, model, cuda=args.cuda)

    if args.resume:
        model_utils.loadCheckpoint(args.resume, model, cuda=args.cuda)
    print(model)
    return model

In [9]:
model = buildModel(args)

Creating Model LCNet
LCNet(
  (featExtractor): FeatExtractor(
    (conv1): Sequential(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (conv2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (conv3): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (conv4): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (conv5): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (conv6): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): Le

In [10]:
def buildModelStage2(args):
    print('Creating Stage2 Model %s' % (args.model_s2))
    # in_c = 6 if args.s2_in_light else 3
    in_c = 4
    other = {
            'img_num':  args.in_img_num,
            'in_mask':  args.in_mask,  'in_light': args.in_light, 
            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
            }
    models = __import__('models.' + args.model_s2)
    model_file = getattr(models, args.model_s2)
    model = getattr(model_file, args.model_s2)(args.fuse_type, args.use_BN, in_c, other)

    if args.cuda: model = model.cuda()

    if args.retrain_s2: 
        model_utils.loadCheckpoint(args.retrain_s2, model, cuda=args.cuda)

    print(model)
    return model

In [11]:
model_s2 = buildModelStage2(args)

Creating Stage2 Model NENet
NENet(
  (generator): Generator(
    (main): Sequential(
      (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (4): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (7): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (10): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU(inplace=True)
      (12): ResidualBlock(
        (main): Sequential(
     

In [12]:
models = [model, model_s2]

In [13]:
def configOptimizer(args, model):
    records = None
    optimizer = getOptimizer(args, model.parameters())
    if args.resume:
        records, start_epoch = loadRecords(args.resume, model, optimizer)
        args.start_epoch = start_epoch
    scheduler = getLrScheduler(args, optimizer)
    return optimizer, scheduler, records

In [14]:
def getOptimizer(args, params):
#     args.log.printWrite('=> Using %s solver for optimization' % (args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(params, args.init_lr, betas=(args.beta_1, args.beta_2))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(params, args.init_lr, momentum=args.momentum)
    else:
        raise Exception("=> Unknown Optimizer %s" % (args.solver))
    return optimizer

def getLrScheduler(args, optimizer):
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
            milestones=args.milestones, gamma=args.lr_decay, last_epoch=args.start_epoch-2)
    return scheduler

def loadRecords(path, model, optimizer):
    records = None
    if os.path.isfile(path):
        records = torch.load(path[:-8] + '_rec' + path[-8:])
        optimizer.load_state_dict(records['optimizer'])
        start_epoch = records['epoch'] + 1
        records = records['records']
        print("=> loaded Records")
    else:
        raise Exception("=> no checkpoint found at '{}'".format(path))
    return records, start_epoch

In [15]:
optimizer, scheduler, records = configOptimizer(args, model_s2)

In [16]:
class Stage2Crit(object): # Second stage
    def __init__(self, args):
        self.s2_est_n = args.s2_est_n 
        self.s2_est_d = args.s2_est_d
        self.s2_est_i = args.s2_est_i
        self.setupLightCrit(args)
        if self.s2_est_n:
            self.setupNormalCrit(args)
        self.setupRecCrit(args)

    def setupRecCrit(self, args):
#         args.log.printWrite('=> Using reconstruction criterion')
        if args.rec_loss == 'L1':
            self.rec_crit = torch.nn.L1Loss()
        elif args.rec_loss == 'L2':
            self.rec_crit = torch.nn.MSELoss()
        self.rec_w = args.rec_w

    def setupLightCrit(self, args):
#         args.log.printWrite('=> Using light criterion')
        if self.s2_est_d:
            self.dir_w = args.dir_w
            self.dirs_crit = torch.nn.CosineEmbeddingLoss()
            if args.cuda: self.dirs_crit = self.dirs_crit.cuda()
        if self.s2_est_i:
            self.ints_w = args.ints_w
            self.ints_crit = torch.nn.MSELoss()
            if args.cuda: self.ints_crit = self.ints_crit.cuda()

    def setupNormalCrit(self, args):
#         args.log.printWrite('=> Using {} for criterion normal'.format(args.normal_loss))
        self.normal_w = args.normal_w
        if args.normal_loss == 'mse':
            self.n_crit = torch.nn.MSELoss()
        elif args.normal_loss == 'cos':
            self.n_crit = torch.nn.CosineEmbeddingLoss()
        else:
            raise Exception("=> Unknown Criterion '{}'".format(args.normal_loss))
        if args.cuda:
            self.n_crit = self.n_crit.cuda()

    def forward(self, output, target, random_loc, s2_est_obMp):
        self.loss = 0
        out_loss = {}

        if self.s2_est_n:
            random_x_loc, random_y_loc = random_loc
            n_est, n_tar = output['n'], target['n'][:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
            n_num = n_tar.nelement() // n_tar.shape[1]
            if not hasattr(self, 'n_flag') or n_num != self.n_flag.nelement():
                self.n_flag = n_tar.data.new().resize_(n_num).fill_(1)
            self.out_reshape = n_est.permute(0, 2, 3, 1).contiguous().view(-1, 3)
            self.gt_reshape  = n_tar.permute(0, 2, 3, 1).contiguous().view(-1, 3)
            normal_loss = self.n_crit(self.out_reshape, self.gt_reshape, self.n_flag)
            normal_loss = torch.acos(1 - normal_loss) / 3.14159 * 180
            self.loss += self.normal_w * normal_loss 
            out_loss['N_loss'] = normal_loss.item() 
        if s2_est_obMp:
            ob_map_est, ob_map_tar = output['ob_map_dense'], target['ob_map_real']

            ob_map_mask = torch.gt(ob_map_tar,0)
            ob_map_est = ob_map_est * ob_map_mask

            rec_loss = self.rec_crit(ob_map_est, ob_map_tar)
            self.loss += self.rec_w * rec_loss
            out_loss['Rec_loss'] = rec_loss.item()
        return out_loss

    def backward(self):
        self.loss.backward()

In [17]:
class Records(object):
    """
    Records->Train,Val->Loss,Accuracy->Epoch1,2,3->[v1,v2]
    IterRecords->Train,Val->Loss, Accuracy,->[v1,v2]
    """
    def __init__(self, log_dir, records=None):
        if records == None:
            self.records = OrderedDict()
        else:
            self.records = records
        self.iter_rec = OrderedDict()
        self.log_dir  = log_dir
        self.classes = ['loss', 'acc', 'err', 'ratio']

    def resetIter(self):
        self.iter_rec.clear()

    def checkDict(self, a_dict, key, sub_type='dict'):
        if key not in a_dict.keys():
            if sub_type == 'dict':
                a_dict[key] = OrderedDict()
            if sub_type == 'list':
                a_dict[key] = []

    def updateIter(self, split, keys, values):
        self.checkDict(self.iter_rec, split, 'dict')
        for k, v in zip(keys, values):
            self.checkDict(self.iter_rec[split], k, 'list')
            self.iter_rec[split][k].append(v)

    def saveIterRecord(self, epoch, reset=True):
        for s in self.iter_rec.keys(): # s stands for split
            self.checkDict(self.records, s, 'dict')
            for k in self.iter_rec[s].keys():
                self.checkDict(self.records[s], k, 'dict')
                self.checkDict(self.records[s][k], epoch, 'list')
                self.records[s][k][epoch].append(np.mean(self.iter_rec[s][k]))
        if reset: 
            self.resetIter()

    def insertRecord(self, split, key, epoch, value):
        self.checkDict(self.records, split, 'dict')
        self.checkDict(self.records[split], key, 'dict')
        self.checkDict(self.records[split][key], epoch, 'list')
        self.records[split][key][epoch].append(value)

    def iterRecToString(self, split, epoch):
        rec_strs = ''
        for c in self.classes:
            strs = ''
            for k in self.iter_rec[split].keys():
                if (c in k.lower()):
                    strs += '{}: {:.3f}| '.format(k, np.mean(self.iter_rec[split][k]))
            if strs != '':
                rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
        self.saveIterRecord(epoch)
        return rec_strs

    def epochRecToString(self, split, epoch):
        rec_strs = ''
        for c in self.classes:
            strs = ''
            for k in self.records[split].keys():
                if (c in k.lower()) and (epoch in self.records[split][k].keys()):
                    strs += '{}: {:.3f}| '.format(k, np.mean(self.records[split][k][epoch]))
            if strs != '':
                rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
        return rec_strs

    def recordToDictOfArray(self, splits, epoch=-1, intv=1):
        if len(self.records) == 0: return {}
        if type(splits) == str: splits = [splits]

        dict_of_array = OrderedDict()
        for split in splits:
            for k in self.records[split].keys():
                y_array, x_array = [], []
                if epoch < 0:
                    for ep in self.records[split][k].keys():
                        y_array.append(np.mean(self.records[split][k][ep]))
                        x_array.append(ep)
                else:
                    if epoch in self.records[split][k].keys():
                        y_array = np.array(self.records[split][k][epoch])
                        x_array = np.linspace(intv, intv*len(y_array), len(y_array))
                dict_of_array[split[0] + split[-1] + '_' + k]      = y_array
                dict_of_array[split[0] + split[-1] + '_' + k+'_x'] = x_array
        return dict_of_array


In [18]:
optimizers = [optimizer, -1]
criterion = Stage2Crit(args)
recorder  = recorders.Records(args.log_dir, records)

AttributeError: 'Namespace' object has no attribute 'log_dir'

In [21]:
def customDataloader(args):
#     args.log.printWrite("=> fetching img pairs in %s" % (args.data_dir))
    datasets = __import__('datasets.' + args.dataset)
    dataset_file = getattr(datasets, args.dataset)
    train_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'train')
    val_set   = getattr(dataset_file, args.dataset)(args, args.data_dir, 'val')

    if args.concat_data:
#         args.log.printWrite('****** Using cocnat data ******')
#         args.log.printWrite("=> fetching img pairs in '{}'".format(args.data_dir2))
        train_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'train')
        val_set2   = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'val')
        train_set  = torch.utils.data.ConcatDataset([train_set, train_set2])
        val_set    = torch.utils.data.ConcatDataset([val_set,   val_set2])

#     args.log.printWrite('Found Data:\t %d Train and %d Val' % (len(train_set), len(val_set)))
#     args.log.printWrite('\t Train Batch: %d, Val Batch: %d' % (args.batch, args.val_batch))

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch,
        num_workers=args.workers, pin_memory=args.cuda, shuffle=True)
    test_loader  = torch.utils.data.DataLoader(val_set , batch_size=args.val_batch,
        num_workers=args.workers, pin_memory=args.cuda, shuffle=False)
    return train_loader, test_loader

In [22]:
train_loader, val_loader = customDataloader(args)

In [23]:
loader = train_loader

In [24]:
models[1].train()
models[0].eval()
optimizer, optimizer_c = optimizers

In [25]:
data_iter = iter(loader)

In [26]:
sample = next(data_iter)

In [143]:
data = model_utils.parseData(args, sample,None, 'train')

In [144]:
input = model_utils.getInput(args, data)

In [145]:
with torch.no_grad():
    pred_c = models[0](input); 
input.append(pred_c)

In [146]:
torch.nn.functional.softmax(pred_c['dirs_x'][1,:])

  """Entry point for launching an IPython kernel.


tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 1.0412e-42, 5.8342e-37, 1.9666e-31, 4.7624e-25, 4.1670e-19,
        2.5075e-13, 1.9390e-08, 1.2796e-04, 1.3743e-01, 8.4075e-01, 2.1676e-02,
        1.2641e-05, 1.5157e-10, 4.3395e-15, 2.7454e-20, 8.4375e-26, 1.9152e-30,
        9.0211e-34, 1.5246e-42, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00], device='cuda:0')

In [173]:
s2_est_obMp = True
if s2_est_obMp:
    start_loc, end_loc = 20, 108
    random_loc = torch.randint(start_loc,end_loc,[2,1])
    input.append(random_loc)
    # print (random_loc)
    data['ob_map_real'] = model_utils.parseData_stage2(args, sample, random_loc, 'train')

torch.Size([8, 1, 512, 512])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')


In [147]:
img_all = sample['img_all']
if args.in_light:
    dirs = sample['dirs_all'].expand_as(img)
else: # predict lighting, prepare ground truth
    n, c, h, w = sample['dirs_all'].shape
    dirs_split = torch.split(sample['dirs_all'].view(n, c), 3, 1)

x_loc, y_loc = random_loc
img_all_crop = img_all[:,:,x_loc - 8:x_loc + 8, y_loc - 8:y_loc + 8]
del img_all
if args.cuda:
    img_all_crop = img_all_crop.cuda()
n, c, h, w = img_all_crop.shape
imgs = list(torch.split(img_all_crop, 3, 1))
for i in range(len(imgs)):
    img_patch = imgs[i].mean(1)
    img_patch = img_patch.repeat_interleave(32,1).repeat_interleave(32,2)
    dirs = dirs_split[i]
    if args.cuda:
        dirs = dirs.cuda()
    x= 0.5*(dirs[:,0]+1)*(32-1); 
    x=torch.round(x).type(torch.uint8).unsqueeze(1);
    x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
    y= 0.5*(dirs[:,1]+1)*(32-1);
    y=torch.round(y).type(torch.uint8).unsqueeze(1);
    y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
    loc_one_hot = x_one_hot * y_one_hot
    loc_one_hot = loc_one_hot.repeat(1,16,16)
    if i == 0:
        ob_map_real = img_patch * loc_one_hot
    else:
        ob_map_real,_ = torch.stack([ob_map_real, img_patch * loc_one_hot],1).max(1)
ob_map_real = ob_map_real.unsqueeze(1)

In [162]:
loc_one_hot[1,:,:].sum()

tensor(1., device='cuda:0')

In [156]:
dirs = dirs.cuda()
x= 0.5*(dirs[:,0]+1)*(32-1); 
x=torch.round(x).type(torch.uint8).unsqueeze(1);
x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
y= 0.5*(dirs[:,1]+1)*(32-1);
y=torch.round(y).type(torch.uint8).unsqueeze(1);
y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
loc_one_hot = x_one_hot * y_one_hot

In [159]:
loc_one_hot.shape[0,0:32,0:32].sum()

TypeError: tuple indices must be integers or slices, not tuple

In [None]:
x= 0.5*(dirs[:,0]+1)*(32-1); 

In [153]:
dirs = dirs_split[i]
img_patch = imgs[i].mean(1)
img_patch = img_patch.repeat_interleave(32,1).repeat_interleave(32,2)

In [154]:
dirs.shape

torch.Size([8, 3])

In [32]:
img_all_crop.max().max().max().max()

tensor(1., device='cuda:0')

In [64]:
img_patch = imgs[i].mean(1)
img_patch = img_patch.repeat_interleave(32,1).repeat_interleave(32,2)

In [79]:
imgs[996].max().max().max()

tensor(2.8220, device='cuda:0')

In [54]:
x

tensor([[27],
        [25],
        [ 7],
        [20],
        [ 1],
        [ 5],
        [12],
        [24]], device='cuda:0', dtype=torch.uint8)

In [56]:
loc_one_hot = x_one_hot * y_one_hot

In [61]:
x_one_hot.shape

torch.Size([8, 32, 32])

In [62]:
x_one_hot[0,:32,:32].sum().sum()

tensor(32., device='cuda:0')

In [60]:
loc_one_hot.shape

torch.Size([8, 32, 32])

In [63]:
loc_one_hot[0,:32,:32].sum().sum()

tensor(1., device='cuda:0')

In [50]:
img_patch[:4,:4]

tensor([[[0.2276, 0.2436, 0.2650, 0.2624, 0.4042, 0.4872, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2383, 0.2597, 0.2757, 0.3025, 0.4444, 0.4685, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2383, 0.2543, 0.2757, 0.3025, 0.4605, 0.4765, 0.5220, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2329, 0.2490, 0.2650, 0.3641, 0.4471, 0.4846, 0.5220, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0959, 0.0148, 0.0553, 0.0443, 0.0037, 0.0111, 0.0148, 0.0037,
          0.0074, 0.0295, 0.0258, 0.0443, 0.0295, 0.0000, 0.0000, 0.0000],
         [0.0332, 0.0332, 0.0184, 0.0258, 0.0295, 0.0221, 0.0074, 0.0074,
          0.0074, 0.0221, 0.0111, 0.0037, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0258, 0.0221, 0.0111, 0.0111, 0.0000, 0.0037, 0.0000, 0.0000,
          0.0000, 0.0000, 0.00

In [33]:
a = data['ob_map_real']

In [35]:
a[0,0,,16]

tensor(0.2972, device='cuda:0')

In [37]:
i=10
j=10

In [39]:
a[3,0,32*i+12:32*(i+1)-12,32*j+12:32*(j+1)-12]

tensor([[0.3077, 0.3027, 0.2928, 0.2804, 0.0000, 0.2457, 0.2357, 0.2283],
        [0.0000, 0.0000, 0.0000, 0.2779, 0.2655, 0.2506, 0.0000, 0.2208],
        [0.3127, 0.3002, 0.2878, 0.2729, 0.2655, 0.2531, 0.2432, 0.2208],
        [0.3102, 0.3027, 0.2878, 0.2754, 0.2680, 0.2556, 0.2382, 0.2258],
        [0.3077, 0.3002, 0.2804, 0.2779, 0.2655, 0.2531, 0.2382, 0.2208],
        [0.0000, 0.2953, 0.2854, 0.2705, 0.2605, 0.2432, 0.2332, 0.0000],
        [0.0000, 0.2878, 0.2754, 0.2655, 0.2556, 0.2382, 0.2258, 0.2159],
        [0.2953, 0.2829, 0.2655, 0.2655, 0.2506, 0.2382, 0.2159, 0.2035]],
       device='cuda:0')

In [34]:
sum(sum(sum(sum(a==0))))

tensor(1227515, device='cuda:0')

In [35]:
img_all_crop.shape

torch.Size([8, 3000, 16, 16])

In [36]:
sum(sum(sum(sum(img_all_crop==0))))

tensor(2078078, device='cuda:0')

In [112]:
import cv2
import numpy as np

In [2]:
import torch

In [1]:
b = torch.ones(100,100)

NameError: name 'torch' is not defined

In [117]:
cv2.imshow('img',np.mat(a[1,0,:,:].cpu()))

In [73]:
a.max().max().max()

tensor(2.8220, device='cuda:0')

In [174]:
x = input

In [175]:
imgs = torch.split(x[0], 3, 1)
idx = 1

In [176]:
idx = 2

In [177]:
dirs_x = torch.split(x[idx]['dirs_x'], x[0].shape[0], 0)
dirs_y = torch.split(x[idx]['dirs_y'], x[0].shape[0], 0)
dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)

In [178]:
random_x_loc, random_y_loc = x[idx + 1]

In [179]:
random_x_loc

tensor([57])

In [180]:
random_y_loc

tensor([60])

In [181]:
len(imgs)

10

In [182]:
i = 1

In [183]:
import torch.nn as nn

In [184]:
dirs_map = nn.functional.softmax(dirs_x[i],1).unsqueeze(2).repeat(1,1,dirs_x[i].shape[1]) * nn.functional.softmax(dirs_y[i],1).unsqueeze(1).repeat(1,dirs_y[i].shape[1],1)
dirs_map = dirs_map.repeat(1,16,16).unsqueeze(1)

In [185]:
dirs_map.shape

torch.Size([8, 1, 512, 512])

In [186]:
img = imgs[i][:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
img = img.repeat_interleave(32,2).repeat_interleave(32,3)

In [187]:
img[1,1,0:32,0:32]

tensor([[0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824],
        ...,
        [0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824,  ..., 0.0824, 0.0824, 0.0824]],
       device='cuda:0')

In [188]:
_, x_idx = dirs_x[i].data.max(1)
_, y_idx = dirs_y[i].data.max(1)

In [189]:
a,_ = dirs_map[0,0,0:32,0:32].max(1)

In [190]:
a.max(0)

torch.return_types.max(
values=tensor(0.3065, device='cuda:0'),
indices=tensor(6, device='cuda:0'))

In [191]:
x_idx

tensor([ 6, 18, 10, 22, 14,  7, 10, 27], device='cuda:0')

In [192]:
y_idx

tensor([15, 22, 13, 28, 29,  4, 16, 21], device='cuda:0')

In [133]:
img.shape

torch.Size([8, 3, 512, 512])

In [195]:
x=x_idx.type(torch.uint8).unsqueeze(1);
x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
y=y_idx.type(torch.uint8).unsqueeze(1);
y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
loc_one_hot = x_one_hot * y_one_hot

In [201]:
loc_one_hot.shape

torch.Size([8, 32, 32])

In [136]:
img_gray[0,0:32,0:32]

tensor([[0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784],
        [0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784],
        [0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784],
        ...,
        [0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784],
        [0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784],
        [0.0784, 0.0784, 0.0784,  ..., 0.0784, 0.0784, 0.0784]],
       device='cuda:0')

In [138]:
max_filter = torch.zeros(32,32)
max_filter[x_idx,y_idx] = 1

In [140]:
max_filter.sum()

tensor(8.)

In [137]:
img_gray_filtered[0,0:32,0:32]

tensor(0.6275, device='cuda:0')

In [210]:
imgs[i].mean(1).shape

torch.Size([8, 128, 128])

In [203]:
_, x_idx = dirs_x[i].data.max(1)
_, y_idx = dirs_y[i].data.max(1)
x=x_idx.type(torch.uint8).unsqueeze(1);
x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
y=y_idx.type(torch.uint8).unsqueeze(1);
y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
loc_one_hot = x_one_hot * y_one_hot
max_filter = loc_one_hot.repeat(1,16,16)

In [209]:
max_filter[0,6,15]

tensor(1., device='cuda:0')

In [None]:
dirs_x = torch.split(x[idx]['dirs_x'], x[0].shape[0], 0)
dirs_y = torch.split(x[idx]['dirs_y'], x[0].shape[0], 0)
dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)
random_x_loc, random_y_loc = x[idx + 1]
s2_inputs = []
tmp = []
for i in range(len(imgs)):
    n, c, h, w = imgs[i].shape
    dirs_map = nn.functional.softmax(dirs_x[i],1).unsqueeze(2).repeat(1,1,dirs_x[i].shape[1]) * nn.functional.softmax(dirs_y[i],1).unsqueeze(1).repeat(1,dirs_y[i].shape[1],1)
    dirs_map = dirs_map.repeat(1,16,16).unsqueeze(1)
    dirs_map = dirs_map.cuda()
    # l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)
    # l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))
    # img   = imgs[i].contiguous().view(n * c, h * w)
    # img   = torch.mm(l_int, img).view(n, c, h, w)
    img = imgs[i][:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
    img = img.repeat_interleave(32,2).repeat_interleave(32,3)
    # img = img.mean(1)
    # img = img.unsqueeze(1)
    img_light = torch.cat([img, dirs_map], 1)
    s2_inputs.append(img_light)

    _, x_idx = dirs_x[i].data.max(1)
    _, y_idx = dirs_y[i].data.max(1)
    x=x_idx.type(torch.uint8).unsqueeze(1);
    x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
    y=y_idx.type(torch.uint8).unsqueeze(1);
    y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
    loc_one_hot = x_one_hot * y_one_hot
    max_filter = loc_one_hot.repeat(1,16,16)
    max_filter = max_filter.cuda()
    img_gray = img.mean(1)
    img_gray_filtered = img_gray * max_filter
    tmp.append(img_gray_filtered)
regressor_inputs,_ = torch.stack(tmp,1).max(1)
regressor_inputs = regressor_inputs.unsqueeze(1)

In [48]:
data_iter2 = iter(val_loader)

In [49]:
sample = next(data_iter2)

In [50]:
data = model_utils.parseData(args, sample, None, 'val')
input = model_utils.getInput(args, data)

In [51]:
pred_c = models[0](input);
input.append(pred_c)

In [52]:
s2_est_obMp = True
if s2_est_obMp:
    start_loc, end_loc = 20, 108
    random_loc = torch.randint(start_loc,end_loc,[2,1])
    input.append(random_loc)
    data['ob_map_real'] = model_utils.parseData_stage2(args, sample, random_loc, 'train')

In [53]:
pred = models[1](input)

In [54]:
pred

{'ob_map_dense': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0',
        grad_fn=<IndexPutBackward>),
 'n': tensor([[[[ 0.3467,  0.1389,  0.0499,  0.0376, -0.0399,  0.0072,  0.0375,
             0.0549,  0.0800,  0.0204, -0.0094, -0.0165, -0.0160, -0.0198,
            -0.0111,  0.0291],
           [ 0.0899,  0.2337,  0.0402, -0.1092, -0.2564, -0.2216, -0.3326,
            -0.4097, -0.2549, -0.2658, -0.3584, -0.3657, -0.3777, -0.3904,
            -0.3536, -0.2839],
           [ 0.1298,  0.5079,  0.3457,  0.0325, -0.2606, -0.3157, -0.3540,
            -0.5260, -0.4461, -0.2996, -0.4105, -0.4190, -0.3985, -0.4169,
            -0.4052, -0.3512],
           [ 0.2383,  0.6385,  0.2333, -0.1055, -0.3842, -0.4986, -0.4190,
            -0.5082, -0.4944, 

In [55]:
split = 'val'

In [56]:
from utils import eval_utils, time_utils 

In [57]:
mask_var = data['m']
data_batch = args.val_batch if split == 'val' else args.test_batch
iter_res = []
error = ''
# if args.s1_est_d:
#     l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch)
#     recorder.updateIter(split, l_acc.keys(), l_acc.values())
#     iter_res.append(l_acc['l_err_mean'])
#     error += 'D_%.3f-' % (l_acc['l_err_mean']) 
# if args.s1_est_i:
#     int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, data_batch)
#     recorder.updateIter(split, int_acc.keys(), int_acc.values())
#     iter_res.append(int_acc['ints_ratio'])
#     error += 'I_%.3f-' % (int_acc['ints_ratio'])
if args.s1_est_d:
    l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch)
    iter_res.append(l_acc['l_err_mean'])
    error += 'D_%.3f-' % (l_acc['l_err_mean']) 

In [58]:
if args.s2_est_n:
    random_x_loc, random_y_loc = random_loc
    n_tar = data['n'][:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
    mask_var = mask_var[:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
    acc, error_map = eval_utils.calNormalAcc(n_tar.data, pred['n'].data, mask_var.data)
    iter_res.append(acc['n_err_mean'])
    error += 'N_%.3f-' % (acc['n_err_mean'])
    data['error_map'] = error_map['angular_map']



In [59]:
acc

{'n_err_mean': 57.3135986328125,
 'n_acc_11': 0.00913241971284151,
 'n_acc_30': 0.09132420271635056,
 'n_acc_45': 0.2009132355451584}

In [None]:
dot_product = (gt_n * pred_n).sum(1).clamp(-1,1)
error_map   = torch.acos(dot_product) # [-pi, pi]
angular_map = error_map * 180.0 / math.pi
angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1)

valid = mask.narrow(1, 0, 1).sum()
ang_valid  = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()]
n_err_mean = ang_valid.sum() / valid
n_err_med  = ang_valid.median()
n_acc_11   = (ang_valid < 11.25).sum().float() / valid
n_acc_30   = (ang_valid < 30).sum().float() / valid
n_acc_45   = (ang_valid < 45).sum().float() / valid

angular_map = colorMap(angular_map.cpu().squeeze(1))
value = {'n_err_mean': n_err_mean.item(), 
        'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()}
angular_error_map = {'angular_map': angular_map}

In [30]:
mask = torch.zeros(1,1,16,16)
# mask[0,0,1,1]=1

In [31]:
gt_n = torch.rand(1,1,16,16)
pred_n = torch.rand(1,1,16,16)

In [32]:
import math

In [33]:
dot_product = (gt_n * pred_n).sum(1).clamp(-1,1)
error_map   = torch.acos(dot_product) # [-pi, pi]
angular_map = error_map * 180.0 / math.pi
angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1)

In [34]:
valid = mask.narrow(1, 0, 1).sum()
ang_valid  = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()]



In [35]:
valid

tensor(0.)

In [37]:
ang_valid

tensor([])

In [36]:
n_err_mean = ang_valid.sum() / valid

In [38]:
n_err_mean

tensor(nan)

In [47]:
n_err_med  = ang_valid.median() if valid else

TypeError: unsupported operand type(s) for /: 'list' and 'int'

In [45]:
n_err_med

In [27]:
mask.narrow(1, 0, 1).shape

torch.Size([1, 1, 16, 16])

In [28]:
n_err_mean

tensor(84.9280)

In [29]:
valid

tensor(1.)

In [40]:
n_acc_11   = (ang_valid < 11.25).sum().float() / valid
n_acc_30   = (ang_valid < 30).sum().float() / valid
n_acc_45   = (ang_valid < 45).sum().float() / valid

In [41]:
n_acc_11

tensor(nan)

In [54]:
import numpy as np
from matplotlib import cm

In [55]:
def colorMap(diff):
    thres = 90
    diff_norm = np.clip(diff, 0, thres) / thres
    diff_cm = torch.from_numpy(cm.jet(diff_norm.numpy()))[:,:,:, :3]
    return diff_cm.permute(0,3,1,2).clone().float()

In [56]:
angular_map = colorMap(angular_map.cpu().squeeze(1))

In [58]:
value = {'n_err_mean': n_err_mean.item(), 
            'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()}
angular_error_map = {'angular_map': angular_map}

In [60]:
acc = value

In [61]:
error_map = angular_map

In [62]:
ang_valid

tensor([])

In [65]:
n_acc_11 +1 / 2

tensor(nan)

In [67]:
valid == 0

tensor(True)

In [68]:
valid = 1

In [69]:
valid

1

In [70]:
10e-5

0.0001

In [None]:
def parseData_stage2(args, sample, random_loc, split='train'):
    img_all = sample['img_all']
    if args.in_light:
        dirs = sample['dirs_all'].expand_as(img)
    else: # predict lighting, prepare ground truth
        n, c, h, w = sample['dirs_all'].shape
        dirs_split = torch.split(sample['dirs_all'].view(n, c), 3, 1)
    
    x_loc, y_loc = random_loc
    img_all_crop = img_all[:,:,x_loc - 8:x_loc + 8, y_loc - 8:y_loc + 8]
    del img_all
    if args.cuda:
        img_all_crop = img_all_crop.cuda()
    # img_all_crop = img_all_crop.repeat_interleave(32,2).repeat_interleave(32,3)
    n, c, h, w = img_all_crop.shape
    imgs = list(torch.split(img_all_crop, 3, 1))
    # ob_map_real = torch.zeros(n, 512, 512).cuda()
    for i in range(len(imgs)):
        img_patch = imgs[i].mean(1)
        img_patch = img_patch.repeat_interleave(32,1).repeat_interleave(32,2)
        dirs = dirs_split[i]
        if args.cuda:
            dirs = dirs.cuda()
        x= 0.5*(dirs[:,0]+1)*(32-1); 
        x=torch.round(x).type(torch.uint8).unsqueeze(1);
        x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
        y= 0.5*(dirs[:,1]+1)*(32-1);
        y=torch.round(y).type(torch.uint8).unsqueeze(1);
        y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
        loc_one_hot = x_one_hot * y_one_hot
        loc_one_hot = loc_one_hot.repeat(1,16,16)
        if i == 0:
            ob_map_real = img_patch * loc_one_hot
        else:
            ob_map_real,_ = torch.stack([ob_map_real, img_patch * loc_one_hot],1).max(1)
        # ob_map.append(img_patch * loc_one_hot)
        # for j in range(n):
        #     x= 0.5*(dirs[j,0]+1)*(32-1); 
        #     x=torch.round(x).type(torch.uint8);
        #     y= 0.5*(dirs[j,1]+1)*(32-1);
        #     y=torch.round(y).type(torch.uint8);
        #     for k in range(16):
        #         for l in range(16):
        #             ob_map_real[j,x + 32 * k, y + 32 * l] = img_patch[j,k,l]
    # ob_map_real = torch.stack(ob_map,1).sum(1)
    ob_map_real = ob_map_real.unsqueeze(1)
    # print (ob_map_real.shape)
    # print (ob_map_real.squeeze()[0,0:32,0:32])
    # if args.cuda:
    #     ob_map_real = ob_map_real.cuda() 
    return ob_map_real

In [None]:
outputs['dir_x'] = self.dir_x_est(out).squeeze(2).squeeze(2)
            outputs['dir_y'] = self.dir_y_est(out).squeeze(2).squeeze(2)

In [None]:
def prepareInputs(self, x):
    imgs = torch.split(x[0], 3, 1)
    idx = 1
    if self.other['in_light']: idx += 1
    if self.other['in_mask']:  idx += 1
    dirs_x = torch.split(x[idx]['dirs_x'], x[0].shape[0], 0)
    dirs_y = torch.split(x[idx]['dirs_y'], x[0].shape[0], 0)
    dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)
    # ints = torch.split(x[idx]['intens'], 3, 1)
    random_x_loc, random_y_loc = x[idx + 1]
    # s2_inputs = []
    # for i in range(len(imgs)):
    #     n, c, h, w = imgs[i].shape
    #     l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)
    #     # l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))
    #     # img   = imgs[i].contiguous().view(n * c, h * w)
    #     # img   = torch.mm(l_int, img).view(n, c, h, w)
    #     img_light = torch.cat([img, l_dir.expand_as(img)], 1)
    #     s2_inputs.append(img_light)
    # return s2_inputs
    s2_inputs = []
    tmp = []
    for i in range(len(imgs)):
        n, c, h, w = imgs[i].shape
        dirs_map = nn.functional.softmax(dirs_x[i],1).unsqueeze(2).repeat(1,1,dirs_x[i].shape[1]) * nn.functional.softmax(dirs_y[i],1).unsqueeze(1).repeat(1,dirs_y[i].shape[1],1)
        dirs_map = dirs_map.repeat(1,16,16).unsqueeze(1)
        dirs_map = dirs_map.cuda()
        # l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)
        # l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))
        # img   = imgs[i].contiguous().view(n * c, h * w)
        # img   = torch.mm(l_int, img).view(n, c, h, w)
        img = imgs[i][:,:,random_x_loc - 8:random_x_loc + 8,random_y_loc - 8:random_y_loc + 8]
        img = img.repeat_interleave(32,2).repeat_interleave(32,3)
        # img = img.mean(1)
        # img = img.unsqueeze(1)
        img_light = torch.cat([img, dirs_map], 1)
        s2_inputs.append(img_light)

        _, x_idx = dirs_x[i].data.max(1)
        _, y_idx = dirs_y[i].data.max(1)
        x=x_idx.type(torch.uint8).unsqueeze(1);
        x_one_hot = torch.zeros(n, 32).cuda().scatter_(1, x.long(), 1).unsqueeze(2).repeat(1,1,32)
        y=y_idx.type(torch.uint8).unsqueeze(1);
        y_one_hot = torch.zeros(n, 32).cuda().scatter_(1, y.long(), 1).unsqueeze(1).repeat(1,32,1)
        loc_one_hot = x_one_hot * y_one_hot
        max_filter = loc_one_hot.repeat(1,16,16)
        max_filter = max_filter.cuda()
        img_gray = img.mean(1)
        img_gray_filtered = img_gray * max_filter
        tmp.append(img_gray_filtered)
    regressor_inputs,_ = torch.stack(tmp,1).max(1)
    regressor_inputs = regressor_inputs.unsqueeze(1)
    return s2_inputs, regressor_inputs