In [1]:
# @inproceedings{brachmann2019ngransac,
#   title={{N}eural- {G}uided {RANSAC}: {L}earning Where to Sample Model Hypotheses},
#   author={Brachmann, Eric and Rother, Carsten},
#   booktitle={ICCV},
#   year={2019}
# }

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torch
import torchvision
import torchvision.transforms as transforms

from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/
from google.colab import files

from hlw_dataset import HLWDataset
from model import Model
import torchvision.utils as vutils

from skimage.io import imsave
import skimage.io as io
from skimage import color
from skimage.io import imsave
from skimage.draw import line, set_color, circle

import time
import warnings
import argparse
import os

from ngdsac import NGDSAC
from loss import Loss

import cv2

Mounted at /content/drive
/content/drive/My Drive/Colab Notebooks/ngdsac_horizon


In [4]:
# parser = argparse.ArgumentParser(description='Train a horizon line estimation network on the HLW dataset using (NG-)DSAC.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# parser.add_argument('--session', '-sid', default='', 
# 	help='custom session name appended to output files, useful to separate different runs of the program')

# parser.add_argument('--capacity', '-c', type=int, default=4, 
# 	help='controls the model capactiy of the network by scaling the number of channels in each layer')

# parser.add_argument('--imagesize', '-is', type=int, default=256, 
# 	help='rescale images to this max side length')

# parser.add_argument('--inlierthreshold', '-it', type=float, default=0.05, 
# 	help='threshold used in the soft inlier count, relative to image size (1 = image width)')

# parser.add_argument('--inlieralpha', '-ia', type=float, default=0.1, 
# 	help='scaling factor for the soft inlier scores (controls the peakiness of the hypothesis distribution)')

# parser.add_argument('--inlierbeta', '-ib', type=float, default=100.0, 
# 	help='scaling factor within the sigmoid of the soft inlier count')

# parser.add_argument('--storeinterval', '-si', type=int, default=1000, 
# 	help='store network weights and a prediction vizualisation every x training iterations')

# parser.add_argument('--hypotheses', '-hyps', type=int, default=16, 
# 	help='number of line hypotheses sampled for each image')

# parser.add_argument('--batchsize', '-bs', type=int, default=32, 
# 	help='training batch size')

# parser.add_argument('--learningrate', '-lr', type=float, default=0.0001, 
# 	help='learning rate')

# parser.add_argument('--iterations', '-i', type=int, default=250000, 
# 	help='number of training iterations (parameter updates)')

# parser.add_argument('--scheduleoffset', '-soff', type=int, default=150000, 
# 	help='start learning rate schedule ofter this many iterations')

# parser.add_argument('--schedulestep', '-sstep', type=int, default=25000, 
# 	help='half learning rate ofter this many iterations')

# parser.add_argument('--samplesize', '-ss', type=int, default=2, 
# 	help='number of ng-dsac runs for each training image to approximate the expectation when learning neural guidance')

# parser.add_argument('--invalidloss', '-il', type=float, default=1, 
# 	help='penalty for sampling invalid hypotheses')

# parser.add_argument('--uniform', '-u', action='store_true', 
# 	help='disable neural-guidance and sample data points uniformely; corresponds to a DSAC model')

#opt = parser.parse_args()

session = 'colab_baseline_26'
capacity = 4
imagesize = 256
inlierthreshold = 0.05
inlieralpha = 0.1
inlierbeta = 100.0
storeinterval = 1000
hypotheses = 16
batchsize = 32
learningrate = 0.0002
iterations = 10000				# Default = 250000
scheduleoffset = 150000
schedulestep = 25000
samplesize = 2
invalidloss = 1
uniform = False

#Directories
tngDir = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/split/'


# setup training set
#trainset = HLWDataset('hlw/split/train.txt', imagesize, training=True)

splitFileName = 'train.txt'
path = tngDir + splitFileName
trainset = HLWDataset(path, imagesize, training=True)
trainset_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=12, batch_size=batchsize)  ##Set workers to 0 from 6 (multithreading issue)

# setup ng dsac estimator
loss = Loss(imagesize) 
ngdsac = NGDSAC(hypotheses, inlierthreshold, inlierbeta, inlieralpha, loss, invalidloss)

# setup network
nn = Model(capacity)
nn.train()
nn = nn.cuda()
#plot_model(nn, to_file="/content/Model/model.png", show_shapes=True, show_layer_names=True)

# optimizer and lr schedule (schedule offset handled further below)
optimizer = optim.Adam(nn.parameters(), lr=learningrate)

# Removing automatic optimizer in pursuit of scholarly excellence
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=schedulestep, gamma=0.5)

# keep track of training progress
train_log = open('log_'+session+'.txt', 'w')
	#uploaded = drive.CreateFile({'title': 'Sample file.txt'})
train_log.write('Sample upload file content')
#files.download('log_'+session+'.txt')
print('Downloaded ' + 'log_'+session+'.txt')

iteration = 0
epoch = 0

while True:

	print('=== Epoch: ', epoch, '========================================')

	for inputs, labels, xStart, xEnd, imh, idx in trainset_loader:

		start_time = time.time()

		# predict points and neural guidance
		inputs = inputs.cuda()
		points, log_probs = nn(inputs)

		if uniform:
			# overwrite neural guidance with uniform sampling probabilities
			log_probs.fill_(1/log_probs.size(1))
			log_probs = torch.log(log_probs)

		g_log_probs = torch.zeros(log_probs.size()).cuda() # gradients for neural guidance
		g_points = torch.zeros(points.size()).cuda() # gradients for point positions

		# approximate neural guidance expectation by sampling

		exp_loss = 0 # mean loss over samples
		losses = [] # losses per sample, we will substract the mean loss later as baseline
		sample_grads = [] # gradients per sample

		for s in range(samplesize):

			# fit lines with ngdsac (also calculates expected loss for DSAC)
			cur_loss = ngdsac(points, log_probs, labels, xStart, xEnd, imh)

			# calculate gradients for 2D point predictions by PyTorch (autograd of expected loss)
			g_points += torch.autograd.grad(cur_loss, points)[0]
			# gradients for neural guidance have been calculated by NGDSAC
			sample_grads.append(ngdsac.g_log_probs.cuda() / batchsize)

			exp_loss += cur_loss
			losses.append(cur_loss)

		g_points /= samplesize
		exp_loss /= samplesize

		# subtract baseline (mean over samples) for neural guidance gradients to reduce variance
		for i,l in enumerate(losses):
			g_log_probs += sample_grads[i] * (float(l) - float(exp_loss))
		g_log_probs /= samplesize * 10
		g_log_probs = g_log_probs.cuda()

		if uniform:
			# DSAC, no neural guidance
			torch.autograd.backward((points), (g_points))
		else:
			# full NG-DSAC
			torch.autograd.backward((points, log_probs), (g_points, g_log_probs))
		
		# # Removing automatic optimizer in pursuit of scholarly excellence
		# optimizer.step() 
		# # apply learning rate schedule
		# if iteration >= scheduleoffset:
		# 	scheduler.step()
		# optimizer.zero_grad()

		optimizer.step()
		optimizer.zero_grad() 

		# wrap up
		end_time = time.time()-start_time
		print('Iteration: %6d, Exp. Loss: %2.2f, Time: %.2fs' % (iteration, exp_loss, end_time), flush=True)

		#train_log.write('%d %f\n' % (iteration, exp_loss))

		if iteration % storeinterval == 0:
			torch.save(nn.state_dict(), './weights_' + session + '.net')
			#files.download('./weights_' + session + '.net')

		del exp_loss, points, log_probs, g_log_probs, g_points, losses, sample_grads

		iteration += 1

		if iteration > iterations:
			break

	epoch += 1

	if iteration > iterations:
		break

print('Done without errors.')
train_log.close()
files.download('log_'+session+'.txt')
files.download('./weights_' + session + '.net')


Downloaded log_colab_baseline_26.txt
Iteration:      0, Exp. Loss: 0.53, Time: 2.64s
Iteration:      1, Exp. Loss: 0.62, Time: 2.92s
Iteration:      2, Exp. Loss: 0.66, Time: 3.19s
Iteration:      3, Exp. Loss: 0.62, Time: 3.49s
Iteration:      4, Exp. Loss: 0.45, Time: 3.09s
Iteration:      5, Exp. Loss: 0.46, Time: 2.91s
Iteration:      6, Exp. Loss: 0.42, Time: 3.39s
Iteration:      7, Exp. Loss: 0.43, Time: 3.04s
Iteration:      8, Exp. Loss: 0.40, Time: 3.21s
Iteration:      9, Exp. Loss: 0.37, Time: 3.51s
Iteration:     10, Exp. Loss: 0.38, Time: 3.06s
Iteration:     11, Exp. Loss: 0.37, Time: 3.17s
Iteration:     12, Exp. Loss: 0.36, Time: 3.57s
Iteration:     13, Exp. Loss: 0.36, Time: 3.41s
Iteration:     14, Exp. Loss: 0.32, Time: 3.07s
Iteration:     15, Exp. Loss: 0.30, Time: 2.88s
Iteration:     16, Exp. Loss: 0.35, Time: 3.32s
Iteration:     17, Exp. Loss: 0.31, Time: 3.16s
Iteration:     18, Exp. Loss: 0.31, Time: 3.04s
Iteration:     19, Exp. Loss: 0.34, Time: 3.34s
Ite

  " Skipping tag %s" % (size, len(data), tag)


