Permalink
Browse files

add functions that load latest weights files and PRINT_INTERVAL

  • Loading branch information...
youyuge34 committed Jan 8, 2019
1 parent 0fe46be commit c5818b24550efac9a16acb51f0d14e1607037c49
Showing with 39 additions and 23 deletions.
  1. +3 −3 main.py
  2. +10 −8 src/edge_connect.py
  3. +11 −11 src/models.py
  4. +12 −0 src/utils.py
  5. +3 −1 test.py
@@ -53,11 +53,11 @@ def main(mode=None):
# model test
elif config.MODE == 2:
print('\nstart testing...\n')
import time
start = time.time()
# import time
# start = time.time()
with torch.no_grad():
model.test()
print(time.time() - start)
# print(time.time() - start)

# eval mode
else:
@@ -90,15 +90,15 @@ def train(self):
epoch += 1
print('\n\nTraining epoch: %d' % epoch)

self.edge_model.train()
self.inpaint_model.train()
# ['epoch', 'iter'] will not be auto-averaged during training, others will
progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter'])

for items in train_loader:
images, images_gray, edges, masks = self.cuda(*items)

self.edge_model.train()
self.inpaint_model.train()

images, images_gray, edges, masks = self.cuda(*items)

# edge model
if model == 1:
# train
@@ -188,16 +188,18 @@ def train(self):
("iter", iteration),
] + logs

if iteration % 20 == 0:
progbar.add(len(images), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])
# terminal prints
if self.config.PRINT_INTERVAL and iteration % self.config.PRINT_INTERVAL == 0:
progbar.add(len(images) * int(self.config.PRINT_INTERVAL), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])

# log model at checkpoints
if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0:
self.log(logs)

# sample model at checkpoints
if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0:
# with torch.no_grad():
print('\nstart sampling...\n')
with torch.no_grad():
self.sample()

# evaluate model at checkpoints
@@ -312,7 +314,7 @@ def test(self):
for items in test_loader:
name = self.test_dataset.load_name(index)
images, images_gray, edges, masks = self.cuda(*items)
print('images size is {}, \n edges size is {}, \n masks size is {}'.format(images.size(), edges.size(), masks.size()))
# print('images size is {}, \n edges size is {}, \n masks size is {}'.format(images.size(), edges.size(), masks.size()))
index += 1

# edge model
@@ -7,6 +7,7 @@
from .networks import InpaintGenerator, EdgeGenerator, Discriminator
from .dataset import Dataset
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss
from .utils import get_model_list


class BaseModel(nn.Module):
@@ -17,32 +18,31 @@ def __init__(self, name, config):
self.config = config
self.iteration = 0

self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')

def load(self):
if os.path.exists(self.gen_weights_path):
print('Loading %s generator...' % self.name)
data = torch.load(self.gen_weights_path)
gen_path = get_model_list(self.config.PATH, self.name, 'gen')
dis_path = get_model_list(self.config.PATH, self.name, 'dis')
if gen_path is not None:
print('Loading {} generator weights file: {}...'.format(self.name, gen_path))
data = torch.load(gen_path)
self.generator.load_state_dict(data['generator'])
self.iteration = data['iteration']

# load discriminator only when training
if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
print('Loading %s discriminator...' % self.name)
data = torch.load(self.dis_weights_path)
if self.config.MODE == 1 and dis_path is not None:
print('Loading {} discriminator weights file: {}...'.format(self.name, dis_path))
data = torch.load(dis_path)
self.discriminator.load_state_dict(data['discriminator'])

def save(self):
print('\nsaving %s...\n' % self.name)
torch.save({
'iteration': self.iteration,
'generator': self.generator.state_dict()
}, os.path.join(self.config.PATH, '{}_{}_gen.pth'.format(self.name, self.iteration)))
}, os.path.join(self.config.PATH, '{}_gen_{}.pth'.format(self.name, self.iteration)))

torch.save({
'discriminator': self.discriminator.state_dict()
}, os.path.join(self.config.PATH, '{}_{}_dis.pth'.format(self.name, self.iteration)))
}, os.path.join(self.config.PATH, '{}_dis_{}.pth'.format(self.name, self.iteration)))


class EdgeModel(BaseModel):
@@ -20,6 +20,18 @@ def create_mask(width, height, mask_width, mask_height, x=None, y=None):
mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
return mask

# Get model list for resume, key_phase = 'EdgeModel' or 'InpaintModel', key_model = 'gen' or 'dis'
def get_model_list(dirname, key_phase, key_model):
if os.path.exists(dirname) is False:
return None
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
os.path.isfile(os.path.join(dirname, f)) and key_phase in f and key_model in f and ".pth" in f]
if gen_models is None:
return None
gen_models.sort()
last_model_name = gen_models[-1]
return last_model_name


def stitch_images(inputs, *outputs, img_per_row=2):
gap = 5
@@ -1,2 +1,4 @@
from main import main
main(mode=2)

if __name__ == '__main__':
main(mode=2)

0 comments on commit c5818b2

Please sign in to comment.