Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/build_torch_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function install_and_setup_conda() {
conda activate "$ENVNAME"
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"

conda install -y numpy pyyaml setuptools cmake cffi typing tqdm coverage
conda install -y numpy pyyaml setuptools cmake cffi typing tqdm coverage tensorboard
/usr/bin/yes | pip install --upgrade google-api-python-client
/usr/bin/yes | pip install --upgrade oauth2client
/usr/bin/yes | pip install lark-parser
Expand Down
7 changes: 6 additions & 1 deletion test/test_train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@

from common_utils import TestCase, run_tests
import os
from statistics import mean
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torch_xla
Expand Down Expand Up @@ -164,10 +166,13 @@ def test_loop_fn(model, loader, device, context):
return correct / total_samples

accuracy = 0.0
writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
for epoch in range(1, FLAGS.num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
accuracies = model_parallel(test_loop_fn, test_loader)
accuracy = sum(accuracies) / len(accuracies)
accuracy = mean(accuracies)
print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch)
if FLAGS.metrics_debug:
print(torch_xla._XLAC._xla_metrics_report())

Expand Down
8 changes: 8 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import sys


# summary_writer should be an instance of torch.utils.tensorborad.SummaryWriter
# or None. If None, no summary files will be written.
def add_scalar_to_summary(summary_writer, metric_name, metric_value,
global_step):
if summary_writer is not None:
summary_writer.add_scalar(metric_name, metric_value, global_step)


def parse_common_options(datadir=None,
logdir=None,
num_cores=None,
Expand Down