Iteration:    176, Exp. Loss: 0.28, Time: 3.33s
Iteration:    177, Exp. Loss: 0.23, Time: 3.29s
Iteration:    178, Exp. Loss: 0.22, Time: 2.96s
Iteration:    179, Exp. Loss: 0.26, Time: 3.24s
Iteration:    180, Exp. Loss: 0.31, Time: 2.91s
Iteration:    181, Exp. Loss: 0.26, Time: 3.03s
Iteration:    182, Exp. Loss: 0.25, Time: 3.10s
Iteration:    183, Exp. Loss: 0.24, Time: 3.09s
Iteration:    184, Exp. Loss: 0.28, Time: 2.91s
Iteration:    185, Exp. Loss: 0.22, Time: 3.21s
Iteration:    186, Exp. Loss: 0.21, Time: 3.21s
Iteration:    187, Exp. Loss: 0.29, Time: 2.90s
Iteration:    188, Exp. Loss: 0.24, Time: 3.28s
Iteration:    189, Exp. Loss: 0.24, Time: 3.16s
Iteration:    190, Exp. Loss: 0.31, Time: 3.31s
Iteration:    191, Exp. Loss: 0.22, Time: 3.50s
Iteration:    192, Exp. Loss: 0.20, Time: 3.14s
Iteration:    193, Exp. Loss: 0.23, Time: 3.65s
Iteration:    194, Exp. Loss: 0.27, Time: 3.11s
Iteration:    195, Exp. Loss: 0.24, Time: 3.45s
Iteration:    196, Exp. Loss: 0.26, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:    419, Exp. Loss: 0.29, Time: 3.49s
Iteration:    420, Exp. Loss: 0.23, Time: 2.77s
Iteration:    421, Exp. Loss: 0.23, Time: 3.05s
Iteration:    422, Exp. Loss: 0.23, Time: 3.20s
Iteration:    423, Exp. Loss: 0.22, Time: 3.33s
Iteration:    424, Exp. Loss: 0.27, Time: 3.21s
Iteration:    425, Exp. Loss: 0.20, Time: 2.88s
Iteration:    426, Exp. Loss: 0.21, Time: 2.66s
Iteration:    427, Exp. Loss: 0.20, Time: 3.28s
Iteration:    428, Exp. Loss: 0.22, Time: 3.05s
Iteration:    429, Exp. Loss: 0.21, Time: 3.07s
Iteration:    430, Exp. Loss: 0.24, Time: 3.25s
Iteration:    431, Exp. Loss: 0.21, Time: 2.95s
Iteration:    432, Exp. Loss: 0.23, Time: 3.20s
Iteration:    433, Exp. Loss: 0.23, Time: 3.27s
Iteration:    434, Exp. Loss: 0.20, Time: 3.48s
Iteration:    435, Exp. Loss: 0.20, Time: 2.91s
Iteration:    436, Exp. Loss: 0.22, Time: 3.60s
Iteration:    437, Exp. Loss: 0.25, Time: 2.99s
Iteration:    438, Exp. Loss: 0.20, Time: 3.34s
Iteration:    439, Exp. Loss: 0.22, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:    729, Exp. Loss: 0.23, Time: 3.23s
Iteration:    730, Exp. Loss: 0.23, Time: 2.74s
Iteration:    731, Exp. Loss: 0.23, Time: 2.85s
Iteration:    732, Exp. Loss: 0.20, Time: 2.99s
Iteration:    733, Exp. Loss: 0.23, Time: 3.05s
Iteration:    734, Exp. Loss: 0.24, Time: 2.70s
Iteration:    735, Exp. Loss: 0.20, Time: 3.11s
Iteration:    736, Exp. Loss: 0.23, Time: 3.32s
Iteration:    737, Exp. Loss: 0.23, Time: 2.82s
Iteration:    738, Exp. Loss: 0.23, Time: 2.80s
Iteration:    739, Exp. Loss: 0.20, Time: 2.55s
Iteration:    740, Exp. Loss: 0.23, Time: 3.24s
Iteration:    741, Exp. Loss: 0.18, Time: 2.83s
Iteration:    742, Exp. Loss: 0.24, Time: 3.03s
Iteration:    743, Exp. Loss: 0.19, Time: 3.29s
Iteration:    744, Exp. Loss: 0.19, Time: 3.21s
Iteration:    745, Exp. Loss: 0.18, Time: 3.67s
Iteration:    746, Exp. Loss: 0.23, Time: 3.21s
Iteration:    747, Exp. Loss: 0.19, Time: 3.23s
Iteration:    748, Exp. Loss: 0.24, Time: 3.24s
Iteration:    749, Exp. Loss: 0.20, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   1033, Exp. Loss: 0.21, Time: 3.41s
Iteration:   1034, Exp. Loss: 0.22, Time: 3.16s
Iteration:   1035, Exp. Loss: 0.18, Time: 2.63s
Iteration:   1036, Exp. Loss: 0.22, Time: 2.48s
Iteration:   1037, Exp. Loss: 0.21, Time: 2.35s
Iteration:   1038, Exp. Loss: 0.23, Time: 2.20s
Iteration:   1039, Exp. Loss: 0.24, Time: 2.12s
Iteration:   1040, Exp. Loss: 0.27, Time: 2.01s
Iteration:   1041, Exp. Loss: 0.19, Time: 2.09s
Iteration:   1042, Exp. Loss: 0.24, Time: 1.95s
Iteration:   1043, Exp. Loss: 0.22, Time: 1.99s
Iteration:   1044, Exp. Loss: 0.21, Time: 1.96s
Iteration:   1045, Exp. Loss: 0.17, Time: 1.96s
Iteration:   1046, Exp. Loss: 0.18, Time: 1.97s
Iteration:   1047, Exp. Loss: 0.22, Time: 1.93s
Iteration:   1048, Exp. Loss: 0.22, Time: 1.97s
Iteration:   1049, Exp. Loss: 0.23, Time: 2.01s
Iteration:   1050, Exp. Loss: 0.21, Time: 1.97s
Iteration:   1051, Exp. Loss: 0.16, Time: 2.01s
Iteration:   1052, Exp. Loss: 0.23, Time: 1.99s
Iteration:   1053, Exp. Loss: 0.19, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   1190, Exp. Loss: 0.16, Time: 3.26s
Iteration:   1191, Exp. Loss: 0.26, Time: 2.40s
Iteration:   1192, Exp. Loss: 0.16, Time: 3.21s
Iteration:   1193, Exp. Loss: 0.19, Time: 2.77s
Iteration:   1194, Exp. Loss: 0.25, Time: 2.90s
Iteration:   1195, Exp. Loss: 0.18, Time: 3.09s
Iteration:   1196, Exp. Loss: 0.21, Time: 3.01s
Iteration:   1197, Exp. Loss: 0.26, Time: 2.99s
Iteration:   1198, Exp. Loss: 0.22, Time: 3.01s
Iteration:   1199, Exp. Loss: 0.20, Time: 3.18s
Iteration:   1200, Exp. Loss: 0.21, Time: 2.75s
Iteration:   1201, Exp. Loss: 0.20, Time: 2.56s
Iteration:   1202, Exp. Loss: 0.23, Time: 3.33s
Iteration:   1203, Exp. Loss: 0.16, Time: 3.74s
Iteration:   1204, Exp. Loss: 0.16, Time: 2.75s
Iteration:   1205, Exp. Loss: 0.19, Time: 2.97s
Iteration:   1206, Exp. Loss: 0.19, Time: 2.98s
Iteration:   1207, Exp. Loss: 0.23, Time: 2.88s
Iteration:   1208, Exp. Loss: 0.19, Time: 2.86s
Iteration:   1209, Exp. Loss: 0.19, Time: 2.89s
Iteration:   1210, Exp. Loss: 0.24, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   1548, Exp. Loss: 0.18, Time: 3.06s
Iteration:   1549, Exp. Loss: 0.19, Time: 3.42s
Iteration:   1550, Exp. Loss: 0.17, Time: 3.33s
Iteration:   1551, Exp. Loss: 0.19, Time: 3.65s
Iteration:   1552, Exp. Loss: 0.22, Time: 2.85s
Iteration:   1553, Exp. Loss: 0.21, Time: 3.23s
Iteration:   1554, Exp. Loss: 0.21, Time: 3.04s
Iteration:   1555, Exp. Loss: 0.26, Time: 2.73s
Iteration:   1556, Exp. Loss: 0.18, Time: 3.28s
Iteration:   1557, Exp. Loss: 0.19, Time: 2.88s
Iteration:   1558, Exp. Loss: 0.17, Time: 3.47s
Iteration:   1559, Exp. Loss: 0.19, Time: 3.08s
Iteration:   1560, Exp. Loss: 0.21, Time: 3.54s
Iteration:   1561, Exp. Loss: 0.15, Time: 2.81s
Iteration:   1562, Exp. Loss: 0.22, Time: 2.92s
Iteration:   1563, Exp. Loss: 0.17, Time: 2.88s
Iteration:   1564, Exp. Loss: 0.19, Time: 2.91s
Iteration:   1565, Exp. Loss: 0.16, Time: 2.37s
Iteration:   1566, Exp. Loss: 0.17, Time: 2.43s
Iteration:   1567, Exp. Loss: 0.13, Time: 2.26s
Iteration:   1568, Exp. Loss: 0.19, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   1841, Exp. Loss: 0.18, Time: 3.21s
Iteration:   1842, Exp. Loss: 0.19, Time: 3.19s
Iteration:   1843, Exp. Loss: 0.22, Time: 2.67s
Iteration:   1844, Exp. Loss: 0.20, Time: 2.86s
Iteration:   1845, Exp. Loss: 0.22, Time: 2.64s
Iteration:   1846, Exp. Loss: 0.24, Time: 2.85s
Iteration:   1847, Exp. Loss: 0.19, Time: 2.92s
Iteration:   1848, Exp. Loss: 0.19, Time: 2.83s
Iteration:   1849, Exp. Loss: 0.18, Time: 2.71s
Iteration:   1850, Exp. Loss: 0.16, Time: 3.08s
Iteration:   1851, Exp. Loss: 0.15, Time: 3.21s
Iteration:   1852, Exp. Loss: 0.17, Time: 3.00s
Iteration:   1853, Exp. Loss: 0.18, Time: 3.32s
Iteration:   1854, Exp. Loss: 0.15, Time: 3.23s
Iteration:   1855, Exp. Loss: 0.18, Time: 2.86s
Iteration:   1856, Exp. Loss: 0.15, Time: 2.96s
Iteration:   1857, Exp. Loss: 0.22, Time: 3.19s
Iteration:   1858, Exp. Loss: 0.18, Time: 3.11s
Iteration:   1859, Exp. Loss: 0.19, Time: 3.19s
Iteration:   1860, Exp. Loss: 0.21, Time: 3.36s
Iteration:   1861, Exp. Loss: 0.18, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   1970, Exp. Loss: 0.23, Time: 3.18s
Iteration:   1971, Exp. Loss: 0.15, Time: 3.09s
Iteration:   1972, Exp. Loss: 0.16, Time: 3.17s
Iteration:   1973, Exp. Loss: 0.22, Time: 2.88s
Iteration:   1974, Exp. Loss: 0.16, Time: 3.14s
Iteration:   1975, Exp. Loss: 0.24, Time: 3.07s
Iteration:   1976, Exp. Loss: 0.17, Time: 3.09s
Iteration:   1977, Exp. Loss: 0.22, Time: 3.06s
Iteration:   1978, Exp. Loss: 0.20, Time: 2.68s
Iteration:   1979, Exp. Loss: 0.22, Time: 3.02s
Iteration:   1980, Exp. Loss: 0.22, Time: 2.72s
Iteration:   1981, Exp. Loss: 0.19, Time: 3.08s
Iteration:   1982, Exp. Loss: 0.15, Time: 3.25s
Iteration:   1983, Exp. Loss: 0.24, Time: 3.08s
Iteration:   1984, Exp. Loss: 0.13, Time: 3.21s
Iteration:   1985, Exp. Loss: 0.24, Time: 2.94s
Iteration:   1986, Exp. Loss: 0.21, Time: 3.46s
Iteration:   1987, Exp. Loss: 0.19, Time: 3.25s
Iteration:   1988, Exp. Loss: 0.16, Time: 2.85s
Iteration:   1989, Exp. Loss: 0.22, Time: 3.14s
Iteration:   1990, Exp. Loss: 0.19, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   2267, Exp. Loss: 0.17, Time: 3.27s
Iteration:   2268, Exp. Loss: 0.15, Time: 3.00s
Iteration:   2269, Exp. Loss: 0.24, Time: 3.01s
Iteration:   2270, Exp. Loss: 0.14, Time: 3.34s
Iteration:   2271, Exp. Loss: 0.15, Time: 3.32s
Iteration:   2272, Exp. Loss: 0.12, Time: 3.04s
Iteration:   2273, Exp. Loss: 0.16, Time: 2.50s
Iteration:   2274, Exp. Loss: 0.19, Time: 3.44s
Iteration:   2275, Exp. Loss: 0.19, Time: 2.64s
Iteration:   2276, Exp. Loss: 0.18, Time: 2.51s
Iteration:   2277, Exp. Loss: 0.17, Time: 2.91s
Iteration:   2278, Exp. Loss: 0.18, Time: 2.92s
Iteration:   2279, Exp. Loss: 0.14, Time: 3.72s
Iteration:   2280, Exp. Loss: 0.14, Time: 3.10s
Iteration:   2281, Exp. Loss: 0.21, Time: 2.72s
Iteration:   2282, Exp. Loss: 0.14, Time: 2.70s
Iteration:   2283, Exp. Loss: 0.20, Time: 3.38s
Iteration:   2284, Exp. Loss: 0.16, Time: 3.27s
Iteration:   2285, Exp. Loss: 0.20, Time: 2.81s
Iteration:   2286, Exp. Loss: 0.16, Time: 2.77s
Iteration:   2287, Exp. Loss: 0.17, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   2297, Exp. Loss: 0.16, Time: 2.88s
Iteration:   2298, Exp. Loss: 0.23, Time: 3.82s
Iteration:   2299, Exp. Loss: 0.18, Time: 2.92s
Iteration:   2300, Exp. Loss: 0.24, Time: 2.97s
Iteration:   2301, Exp. Loss: 0.17, Time: 2.94s
Iteration:   2302, Exp. Loss: 0.20, Time: 2.53s
Iteration:   2303, Exp. Loss: 0.17, Time: 2.98s
Iteration:   2304, Exp. Loss: 0.18, Time: 2.86s
Iteration:   2305, Exp. Loss: 0.16, Time: 3.05s
Iteration:   2306, Exp. Loss: 0.18, Time: 2.82s
Iteration:   2307, Exp. Loss: 0.15, Time: 3.58s
Iteration:   2308, Exp. Loss: 0.14, Time: 2.57s
Iteration:   2309, Exp. Loss: 0.16, Time: 3.19s
Iteration:   2310, Exp. Loss: 0.20, Time: 2.83s
Iteration:   2311, Exp. Loss: 0.24, Time: 3.44s
Iteration:   2312, Exp. Loss: 0.15, Time: 3.53s
Iteration:   2313, Exp. Loss: 0.20, Time: 2.77s
Iteration:   2314, Exp. Loss: 0.17, Time: 3.39s
Iteration:   2315, Exp. Loss: 0.17, Time: 2.95s
Iteration:   2316, Exp. Loss: 0.17, Time: 2.90s
Iteration:   2317, Exp. Loss: 0.21, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   2984, Exp. Loss: 0.17, Time: 2.85s
Iteration:   2985, Exp. Loss: 0.18, Time: 2.94s
Iteration:   2986, Exp. Loss: 0.17, Time: 2.83s
Iteration:   2987, Exp. Loss: 0.18, Time: 3.01s
Iteration:   2988, Exp. Loss: 0.19, Time: 3.01s
Iteration:   2989, Exp. Loss: 0.18, Time: 3.07s
Iteration:   2990, Exp. Loss: 0.15, Time: 3.29s
Iteration:   2991, Exp. Loss: 0.15, Time: 3.44s
Iteration:   2992, Exp. Loss: 0.21, Time: 2.77s
Iteration:   2993, Exp. Loss: 0.14, Time: 2.83s
Iteration:   2994, Exp. Loss: 0.19, Time: 3.08s
Iteration:   2995, Exp. Loss: 0.18, Time: 3.28s
Iteration:   2996, Exp. Loss: 0.17, Time: 2.93s
Iteration:   2997, Exp. Loss: 0.16, Time: 3.07s
Iteration:   2998, Exp. Loss: 0.15, Time: 2.98s
Iteration:   2999, Exp. Loss: 0.19, Time: 2.87s
Iteration:   3000, Exp. Loss: 0.19, Time: 3.46s
Iteration:   3001, Exp. Loss: 0.14, Time: 3.09s
Iteration:   3002, Exp. Loss: 0.13, Time: 3.03s
Iteration:   3003, Exp. Loss: 0.19, Time: 2.65s
Iteration:   3004, Exp. Loss: 0.16, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   3021, Exp. Loss: 0.20, Time: 2.80s
Iteration:   3022, Exp. Loss: 0.11, Time: 3.36s
Iteration:   3023, Exp. Loss: 0.16, Time: 3.27s
Iteration:   3024, Exp. Loss: 0.18, Time: 3.33s
Iteration:   3025, Exp. Loss: 0.12, Time: 2.92s
Iteration:   3026, Exp. Loss: 0.15, Time: 3.00s
Iteration:   3027, Exp. Loss: 0.19, Time: 3.30s
Iteration:   3028, Exp. Loss: 0.13, Time: 2.95s
Iteration:   3029, Exp. Loss: 0.23, Time: 2.80s
Iteration:   3030, Exp. Loss: 0.20, Time: 2.85s
Iteration:   3031, Exp. Loss: 0.19, Time: 2.66s
Iteration:   3032, Exp. Loss: 0.20, Time: 2.82s
Iteration:   3033, Exp. Loss: 0.19, Time: 3.16s
Iteration:   3034, Exp. Loss: 0.15, Time: 2.87s
Iteration:   3035, Exp. Loss: 0.20, Time: 3.11s
Iteration:   3036, Exp. Loss: 0.22, Time: 3.36s
Iteration:   3037, Exp. Loss: 0.15, Time: 3.23s
Iteration:   3038, Exp. Loss: 0.16, Time: 3.02s
Iteration:   3039, Exp. Loss: 0.12, Time: 3.31s
Iteration:   3040, Exp. Loss: 0.18, Time: 3.05s
Iteration:   3041, Exp. Loss: 0.16, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   3328, Exp. Loss: 0.19, Time: 2.98s
Iteration:   3329, Exp. Loss: 0.11, Time: 2.84s
Iteration:   3330, Exp. Loss: 0.17, Time: 2.50s
Iteration:   3331, Exp. Loss: 0.16, Time: 3.15s
Iteration:   3332, Exp. Loss: 0.15, Time: 3.03s
Iteration:   3333, Exp. Loss: 0.17, Time: 2.98s
Iteration:   3334, Exp. Loss: 0.16, Time: 3.48s
Iteration:   3335, Exp. Loss: 0.16, Time: 3.17s
Iteration:   3336, Exp. Loss: 0.20, Time: 2.82s
Iteration:   3337, Exp. Loss: 0.13, Time: 2.97s
Iteration:   3338, Exp. Loss: 0.22, Time: 2.87s
Iteration:   3339, Exp. Loss: 0.17, Time: 2.87s
Iteration:   3340, Exp. Loss: 0.18, Time: 3.08s
Iteration:   3341, Exp. Loss: 0.18, Time: 2.97s
Iteration:   3342, Exp. Loss: 0.19, Time: 2.57s
Iteration:   3343, Exp. Loss: 0.17, Time: 3.30s
Iteration:   3344, Exp. Loss: 0.14, Time: 3.17s
Iteration:   3345, Exp. Loss: 0.19, Time: 3.44s
Iteration:   3346, Exp. Loss: 0.16, Time: 2.70s
Iteration:   3347, Exp. Loss: 0.15, Time: 2.47s
Iteration:   3348, Exp. Loss: 0.16, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   3589, Exp. Loss: 0.21, Time: 3.18s
Iteration:   3590, Exp. Loss: 0.22, Time: 3.49s
Iteration:   3591, Exp. Loss: 0.14, Time: 3.01s
Iteration:   3592, Exp. Loss: 0.14, Time: 3.30s
Iteration:   3593, Exp. Loss: 0.18, Time: 2.87s
Iteration:   3594, Exp. Loss: 0.15, Time: 3.50s
Iteration:   3595, Exp. Loss: 0.16, Time: 3.10s
Iteration:   3596, Exp. Loss: 0.15, Time: 2.86s
Iteration:   3597, Exp. Loss: 0.17, Time: 3.12s
Iteration:   3598, Exp. Loss: 0.16, Time: 3.07s
Iteration:   3599, Exp. Loss: 0.13, Time: 3.08s
Iteration:   3600, Exp. Loss: 0.18, Time: 3.33s
Iteration:   3601, Exp. Loss: 0.28, Time: 3.00s
Iteration:   3602, Exp. Loss: 0.14, Time: 2.94s
Iteration:   3603, Exp. Loss: 0.18, Time: 3.02s
Iteration:   3604, Exp. Loss: 0.16, Time: 3.05s
Iteration:   3605, Exp. Loss: 0.14, Time: 2.95s
Iteration:   3606, Exp. Loss: 0.17, Time: 2.49s
Iteration:   3607, Exp. Loss: 0.16, Time: 3.11s
Iteration:   3608, Exp. Loss: 0.22, Time: 3.26s
Iteration:   3609, Exp. Loss: 0.15, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4138, Exp. Loss: 0.16, Time: 3.35s
Iteration:   4139, Exp. Loss: 0.15, Time: 3.00s
Iteration:   4140, Exp. Loss: 0.20, Time: 3.32s
Iteration:   4141, Exp. Loss: 0.12, Time: 2.89s


  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4142, Exp. Loss: 0.13, Time: 3.26s
