/
test.lua
executable file
·139 lines (98 loc) · 4.16 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
require 'torch'
require 'nngraph'
require 'cunn'
require 'optim'
require 'image'
require 'pl'
require 'paths'
require 'nnx'
ok, disp = pcall(require, 'display')
if not ok then print('display not found. unable to plot') end
opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
--saveFreq (default 10) save every saveFreq epochs
-n,--network (default "") reload pretrained network
-p,--plot plot while training
-r,--learningRate (default 0.02) learning rate
-b,--batchSize (default 100) batch size
-m,--momentum (default 0) momentum, for SGD only
--coefL1 (default 0) L1 penalty on the weights
--coefL2 (default 0) L2 penalty on the weights
-t,--threads (default 4) number of threads
-g,--gpu (default 0) gpu to run on (default cpu)
-d,--noiseDim (default 100) dimensionality of noise vector
--K (default 1) number of iterations to optimize D for
-w, --window (default 3) windsow id of sample image
--hidden_G (default 8000) number of units in hidden layers of G
--hidden_D (default 1600) number of units in hidden layers of D
--scale (default 72) scale of images to train on
]]
if opt.gpu < 0 or opt.gpu > 8 then opt.gpu = false end
print(opt)
-- fix seed
torch.manualSeed(torch.random(1,10000))
-- torch.manualSeed(1)
-- threads
torch.setnumthreads(opt.threads)
print('<torch> set nb of threads to ' .. torch.getnumthreads())
if opt.gpu then
cutorch.setDevice(opt.gpu + 1)
print('<gpu> using device ' .. opt.gpu)
torch.setdefaulttensortype('torch.CudaTensor')
else
torch.setdefaulttensortype('torch.FloatTensor')
end
-- model = torch.load('/nfs.yoda/xiaolonw/torch_projects/models/train_3dnormal_jointall_bi_s4/adversarial_G_9.net')
-- model_G = model.G:cuda()
-- model_G1= model.G1:cuda()
model = torch.load('../ssgan_models/joint_Style_GAN.net')
model_G = model.G
model = torch.load('../ssgan_models/Structure_GAN.net')
model_G1 = model.G
model_G = model_G:cuda()
model_G1 = model_G1:cuda()
opt.noiseDim = {100, 1, 1}
opt.geometry = {3, opt.scale, opt.scale}
opt.condDim = {3, opt.scale, opt.scale}
opt.div_num = 127.5
opt.finescale = opt.scale * 2
model_upsample = nn.Sequential()
model_upsample:add(nn.SpatialReSampling({owidth=128,oheight=128}) )
model_upsample:add(nn.Transpose({2,3},{3,4}))
model_upsample:add(nn.View(-1, 3))
model_upsample:add(nn.Normalize(2))
model_upsample:add(nn.View(-1, 128, 128, 3))
model_upsample:add(nn.Transpose({4,3},{3,2}))
model_upsample:float()
-- Get examples to plot
function getSamples(dataset, N, beg)
local resultpath = '../results/joint_all_results/'
os.execute('mkdir -p ' .. resultpath)
local N = N or 8
local noise_inputs = torch.Tensor(N, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3])
local noise_inputs2 = torch.Tensor(N, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3])
local cond_inputs = torch.Tensor(N, opt.condDim[1], opt.condDim[2], opt.condDim[3])
local gt_inputs = torch.Tensor(N, opt.condDim[1], opt.condDim[2], opt.condDim[3])
-- Generate samples
noise_inputs:uniform(-1, 1)
noise_inputs2:uniform(-1, 1)
local samples1 = model_G1:forward(noise_inputs)
local samples = model_upsample:forward(samples1:float())
local imgsamples = model_G:forward({noise_inputs2, samples:cuda()})
-- local to_plot = torch.FloatTensor(3, opt.scale * 10,opt.scale * 30)
for i=1,N do
output_name = paths.concat(resultpath, string.format('%04d_norm.jpg',i + beg))
output_imgname = paths.concat(resultpath, string.format('%04d_img.jpg',i + beg))
samples[i] = (samples[i] + 1 ) * opt.div_num
imgsamples[i] = (imgsamples[i] + 1 ) * opt.div_num
output_norm = samples[i]:clone()
output_norm = output_norm:byte():clone()
image.save(output_name, output_norm )
output_img = imgsamples[i]:clone()
output_img = output_img:byte():clone()
image.save(output_imgname, output_img )
end
end
for i = 1,10 do
getSamples(trainData, 10, (i - 1) * 10 )
end