Skip to content

Commit

Permalink
make Image2Image support windows (fix #1412)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Mar 26, 2020
1 parent e50423d commit f1a8acf
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions examples/GAN/Image2Image.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ def optimizer(self):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)


def split_input(img):
def split_input(dp):
"""
img: an RGB image of shape (s, 2s, 3).
dp: the datapoint. first component is an RGB image of shape (s, 2s, 3).
:return: [input, output]
"""
img = dp[0]
# split the image into left + right pairs
s = img.shape[0]
assert img.shape[1] == 2 * s
Expand All @@ -169,7 +170,7 @@ def get_data():
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)

ds = MapData(ds, lambda dp: split_input(dp[0]))
ds = MapData(ds, split_input)
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
Expand All @@ -186,7 +187,7 @@ def sample(datadir, model_path):

imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
ds = MapData(ds, split_input)
ds = AugmentImageComponents(ds, [imgaug.Resize(256)], (0, 1))
ds = BatchData(ds, 6)

Expand Down
1 change: 1 addition & 0 deletions tensorpack/models/batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Tensorpack Contributors. All Rights Reserved
# -*- coding: utf-8 -*-
# File: batch_norm.py

Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def collect_env_info():
# List devices with NVML
data.append(
("CUDA_VISIBLE_DEVICES",
os.environ.get("CUDA_VISIBLE_DEVICES", str(None))))
os.environ.get("CUDA_VISIBLE_DEVICES", "Unspecified")))
try:
devs = defaultdict(list)
with NVMLContext() as ctx:
Expand Down

0 comments on commit f1a8acf

Please sign in to comment.