Iteration:   4143, Exp. Loss: 0.14, Time: 2.90s
Iteration:   4144, Exp. Loss: 0.15, Time: 3.24s
Iteration:   4145, Exp. Loss: 0.12, Time: 3.04s
Iteration:   4146, Exp. Loss: 0.17, Time: 3.51s
Iteration:   4147, Exp. Loss: 0.17, Time: 2.99s
Iteration:   4148, Exp. Loss: 0.16, Time: 3.03s
Iteration:   4149, Exp. Loss: 0.16, Time: 2.77s
Iteration:   4150, Exp. Loss: 0.12, Time: 2.97s
Iteration:   4151, Exp. Loss: 0.12, Time: 3.09s
Iteration:   4152, Exp. Loss: 0.11, Time: 3.45s
Iteration:   4153, Exp. Loss: 0.15, Time: 3.22s
Iteration:   4154, Exp. Loss: 0.15, Time: 3.11s
Iteration:   4155, Exp. Loss: 0.17, Time: 3.48s
Iteration:   4156, Exp. Loss: 0.11, Time: 3.06s
Iteration:   4157, Exp. Loss: 0.13, Time: 3.00s
Iteration:   4158, Exp. Loss: 0.15, Time: 3.34s
Iteration:   4159, Exp. Loss: 0.13, Time: 3.33s
Iteration:   4160, Exp. Loss: 0.14, Time: 2.88s
Iteration:   4161, Exp. Loss: 0.18, Time: 3.36s
Iteration:   4162, Exp. Loss: 0.18, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4409, Exp. Loss: 0.44, Time: 1.97s
Iteration:   4410, Exp. Loss: 0.48, Time: 2.37s
Iteration:   4411, Exp. Loss: 0.40, Time: 2.49s
Iteration:   4412, Exp. Loss: 0.44, Time: 2.86s
Iteration:   4413, Exp. Loss: 0.32, Time: 3.44s
Iteration:   4414, Exp. Loss: 0.39, Time: 3.10s
Iteration:   4415, Exp. Loss: 0.28, Time: 1.92s
Iteration:   4416, Exp. Loss: 0.30, Time: 2.52s
Iteration:   4417, Exp. Loss: 0.33, Time: 2.20s
Iteration:   4418, Exp. Loss: 0.33, Time: 2.61s
Iteration:   4419, Exp. Loss: 0.29, Time: 2.51s
Iteration:   4420, Exp. Loss: 0.34, Time: 3.00s
Iteration:   4421, Exp. Loss: 0.37, Time: 2.72s
Iteration:   4422, Exp. Loss: 0.41, Time: 2.81s
Iteration:   4423, Exp. Loss: 0.29, Time: 3.34s
Iteration:   4424, Exp. Loss: 0.30, Time: 2.43s
Iteration:   4425, Exp. Loss: 0.25, Time: 2.75s
Iteration:   4426, Exp. Loss: 0.20, Time: 2.65s
Iteration:   4427, Exp. Loss: 0.27, Time: 2.58s
Iteration:   4428, Exp. Loss: 0.19, Time: 2.69s
Iteration:   4429, Exp. Loss: 0.25, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4491, Exp. Loss: 0.16, Time: 3.14s
Iteration:   4492, Exp. Loss: 0.18, Time: 2.79s
Iteration:   4493, Exp. Loss: 0.15, Time: 2.83s
Iteration:   4494, Exp. Loss: 0.14, Time: 2.81s
Iteration:   4495, Exp. Loss: 0.22, Time: 3.38s
Iteration:   4496, Exp. Loss: 0.14, Time: 3.11s
Iteration:   4497, Exp. Loss: 0.16, Time: 2.93s
Iteration:   4498, Exp. Loss: 0.13, Time: 3.11s
Iteration:   4499, Exp. Loss: 0.19, Time: 3.02s
Iteration:   4500, Exp. Loss: 0.12, Time: 2.82s
Iteration:   4501, Exp. Loss: 0.17, Time: 3.16s
Iteration:   4502, Exp. Loss: 0.15, Time: 2.83s
Iteration:   4503, Exp. Loss: 0.15, Time: 3.16s
Iteration:   4504, Exp. Loss: 0.15, Time: 2.88s
Iteration:   4505, Exp. Loss: 0.12, Time: 2.80s
Iteration:   4506, Exp. Loss: 0.14, Time: 2.71s
Iteration:   4507, Exp. Loss: 0.19, Time: 3.54s
Iteration:   4508, Exp. Loss: 0.14, Time: 3.17s
Iteration:   4509, Exp. Loss: 0.16, Time: 2.61s
Iteration:   4510, Exp. Loss: 0.21, Time: 3.14s
Iteration:   4511, Exp. Loss: 0.19, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4795, Exp. Loss: 0.13, Time: 3.13s
Iteration:   4796, Exp. Loss: 0.15, Time: 3.10s
Iteration:   4797, Exp. Loss: 0.17, Time: 2.39s
Iteration:   4798, Exp. Loss: 0.16, Time: 3.22s
Iteration:   4799, Exp. Loss: 0.09, Time: 3.86s
Iteration:   4800, Exp. Loss: 0.17, Time: 2.85s
Iteration:   4801, Exp. Loss: 0.17, Time: 2.71s
Iteration:   4802, Exp. Loss: 0.14, Time: 2.66s
Iteration:   4803, Exp. Loss: 0.13, Time: 2.98s
Iteration:   4804, Exp. Loss: 0.14, Time: 2.59s
Iteration:   4805, Exp. Loss: 0.15, Time: 2.98s
Iteration:   4806, Exp. Loss: 0.18, Time: 2.59s
Iteration:   4807, Exp. Loss: 0.15, Time: 3.08s
Iteration:   4808, Exp. Loss: 0.12, Time: 3.52s
Iteration:   4809, Exp. Loss: 0.14, Time: 3.10s
Iteration:   4810, Exp. Loss: 0.16, Time: 3.01s
Iteration:   4811, Exp. Loss: 0.15, Time: 2.75s
Iteration:   4812, Exp. Loss: 0.21, Time: 3.19s
Iteration:   4813, Exp. Loss: 0.14, Time: 2.82s
Iteration:   4814, Exp. Loss: 0.26, Time: 2.61s
Iteration:   4815, Exp. Loss: 0.12, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   4847, Exp. Loss: 0.21, Time: 2.67s
Iteration:   4848, Exp. Loss: 0.21, Time: 3.00s
Iteration:   4849, Exp. Loss: 0.13, Time: 2.96s
Iteration:   4850, Exp. Loss: 0.14, Time: 2.78s
Iteration:   4851, Exp. Loss: 0.11, Time: 3.27s
Iteration:   4852, Exp. Loss: 0.16, Time: 3.25s
Iteration:   4853, Exp. Loss: 0.14, Time: 2.79s
Iteration:   4854, Exp. Loss: 0.14, Time: 3.14s
Iteration:   4855, Exp. Loss: 0.14, Time: 2.71s
Iteration:   4856, Exp. Loss: 0.13, Time: 2.80s
Iteration:   4857, Exp. Loss: 0.19, Time: 2.97s
Iteration:   4858, Exp. Loss: 0.12, Time: 3.81s
Iteration:   4859, Exp. Loss: 0.12, Time: 2.50s
Iteration:   4860, Exp. Loss: 0.11, Time: 3.12s
Iteration:   4861, Exp. Loss: 0.22, Time: 2.59s
Iteration:   4862, Exp. Loss: 0.16, Time: 3.00s
Iteration:   4863, Exp. Loss: 0.16, Time: 3.17s
Iteration:   4864, Exp. Loss: 0.10, Time: 2.89s
Iteration:   4865, Exp. Loss: 0.16, Time: 2.65s
Iteration:   4866, Exp. Loss: 0.11, Time: 2.85s
Iteration:   4867, Exp. Loss: 0.15, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   5572, Exp. Loss: 0.18, Time: 2.52s
Iteration:   5573, Exp. Loss: 0.17, Time: 3.04s
Iteration:   5574, Exp. Loss: 0.15, Time: 3.24s
Iteration:   5575, Exp. Loss: 0.17, Time: 2.89s
Iteration:   5576, Exp. Loss: 0.15, Time: 3.27s
Iteration:   5577, Exp. Loss: 0.15, Time: 3.26s
Iteration:   5578, Exp. Loss: 0.19, Time: 3.07s
Iteration:   5579, Exp. Loss: 0.19, Time: 2.88s
Iteration:   5580, Exp. Loss: 0.15, Time: 2.99s
Iteration:   5581, Exp. Loss: 0.14, Time: 3.05s
Iteration:   5582, Exp. Loss: 0.14, Time: 2.89s
Iteration:   5583, Exp. Loss: 0.21, Time: 2.82s
Iteration:   5584, Exp. Loss: 0.11, Time: 3.08s
Iteration:   5585, Exp. Loss: 0.12, Time: 3.05s
Iteration:   5586, Exp. Loss: 0.13, Time: 3.10s
Iteration:   5587, Exp. Loss: 0.14, Time: 3.03s
Iteration:   5588, Exp. Loss: 0.17, Time: 3.07s
Iteration:   5589, Exp. Loss: 0.11, Time: 2.71s
Iteration:   5590, Exp. Loss: 0.15, Time: 2.84s
Iteration:   5591, Exp. Loss: 0.18, Time: 2.48s
Iteration:   5592, Exp. Loss: 0.13, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   5728, Exp. Loss: 0.16, Time: 3.46s
Iteration:   5729, Exp. Loss: 0.17, Time: 3.28s
Iteration:   5730, Exp. Loss: 0.14, Time: 3.48s
Iteration:   5731, Exp. Loss: 0.16, Time: 2.77s
Iteration:   5732, Exp. Loss: 0.15, Time: 3.10s
Iteration:   5733, Exp. Loss: 0.14, Time: 2.99s
Iteration:   5734, Exp. Loss: 0.10, Time: 2.83s
Iteration:   5735, Exp. Loss: 0.17, Time: 2.97s
Iteration:   5736, Exp. Loss: 0.13, Time: 2.94s
Iteration:   5737, Exp. Loss: 0.21, Time: 2.68s
Iteration:   5738, Exp. Loss: 0.20, Time: 3.08s
Iteration:   5739, Exp. Loss: 0.14, Time: 2.78s
Iteration:   5740, Exp. Loss: 0.18, Time: 2.98s
Iteration:   5741, Exp. Loss: 0.13, Time: 2.80s
Iteration:   5742, Exp. Loss: 0.21, Time: 2.98s
Iteration:   5743, Exp. Loss: 0.14, Time: 2.92s
Iteration:   5744, Exp. Loss: 0.16, Time: 2.97s
Iteration:   5745, Exp. Loss: 0.16, Time: 3.00s
Iteration:   5746, Exp. Loss: 0.11, Time: 3.29s
Iteration:   5747, Exp. Loss: 0.19, Time: 3.36s
Iteration:   5748, Exp. Loss: 0.16, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   6059, Exp. Loss: 0.13, Time: 3.06s
Iteration:   6060, Exp. Loss: 0.11, Time: 3.33s
Iteration:   6061, Exp. Loss: 0.13, Time: 3.19s
Iteration:   6062, Exp. Loss: 0.19, Time: 3.30s
Iteration:   6063, Exp. Loss: 0.12, Time: 2.70s
Iteration:   6064, Exp. Loss: 0.14, Time: 2.99s
Iteration:   6065, Exp. Loss: 0.24, Time: 2.95s
Iteration:   6066, Exp. Loss: 0.13, Time: 2.90s
Iteration:   6067, Exp. Loss: 0.15, Time: 2.82s
Iteration:   6068, Exp. Loss: 0.17, Time: 3.15s
Iteration:   6069, Exp. Loss: 0.16, Time: 3.04s
Iteration:   6070, Exp. Loss: 0.14, Time: 2.98s
Iteration:   6071, Exp. Loss: 0.13, Time: 3.20s
Iteration:   6072, Exp. Loss: 0.14, Time: 3.03s
Iteration:   6073, Exp. Loss: 0.16, Time: 2.85s
Iteration:   6074, Exp. Loss: 0.15, Time: 3.61s
Iteration:   6075, Exp. Loss: 0.13, Time: 3.03s
Iteration:   6076, Exp. Loss: 0.17, Time: 2.84s
Iteration:   6077, Exp. Loss: 0.18, Time: 2.88s
Iteration:   6078, Exp. Loss: 0.13, Time: 3.44s
Iteration:   6079, Exp. Loss: 0.13, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   6273, Exp. Loss: 0.15, Time: 2.98s
Iteration:   6274, Exp. Loss: 0.12, Time: 3.16s
Iteration:   6275, Exp. Loss: 0.11, Time: 2.91s
Iteration:   6276, Exp. Loss: 0.13, Time: 2.82s
Iteration:   6277, Exp. Loss: 0.13, Time: 2.98s
Iteration:   6278, Exp. Loss: 0.11, Time: 2.99s
Iteration:   6279, Exp. Loss: 0.09, Time: 3.14s
Iteration:   6280, Exp. Loss: 0.15, Time: 3.18s
Iteration:   6281, Exp. Loss: 0.17, Time: 3.27s
Iteration:   6282, Exp. Loss: 0.14, Time: 3.02s
Iteration:   6283, Exp. Loss: 0.11, Time: 3.29s
Iteration:   6284, Exp. Loss: 0.14, Time: 3.27s
Iteration:   6285, Exp. Loss: 0.14, Time: 3.34s
Iteration:   6286, Exp. Loss: 0.16, Time: 3.21s
Iteration:   6287, Exp. Loss: 0.15, Time: 3.06s
Iteration:   6288, Exp. Loss: 0.14, Time: 3.18s
Iteration:   6289, Exp. Loss: 0.14, Time: 3.04s
Iteration:   6290, Exp. Loss: 0.14, Time: 2.76s
Iteration:   6291, Exp. Loss: 0.12, Time: 3.04s
Iteration:   6292, Exp. Loss: 0.11, Time: 3.25s
Iteration:   6293, Exp. Loss: 0.15, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   6348, Exp. Loss: 0.16, Time: 10.69s
Iteration:   6349, Exp. Loss: 0.13, Time: 8.86s
Iteration:   6350, Exp. Loss: 0.12, Time: 4.22s
Iteration:   6351, Exp. Loss: 0.13, Time: 3.66s
Iteration:   6352, Exp. Loss: 0.14, Time: 3.91s
Iteration:   6353, Exp. Loss: 0.15, Time: 3.95s
Iteration:   6354, Exp. Loss: 0.12, Time: 2.64s
Iteration:   6355, Exp. Loss: 0.15, Time: 2.38s
Iteration:   6356, Exp. Loss: 0.16, Time: 2.91s
Iteration:   6357, Exp. Loss: 0.14, Time: 3.16s
Iteration:   6358, Exp. Loss: 0.13, Time: 3.34s
Iteration:   6359, Exp. Loss: 0.15, Time: 3.79s
Iteration:   6360, Exp. Loss: 0.12, Time: 2.41s
Iteration:   6361, Exp. Loss: 0.13, Time: 2.38s


  " Skipping tag %s" % (size, len(data), tag)


