-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
43 lines (40 loc) · 1.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from trainer import train
from validation import testModel
from DatasetLoader import GANDataGenerator, GANDataGeneratorXY
import numpy as np
from Datasetlabels import *
from config import get_parameters
def main(config):
if config.train:
'''
Train the CycleGAN Model
'''
model = train(config)
# saving model weights
model.save_weights(config.model_save_path+'model_weights')
else:
'''
Validate the CycleGAN Model
'''
validation_image_path = np.array([config.validate])
if config.subject==0:
# dataset-loder used in case of sketch to colorize image
loader = GANDataGenerator(validation_image_path,
config.dataset,
1,
dim = (config.height, config.width)
)
else:
# dataset-loader used in case of gender-bender and glass to no-glass
loader = GANDataGeneratorXY(validation_image_path,
validation_image_path,
config.dataset,
1,
dim = (config.height, config.width)
)
source, destination =next(iter(loader))
testModel(source, destination, config)
if __name__ == '__main__':
config = get_parameters()
print(config)
main(config)