Skip to content

Commit 7aab770

Browse files
committed
Fix tfrecord-converter
1 parent 08bf31e commit 7aab770

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

run-gqn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
6868

6969
def step(engine, batch):
70+
model.train()
71+
7072
x, v = batch
7173
x, v = x.to(device), v.to(device)
7274
x, v, x_q, v_q = partition(x, v)
@@ -101,7 +103,10 @@ def step(engine, batch):
101103
# Trainer and metrics
102104
trainer = Engine(step)
103105
metric_names = ["elbo", "kl", "sigma", "mu"]
104-
metrics = [RunningAverage(output_transform=lambda x: x[m]).attach(trainer, m) for m in metric_names]
106+
RunningAverage(output_transform=lambda x: x["elbo"]).attach(trainer, "elbo")
107+
RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
108+
RunningAverage(output_transform=lambda x: x["sigma"]).attach(trainer, "sigma")
109+
RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
105110
ProgressBar().attach(trainer, metric_names=metric_names)
106111

107112
# Model checkpointing
@@ -142,6 +147,7 @@ def save_images(engine):
142147

143148
@trainer.on(Events.EPOCH_COMPLETED)
144149
def validate(engine):
150+
model.eval()
145151
with torch.no_grad():
146152
x, v = next(iter(valid_loader))
147153
x, v = x.to(device), v.to(device)

scripts/data.sh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#!/usr/bin/env bash
22

3-
LOCATION=$1
4-
BATCH_SIZE=$2
3+
LOCATION=$1 # example: /tmp/data
4+
BATCH_SIZE=$2 # example: 64
55

66
echo "Downloading data"
77
gsutil -m cp -R gs://gqn-dataset/shepard_metzler_5_parts $LOCATION
88

9-
echo "Deleting small records"
10-
TRAIN_PATH="$LOCATION/shepard_metzler_5_parts/train"
11-
find "$TRAIN_PATH/*.tfrecord" -type f -size -10M | xargs rm # remove smaller than 10mb
9+
echo "Deleting small records" # less than 10MB
10+
DATA_PATH="$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"
11+
find $DATA_PATH -type f -size -10M | xargs rm
1212

1313
echo "Converting data"
1414
python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "train"
1515
echo "Training data: done"
1616
python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "test"
17-
echo "Testing data: done"
17+
echo "Testing data: done"
18+
19+
echo "Removing original records"
20+
rm -rf "$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"

scripts/gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ python ../run-gqn.py \
1414
--log_dir "../logs" \
1515
--data_parallel "True" \
1616
--batch_size 1 \
17-
--n_workers 6
17+
--workers 6

scripts/tfrecord-converter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process(record):
4646

4747
# Convert
4848
images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]), **kwargs)
49-
images = tf.reshape(images, (-1, SEQ_DIM, 3, IMG_DIM, IMG_DIM))
49+
images = tf.reshape(images, (-1, SEQ_DIM, IMG_DIM, IMG_DIM, 3))
5050
poses = tf.reshape(poses, (-1, SEQ_DIM, POSE_DIM))
5151

5252
# Numpy conversion
@@ -64,8 +64,8 @@ def convert(record, batch_size):
6464
batch_process = lambda r: chunk(process(r), batch_size)
6565

6666
for i, batch in enumerate(batch_process(record)):
67-
path = os.path.join(path, "{0:}-{1:02}.pt.gz".format(basename, i))
68-
with gzip.open(path, 'wb') as f:
67+
p = os.path.join(path, "{0:}-{1:02}.pt.gz".format(basename, i))
68+
with gzip.open(p, 'wb') as f:
6969
torch.save(list(batch), f)
7070

7171
if __name__ == '__main__':
@@ -91,4 +91,4 @@ def convert(record, batch_size):
9191

9292
with mp.Pool(processes=mp.cpu_count()) as pool:
9393
f = partial(convert, batch_size=args.batch_size)
94-
pool.map(f, records)
94+
pool.map(f, records)

shepardmetzler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __getitem__(self, idx):
4949
images, viewpoints = list(zip(*data))
5050

5151
# (b, m, c, h, w)
52-
images = torch.FloatTensor(images)
52+
images = torch.FloatTensor(images)/255
53+
images = images.permute(0, 1, 4, 2, 3)
5354
if self.transform:
5455
images = self.transform(images)
5556

0 commit comments

Comments
 (0)