Iteration:   6362, Exp. Loss: 0.09, Time: 2.55s
Iteration:   6363, Exp. Loss: 0.13, Time: 3.19s
Iteration:   6364, Exp. Loss: 0.14, Time: 3.01s
Iteration:   6365, Exp. Loss: 0.13, Time: 3.57s
Iteration:   6366, Exp. Loss: 0.17, Time: 2.88s
Iteration:   6367, Exp. Loss: 0.11, Time: 3.30s
Iteration:   6368, Exp. Loss: 0.13, Time: 2.87s
Iteration:   6369, Exp. Loss: 0.18, Time: 2.41s
Iteration:   6370, Exp. Loss: 0.10, Time: 2.99s
Iteration:   6371, Exp. Loss: 0.10, Time: 3.19s
Iteration:   6372, Exp. Loss: 0.12, Time: 2.97s
Iteration:   6373, Exp. Loss: 0.13, Time: 2.58s
Iteration:   6374, Exp. Loss: 0.11, Time: 2.87s
Iteration:   6375, Exp. Loss: 0.13, Time: 3.05s
Iteration:   6376, Exp. Loss: 0.14, Time: 2.43s
Iteration:   6377, Exp. Loss: 0.13, Time: 2.50s
Iteration:   6378, Exp. Loss: 0.12, Time: 2.78s
Iteration:   6379, Exp. Loss: 0.11, Time: 3.23s
Iteration:   6380, Exp. Loss: 0.16, Time: 3.63s
Iteration:   6381, Exp. Loss: 0.08, Time: 3.41s
Iteration:   6382, Exp. Loss: 0.13, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   6887, Exp. Loss: 0.19, Time: 3.09s
Iteration:   6888, Exp. Loss: 0.14, Time: 2.26s
Iteration:   6889, Exp. Loss: 0.11, Time: 2.51s
Iteration:   6890, Exp. Loss: 0.10, Time: 3.26s
Iteration:   6891, Exp. Loss: 0.10, Time: 3.01s
Iteration:   6892, Exp. Loss: 0.12, Time: 3.02s
Iteration:   6893, Exp. Loss: 0.11, Time: 4.24s
Iteration:   6894, Exp. Loss: 0.19, Time: 3.03s
Iteration:   6895, Exp. Loss: 0.11, Time: 2.39s
Iteration:   6896, Exp. Loss: 0.09, Time: 2.74s
Iteration:   6897, Exp. Loss: 0.14, Time: 2.76s
Iteration:   6898, Exp. Loss: 0.09, Time: 2.79s
Iteration:   6899, Exp. Loss: 0.12, Time: 3.54s
Iteration:   6900, Exp. Loss: 0.08, Time: 3.50s
Iteration:   6901, Exp. Loss: 0.08, Time: 3.21s
Iteration:   6902, Exp. Loss: 0.10, Time: 3.12s
Iteration:   6903, Exp. Loss: 0.16, Time: 3.08s
Iteration:   6904, Exp. Loss: 0.10, Time: 2.93s
Iteration:   6905, Exp. Loss: 0.11, Time: 3.06s
Iteration:   6906, Exp. Loss: 0.19, Time: 3.21s
Iteration:   6907, Exp. Loss: 0.12, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   7014, Exp. Loss: 0.12, Time: 3.52s
Iteration:   7015, Exp. Loss: 0.14, Time: 3.15s
Iteration:   7016, Exp. Loss: 0.10, Time: 3.17s
Iteration:   7017, Exp. Loss: 0.15, Time: 3.53s
Iteration:   7018, Exp. Loss: 0.11, Time: 3.02s
Iteration:   7019, Exp. Loss: 0.12, Time: 2.80s
Iteration:   7020, Exp. Loss: 0.14, Time: 2.70s
Iteration:   7021, Exp. Loss: 0.13, Time: 3.11s
Iteration:   7022, Exp. Loss: 0.09, Time: 2.96s
Iteration:   7023, Exp. Loss: 0.10, Time: 2.84s
Iteration:   7024, Exp. Loss: 0.13, Time: 2.71s
Iteration:   7025, Exp. Loss: 0.15, Time: 3.53s
Iteration:   7026, Exp. Loss: 0.12, Time: 2.76s
Iteration:   7027, Exp. Loss: 0.10, Time: 2.99s
Iteration:   7028, Exp. Loss: 0.15, Time: 3.15s
Iteration:   7029, Exp. Loss: 0.16, Time: 3.07s
Iteration:   7030, Exp. Loss: 0.07, Time: 2.95s
Iteration:   7031, Exp. Loss: 0.13, Time: 2.92s
Iteration:   7032, Exp. Loss: 0.16, Time: 3.07s
Iteration:   7033, Exp. Loss: 0.16, Time: 3.13s
Iteration:   7034, Exp. Loss: 0.15, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   7423, Exp. Loss: 0.13, Time: 2.63s
Iteration:   7424, Exp. Loss: 0.10, Time: 3.30s
Iteration:   7425, Exp. Loss: 0.14, Time: 2.77s
Iteration:   7426, Exp. Loss: 0.12, Time: 3.42s
Iteration:   7427, Exp. Loss: 0.15, Time: 3.20s
Iteration:   7428, Exp. Loss: 0.06, Time: 2.81s
Iteration:   7429, Exp. Loss: 0.13, Time: 2.68s
Iteration:   7430, Exp. Loss: 0.11, Time: 3.44s
Iteration:   7431, Exp. Loss: 0.19, Time: 2.79s
Iteration:   7432, Exp. Loss: 0.17, Time: 2.80s
Iteration:   7433, Exp. Loss: 0.11, Time: 3.27s
Iteration:   7434, Exp. Loss: 0.08, Time: 2.60s
Iteration:   7435, Exp. Loss: 0.09, Time: 2.80s
Iteration:   7436, Exp. Loss: 0.16, Time: 3.11s
Iteration:   7437, Exp. Loss: 0.10, Time: 3.10s
Iteration:   7438, Exp. Loss: 0.08, Time: 2.27s
Iteration:   7439, Exp. Loss: 0.10, Time: 3.40s
Iteration:   7440, Exp. Loss: 0.18, Time: 2.57s
Iteration:   7441, Exp. Loss: 0.11, Time: 3.88s
Iteration:   7442, Exp. Loss: 0.17, Time: 2.42s
Iteration:   7443, Exp. Loss: 0.08, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   7559, Exp. Loss: 0.13, Time: 3.19s
Iteration:   7560, Exp. Loss: 0.14, Time: 2.80s
Iteration:   7561, Exp. Loss: 0.10, Time: 3.03s
Iteration:   7562, Exp. Loss: 0.13, Time: 3.32s
Iteration:   7563, Exp. Loss: 0.13, Time: 2.86s
Iteration:   7564, Exp. Loss: 0.08, Time: 2.90s
Iteration:   7565, Exp. Loss: 0.11, Time: 3.36s
Iteration:   7566, Exp. Loss: 0.12, Time: 3.31s
Iteration:   7567, Exp. Loss: 0.10, Time: 3.30s
Iteration:   7568, Exp. Loss: 0.12, Time: 2.46s
Iteration:   7569, Exp. Loss: 0.18, Time: 2.63s
Iteration:   7570, Exp. Loss: 0.13, Time: 2.59s
Iteration:   7571, Exp. Loss: 0.11, Time: 2.65s
Iteration:   7572, Exp. Loss: 0.07, Time: 2.93s
Iteration:   7573, Exp. Loss: 0.09, Time: 2.99s
Iteration:   7574, Exp. Loss: 0.13, Time: 3.75s
Iteration:   7575, Exp. Loss: 0.11, Time: 3.04s
Iteration:   7576, Exp. Loss: 0.15, Time: 3.36s
Iteration:   7577, Exp. Loss: 0.11, Time: 2.42s
Iteration:   7578, Exp. Loss: 0.08, Time: 2.84s
Iteration:   7579, Exp. Loss: 0.10, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   7989, Exp. Loss: 0.17, Time: 2.31s
Iteration:   7990, Exp. Loss: 0.09, Time: 3.22s
Iteration:   7991, Exp. Loss: 0.08, Time: 3.03s
Iteration:   7992, Exp. Loss: 0.14, Time: 3.16s
Iteration:   7993, Exp. Loss: 0.12, Time: 2.77s
Iteration:   7994, Exp. Loss: 0.12, Time: 2.72s
Iteration:   7995, Exp. Loss: 0.12, Time: 3.16s
Iteration:   7996, Exp. Loss: 0.12, Time: 3.08s
Iteration:   7997, Exp. Loss: 0.07, Time: 3.54s
Iteration:   7998, Exp. Loss: 0.10, Time: 3.21s
Iteration:   7999, Exp. Loss: 0.10, Time: 2.42s
Iteration:   8000, Exp. Loss: 0.15, Time: 2.71s
Iteration:   8001, Exp. Loss: 0.16, Time: 2.81s
Iteration:   8002, Exp. Loss: 0.09, Time: 2.93s
Iteration:   8003, Exp. Loss: 0.17, Time: 2.61s
Iteration:   8004, Exp. Loss: 0.17, Time: 3.32s
Iteration:   8005, Exp. Loss: 0.12, Time: 2.85s
Iteration:   8006, Exp. Loss: 0.14, Time: 2.76s
Iteration:   8007, Exp. Loss: 0.14, Time: 2.74s
Iteration:   8008, Exp. Loss: 0.10, Time: 2.79s
Iteration:   8009, Exp. Loss: 0.09, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   8195, Exp. Loss: 0.08, Time: 3.02s
Iteration:   8196, Exp. Loss: 0.11, Time: 2.84s
Iteration:   8197, Exp. Loss: 0.15, Time: 3.40s
Iteration:   8198, Exp. Loss: 0.11, Time: 2.90s
Iteration:   8199, Exp. Loss: 0.10, Time: 3.33s
Iteration:   8200, Exp. Loss: 0.12, Time: 2.65s
Iteration:   8201, Exp. Loss: 0.13, Time: 2.93s
Iteration:   8202, Exp. Loss: 0.10, Time: 2.88s
Iteration:   8203, Exp. Loss: 0.07, Time: 3.67s
Iteration:   8204, Exp. Loss: 0.12, Time: 3.40s
Iteration:   8205, Exp. Loss: 0.12, Time: 3.24s
Iteration:   8206, Exp. Loss: 0.12, Time: 3.10s
Iteration:   8207, Exp. Loss: 0.12, Time: 2.82s
Iteration:   8208, Exp. Loss: 0.13, Time: 2.58s
Iteration:   8209, Exp. Loss: 0.10, Time: 2.71s
Iteration:   8210, Exp. Loss: 0.16, Time: 3.16s
Iteration:   8211, Exp. Loss: 0.16, Time: 2.86s
Iteration:   8212, Exp. Loss: 0.17, Time: 2.63s
Iteration:   8213, Exp. Loss: 0.11, Time: 3.39s
Iteration:   8214, Exp. Loss: 0.09, Time: 2.94s
Iteration:   8215, Exp. Loss: 0.10, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   8836, Exp. Loss: 0.11, Time: 2.99s
Iteration:   8837, Exp. Loss: 0.13, Time: 3.06s
Iteration:   8838, Exp. Loss: 0.13, Time: 3.02s
Iteration:   8839, Exp. Loss: 0.15, Time: 2.93s
Iteration:   8840, Exp. Loss: 0.09, Time: 3.38s
Iteration:   8841, Exp. Loss: 0.10, Time: 2.98s
Iteration:   8842, Exp. Loss: 0.12, Time: 3.18s
Iteration:   8843, Exp. Loss: 0.09, Time: 3.23s
Iteration:   8844, Exp. Loss: 0.09, Time: 2.99s
Iteration:   8845, Exp. Loss: 0.13, Time: 2.70s
Iteration:   8846, Exp. Loss: 0.06, Time: 2.90s
Iteration:   8847, Exp. Loss: 0.08, Time: 3.29s
Iteration:   8848, Exp. Loss: 0.10, Time: 3.05s
Iteration:   8849, Exp. Loss: 0.15, Time: 2.94s
Iteration:   8850, Exp. Loss: 0.10, Time: 3.28s
Iteration:   8851, Exp. Loss: 0.10, Time: 3.19s
Iteration:   8852, Exp. Loss: 0.12, Time: 3.18s
Iteration:   8853, Exp. Loss: 0.12, Time: 3.13s
Iteration:   8854, Exp. Loss: 0.09, Time: 3.44s
Iteration:   8855, Exp. Loss: 0.09, Time: 2.99s
Iteration:   8856, Exp. Loss: 0.09, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   8880, Exp. Loss: 0.11, Time: 3.19s
Iteration:   8881, Exp. Loss: 0.12, Time: 2.77s
Iteration:   8882, Exp. Loss: 0.11, Time: 3.12s
Iteration:   8883, Exp. Loss: 0.07, Time: 3.04s
Iteration:   8884, Exp. Loss: 0.13, Time: 2.88s
Iteration:   8885, Exp. Loss: 0.12, Time: 3.21s
Iteration:   8886, Exp. Loss: 0.09, Time: 3.07s
Iteration:   8887, Exp. Loss: 0.11, Time: 3.43s
Iteration:   8888, Exp. Loss: 0.09, Time: 3.00s
Iteration:   8889, Exp. Loss: 0.16, Time: 2.85s
Iteration:   8890, Exp. Loss: 0.09, Time: 3.11s
Iteration:   8891, Exp. Loss: 0.20, Time: 3.25s
Iteration:   8892, Exp. Loss: 0.07, Time: 2.90s
Iteration:   8893, Exp. Loss: 0.10, Time: 3.16s
Iteration:   8894, Exp. Loss: 0.10, Time: 3.00s
Iteration:   8895, Exp. Loss: 0.10, Time: 3.25s
Iteration:   8896, Exp. Loss: 0.10, Time: 2.96s
Iteration:   8897, Exp. Loss: 0.09, Time: 2.88s
Iteration:   8898, Exp. Loss: 0.13, Time: 3.08s
Iteration:   8899, Exp. Loss: 0.10, Time: 3.36s
Iteration:   8900, Exp. Loss: 0.11, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   9003, Exp. Loss: 0.10, Time: 2.67s
Iteration:   9004, Exp. Loss: 0.13, Time: 3.64s
Iteration:   9005, Exp. Loss: 0.12, Time: 3.04s
Iteration:   9006, Exp. Loss: 0.15, Time: 2.72s
Iteration:   9007, Exp. Loss: 0.09, Time: 3.00s
Iteration:   9008, Exp. Loss: 0.09, Time: 3.58s
Iteration:   9009, Exp. Loss: 0.10, Time: 2.33s
Iteration:   9010, Exp. Loss: 0.11, Time: 2.59s
Iteration:   9011, Exp. Loss: 0.14, Time: 2.93s
Iteration:   9012, Exp. Loss: 0.07, Time: 2.94s
Iteration:   9013, Exp. Loss: 0.08, Time: 3.49s
Iteration:   9014, Exp. Loss: 0.12, Time: 2.88s
Iteration:   9015, Exp. Loss: 0.11, Time: 2.97s
Iteration:   9016, Exp. Loss: 0.10, Time: 3.27s
Iteration:   9017, Exp. Loss: 0.10, Time: 2.51s
Iteration:   9018, Exp. Loss: 0.11, Time: 3.08s
Iteration:   9019, Exp. Loss: 0.14, Time: 2.52s
Iteration:   9020, Exp. Loss: 0.13, Time: 3.25s
Iteration:   9021, Exp. Loss: 0.08, Time: 2.71s
Iteration:   9022, Exp. Loss: 0.13, Time: 3.22s
Iteration:   9023, Exp. Loss: 0.14, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   9217, Exp. Loss: 0.10, Time: 2.91s
Iteration:   9218, Exp. Loss: 0.15, Time: 2.70s
Iteration:   9219, Exp. Loss: 0.08, Time: 2.98s
Iteration:   9220, Exp. Loss: 0.13, Time: 3.10s
Iteration:   9221, Exp. Loss: 0.11, Time: 2.82s
Iteration:   9222, Exp. Loss: 0.11, Time: 3.28s
Iteration:   9223, Exp. Loss: 0.08, Time: 3.09s
Iteration:   9224, Exp. Loss: 0.09, Time: 2.94s
Iteration:   9225, Exp. Loss: 0.14, Time: 3.43s
Iteration:   9226, Exp. Loss: 0.13, Time: 3.07s
Iteration:   9227, Exp. Loss: 0.09, Time: 3.44s
Iteration:   9228, Exp. Loss: 0.11, Time: 3.48s
Iteration:   9229, Exp. Loss: 0.10, Time: 3.11s
Iteration:   9230, Exp. Loss: 0.15, Time: 2.85s
Iteration:   9231, Exp. Loss: 0.08, Time: 2.86s
Iteration:   9232, Exp. Loss: 0.10, Time: 3.27s
Iteration:   9233, Exp. Loss: 0.10, Time: 3.52s
Iteration:   9234, Exp. Loss: 0.09, Time: 2.57s
Iteration:   9235, Exp. Loss: 0.11, Time: 3.11s
Iteration:   9236, Exp. Loss: 0.11, Time: 3.02s
Iteration:   9237, Exp. Loss: 0.10, Time

  " Skipping tag %s" % (size, len(data), tag)


Iteration:   9704, Exp. Loss: 0.13, Time: 3.81s
Iteration:   9705, Exp. Loss: 0.11, Time: 3.34s
Iteration:   9706, Exp. Loss: 0.09, Time: 2.97s
Iteration:   9707, Exp. Loss: 0.11, Time: 3.36s
Iteration:   9708, Exp. Loss: 0.11, Time: 3.47s
Iteration:   9709, Exp. Loss: 0.17, Time: 3.32s
Iteration:   9710, Exp. Loss: 0.09, Time: 3.05s
Iteration:   9711, Exp. Loss: 0.11, Time: 3.12s
Iteration:   9712, Exp. Loss: 0.08, Time: 2.96s
Iteration:   9713, Exp. Loss: 0.11, Time: 2.84s
Iteration:   9714, Exp. Loss: 0.15, Time: 3.16s
Iteration:   9715, Exp. Loss: 0.11, Time: 3.22s
Iteration:   9716, Exp. Loss: 0.09, Time: 3.27s
Iteration:   9717, Exp. Loss: 0.13, Time: 2.81s
Iteration:   9718, Exp. Loss: 0.08, Time: 3.18s
Iteration:   9719, Exp. Loss: 0.17, Time: 3.29s
Iteration:   9720, Exp. Loss: 0.06, Time: 2.53s
Iteration:   9721, Exp. Loss: 0.13, Time: 3.45s
Iteration:   9722, Exp. Loss: 0.11, Time: 3.79s
Iteration:   9723, Exp. Loss: 0.10, Time: 3.21s
Iteration:   9724, Exp. Loss: 0.11, Time

  " Skipping tag %s" % (size, len(data), tag)
  " Skipping tag %s" % (size, len(data), tag)


Iteration:   9874, Exp. Loss: 0.07, Time: 3.14s
Iteration:   9875, Exp. Loss: 0.13, Time: 3.60s
Iteration:   9876, Exp. Loss: 0.10, Time: 2.99s
Iteration:   9877, Exp. Loss: 0.12, Time: 3.29s
Iteration:   9878, Exp. Loss: 0.12, Time: 3.80s
Iteration:   9879, Exp. Loss: 0.12, Time: 3.29s
Iteration:   9880, Exp. Loss: 0.07, Time: 3.13s
Iteration:   9881, Exp. Loss: 0.12, Time: 3.47s
Iteration:   9882, Exp. Loss: 0.09, Time: 3.12s
Iteration:   9883, Exp. Loss: 0.11, Time: 3.63s
Iteration:   9884, Exp. Loss: 0.15, Time: 3.49s
Iteration:   9885, Exp. Loss: 0.09, Time: 3.53s
Iteration:   9886, Exp. Loss: 0.11, Time: 3.09s
Iteration:   9887, Exp. Loss: 0.10, Time: 3.98s
Iteration:   9888, Exp. Loss: 0.10, Time: 3.35s
Iteration:   9889, Exp. Loss: 0.10, Time: 3.35s
Iteration:   9890, Exp. Loss: 0.07, Time: 3.73s
Iteration:   9891, Exp. Loss: 0.08, Time: 3.24s
Iteration:   9892, Exp. Loss: 0.10, Time: 3.61s
Iteration:   9893, Exp. Loss: 0.14, Time: 3.35s
Iteration:   9894, Exp. Loss: 0.12, Time

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
# @inproceedings{brachmann2019ngransac,
#   title={{N}eural- {G}uided {RANSAC}: {L}earning Where to Sample Model Hypotheses},
#   author={Brachmann, Eric and Rother, Carsten},
#   booktitle={ICCV},
#   year={2019}
# }

import torch
import numpy as np

from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/

from hlw_dataset import HLWDataset
from model import Model

import time
import argparse

from ngdsac import NGDSAC
from loss import Loss

# parser = argparse.ArgumentParser(description='Test a trained horizon line estimation network (DSAC or NG-DSAC).', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# parser.add_argument('model', type=str,
# 	help='a trained network')

# parser.add_argument('--capacity', '-c', type=int, default=4, 
# 	help='controls the model capactiy of the network, must match the model to load (multiplicative factor for number of channels)')

# parser.add_argument('--imagesize', '-is', type=int, default=256, 
# 	help='size of input images to the network, must match the model to load')

# parser.add_argument('--inlierthreshold', '-it', type=float, default=0.05, 
# 	help='threshold used in the soft inlier count, relative to image size')

# parser.add_argument('--inlieralpha', '-ia', type=float, default=0.1, 
# 	help='scaling factor for the soft inlier scores (controls the peakiness of the hypothesis distribution)')

# parser.add_argument('--inlierbeta', '-ib', type=float, default=100.0, 
# 	help='scaling factor within the sigmoid of the soft inlier count')

# parser.add_argument('--hypotheses', '-hyps', type=int, default=16, 
# 	help='number of line hypotheses sampled for each image')

# parser.add_argument('--session', '-sid', default='', 
# 	help='custom session name appended to output files; useful to separate different runs of the program')

# parser.add_argument('--invalidloss', '-il', type=int, default=1, 
# 	help='penalty for sampling invalid hypotheses')

# parser.add_argument('--uniform', '-u', action='store_true', 
# 	help='disable neural-guidance and sample data points uniformely; corresponds to a DSAC model')

#Arguments

model = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/models/weights_colab_baseline_26.net'
capacity = 4
imagesize = 256
inlierthreshold = 0.05
inlieralpha = 0.1
inlierbeta = 100.0
hypotheses = 16
session = 'test26'
invalidloss = 1
uniform = False


#opt = parser.parse_args()

testDir = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/split/'

# setup test set
testset = HLWDataset(testDir + 'test.txt', imagesize, training=False)
testset_loader = torch.utils.data.DataLoader(testset, shuffle=False, num_workers=6, batch_size=1)

# setup ng dsac estimator
loss = Loss(imagesize, cut_off = 100) 
ngdsac = NGDSAC(hypotheses, inlierthreshold, inlierbeta,inlieralpha, loss,invalidloss)

# load network
nn = Model(capacity)
nn.load_state_dict(torch.load(model))
nn.eval()
nn = nn.cuda()

# write test results
test_log = open('test_'+session+'.txt', 'w', 1)

def AUC(losses, thresholds, binsize):
	"""Compute the AUC up to a set of error thresholds.
	Return mutliple AUC corresponding to multiple threshold provided.
	Keyword arguments:
	losses -- list of losses which the AUC should be calculated for
	thresholds -- list of threshold values up to which the AUC should be calculated
	binsize -- bin size to be used fo the cumulative histogram when calculating the AUC, the finer the more accurate
	"""

	bin_num = int(max(thresholds) / binsize)
	bins = np.arange(bin_num + 1) * binsize  

	hist, _ = np.histogram(losses, bins) # histogram up to the max threshold
	hist = hist.astype(np.float32) / len(losses) # normalized histogram
	hist = np.cumsum(hist) # cumulative normalized histogram
	 
	# calculate AUC for each threshold
	return [np.mean(hist[:int(t / binsize)]) for t in thresholds]

losses = []

for inputs, labels, xStart, xEnd, imh, idx in testset_loader:

	start_time = time.time()

	with torch.no_grad():
		# forward pass of neural network
		points, log_probs = nn(inputs.cuda())

		if uniform:
			# overwrite neural guidance with uniform sampling probabilities
			log_probs.fill_(1/log_probs.size(1))
			log_probs = torch.log(log_probs)

		# fit line with NG-DSAC
		ngdsac(points, log_probs, labels, xStart, xEnd, imh) 

	# evaluate (assumes a batch size of 1)
	cur_loss = loss(ngdsac.est_parameters[0], labels[0], xStart[0], xEnd[0], imh[0])

	# wrap up
	end_time = time.time()-start_time
	print('Image: %s, Loss: %2.2f, Time: %.2fs' 
		% (testset.images[idx[0]], cur_loss, end_time), flush=True)

	test_log.write('%s %f\n' % (testset.images[idx[0]], cur_loss))
	losses.append(cur_loss)

auc = AUC(losses, thresholds=[0.25], binsize=0.0001)

print("\n==========================================")
print("AUC@0.25: %.1f%%" % (auc[0]*100))
print("==========================================\n")
print('Done without errors.')

test_log.close()
files.download('test_'+session+'.txt')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Colab Notebooks/ngdsac_horizon
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/10047400_4909449851_o.jpg, Loss: 0.10, Time: 0.05s
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/110935763_66fc01689c_o.jpg, Loss: 0.21, Time: 0.02s
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/110935776_5625ee33a9_o.jpg, Loss: 0.09, Time: 0.03s
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/1157929690_376fa9d177_o.jpg, Loss: 0.09, Time: 0.02s
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/1157940322_c3423c93a7_o.jpg, Loss: 0.08, Time: 0.02s
Image: /content/drive/My Drive/Colab Notebooks/ngdsac_horizon/hlw_1_2/images/Alamo/1190606546_daf2749711_o.jpg, Loss: 0.05, Time: 0.02s
Imag

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Demo

In [19]:
# @inproceedings{brachmann2019ngransac,
#   title={{N}eural- {G}uided {RANSAC}: {L}earning Where to Sample Model Hypotheses},
#   author={Brachmann, Eric and Rother, Carsten},
#   booktitle={ICCV},
#   year={2019}
# }


model = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/models/weights_colab_baseline_2.net'
capacity = 4
imagesize = 256
inlierthreshold = 0.05
inlieralpha = 0.1
inlierbeta = 100.0
hypotheses = 16
invalidloss = 1
uniform = False
scorethreshold = 0.4
verbose = True

input = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/test/2.jpg'
outputName = '/2.png'


output_folder = '/content/drive/My Drive/Colab Notebooks/ngdsac_horizon/test/results'
#if not os.path.isdir(output_folder): os.makedirs(output_folder)

# setup ng dsac estimator
ngdsac = NGDSAC(hypotheses, inlierthreshold, inlierbeta, inlieralpha, Loss(imagesize), 1)

# load network
nn = Model(capacity)
nn.load_state_dict(torch.load(model))
nn.eval()
nn = nn.cuda()


def process_frame(image):
	'''
	Estimate horizon line for an image and return a visualization.

	image -- 3 dim numpy image tensor
	'''

	# determine image scaling factor
	image_scale = max(image.shape[0], image.shape[1])
	image_scale = imagesize / image_scale 

	# convert image to RGB
	if len(image.shape) < 3:
		image = color.gray2rgb(image)

	# store original image dimensions		
	src_h = int(image.shape[0] * image_scale)
	src_w = int(image.shape[1] * image_scale)

	# resize and to gray scale
	image = transforms.functional.to_pil_image(image)
	image = transforms.functional.resize(image, (src_h, src_w))
	image = transforms.functional.adjust_saturation(image, 0)
	image = transforms.functional.to_tensor(image)

	# make image square by zero padding
	padding_left = int((imagesize - image.size(2)) / 2)
	padding_right = imagesize - image.size(2) - padding_left
	padding_top = int((imagesize - image.size(1)) / 2)
	padding_bottom = imagesize - image.size(1) - padding_top

	padding = torch.nn.ZeroPad2d((padding_left, padding_right, padding_top, padding_bottom))
	image = padding(image)

	image_src = image.clone().unsqueeze(0)

	# normalize image (mean and variance), values estimated offline from HLW training set
	img_mask = image.sum(0) > 0
	image[:,img_mask] -= 0.45
	image[:,img_mask] /= 0.25
	image = image.unsqueeze(0).cuda()

	with torch.no_grad():
		#predict data points and neural guidance
		points, log_probs = nn(image)
	
		if uniform:
			# overwrite neural guidance with uniform sampling probabilities
			log_probs.fill_(1/log_probs.size(1))
			log_probs = torch.log(log_probs)

		# fit line with NG-DSAC, providing dummy ground truth labels
		ngdsac(points, log_probs, torch.zeros((1,2)), torch.zeros((1)), torch.ones((1)), torch.ones((1))) 

	def draw_line(data, lX1, lY1, lX2, lY2, clr):
		'''
		Draw a line with the given color and opacity.

		data -- image to draw to
		lX1 -- x value of line segment start point
		lY1 -- y value of line segment start point
		lX2 -- x value of line segment end point
		lY2 -- y value of line segment end point
		clr -- line color, triple of values
		'''

		rr, cc = line(lY1, lX1, lY2, lX2)
		set_color(data, (rr, cc), clr)

	def draw_models(labels, clr, data):
		'''
		Draw circles for a batch of images.
	
		labels -- line parameters, array shape (Nx2) where 
			N is the number of images in the batch
			2 is the number of line parameters (offset,  slope)
		data -- batch of images to draw to
		'''

		# number of image in batch
		n = labels.shape[0]

		for i in range (n):

			#line
			lY1 = int(labels[i, 0] * imagesize)
			lY2 = int(labels[i, 1] * imagesize + labels[i, 0] * imagesize)
			draw_line(data[i], 0, lY1, imagesize, lY2, clr)

		return data	

	def draw_wpoints(points, data, weights, clrmap):
		'''
		Draw 2D points for a batch of images.

		points -- 2D points, array shape (Nx2xM) where 
			N is the number of images in the batch
			2 is the number of point dimensions (x, y)
			M is the number of points
		data -- batch of images to draw to
		weights -- array shape (NxM), one weight per point, for visualization
		clrmap -- OpenCV color map for visualizing weights
			
		'''

		# create explicit color map
		color_map = np.arange(256).astype('u1')
		color_map = cv2.applyColorMap(color_map, clrmap)
		color_map = color_map[:,:,::-1] # BGR to RGB

		n = points.shape[0] # number of images
		m = points.shape[2] # number of points

		for i in range (0, n):

			s_idx = weights[i].sort(descending=False)[1] # draw low weight points first
			weights[i] = weights[i] / weights[i].max() # normalize weights for visualization

			for j in range(0, m):

				idx = int(s_idx[j])

				# convert weight to color
				clr_idx = float(min(1, weights[i,idx]))
				clr_idx = int(clr_idx * 255)
				clr = color_map[clr_idx, 0] / 255

				# draw point
				r = int(points[i, 0, idx] * imagesize)
				c = int(points[i, 1, idx] * imagesize)
				rr, cc = circle(r, c, 2)
				set_color(data[i], (rr, cc), clr)

		return data

	# normalized inlier score of the estimated line
	score = ngdsac.batch_inliers[0].sum() / points.shape[2]

	image_src = image_src.cpu().permute(0,2,3,1).numpy() #Torch to Numpy
	viz_probs = image_src.copy() * 0.2 # make a faint copy of the input image
	
	# draw estimated line
	if score > scorethreshold:
		image_src = draw_models(ngdsac.est_parameters, clr=(0,0,1), data=image_src)

	viz = [image_src]

	if verbose:	
		# create additional visualizations

		# draw faint estimated line 
		viz_score = viz_probs.copy()
		viz_probs = draw_models(ngdsac.est_parameters, clr=(0.3,0.3,0.3), data=viz_probs)
		viz_inliers = viz_probs.copy()

		# draw predicted points with neural guidance and soft inlier count, respectively
		viz_probs = draw_wpoints(points, viz_probs, weights=torch.exp(log_probs), clrmap=cv2.COLORMAP_PLASMA)
		viz_inliers = draw_wpoints(points, viz_inliers, weights=ngdsac.batch_inliers, clrmap=cv2.COLORMAP_WINTER)

		# create a explicit color map for visualize score of estimate line
		color_map = np.arange(256).astype('u1')
		color_map = cv2.applyColorMap(color_map, cv2.COLORMAP_HSV)	
		color_map = color_map[:,:,::-1]

		# map score to color
		score = int(score*100) #using only the first portion of HSV to get a nice (red, yellow, green) gradient
		clr = color_map[score, 0] / 255

		viz_score = draw_models(ngdsac.est_parameters, clr=clr, data=viz_score)

		viz = viz + [viz_probs, viz_inliers, viz_score]

	#undo zero padding of inputs
	if padding_left > 0:
		viz = [img[:,:,padding_left:,:] for img in viz]
	if padding_right > 0:
		viz = [img[:,:,:-padding_right,:] for img in viz]
	if padding_top > 0:
		viz = [img[:,padding_top:,:,:] for img in viz]
	if padding_bottom > 0:
		viz = [img[:,:-padding_bottom,:,:] for img in viz]		

	# convert to a single uchar image
	viz = np.concatenate(viz, axis=2)
	viz = viz * 255
	viz = viz.astype('u1')

	return viz[0]

# try to read input as image
image = cv2.imread(input)

if image is not None:
	#success, it was an image
	viz = process_frame(image)
	imsave(output_folder + outputName, viz)

else:
	#failure, try interpreting it as video
	cap = cv2.VideoCapture(input)
	iteration = 0

	while(cap.isOpened()):
		ret, image = cap.read()

		if not ret:
			break

		print("Processing frame %5d." % iteration)
	
		viz = process_frame(image)
		imsave(output_folder + '/frame_' + str(iteration).zfill(5) + '.png', viz)

		iteration = iteration + 1
