Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/apps/lightning_classy_vision/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def trainer(
image=image,
resource=named_resources[resource]
if resource
else torchx.Resource(cpu=1, gpu=0, memMB=1024),
else torchx.Resource(cpu=1, gpu=0, memMB=1500),
)
],
)
Expand Down
19 changes: 15 additions & 4 deletions examples/apps/lightning_classy_vision/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os.path
import subprocess
from typing import Tuple
from typing import Tuple, Optional, List

import fsspec
import pytorch_lightning as pl
Expand All @@ -29,15 +29,26 @@ class TinyImageNetModel(pl.LightningModule):
An very simple linear model for the tiny image net dataset.
"""

def __init__(self) -> None:
def __init__(self, layer_sizes: Optional[List[int]] = None) -> None:
super().__init__()
self.l1 = torch.nn.Linear(64 * 64, 4096)

# build a model with hidden layers specified by layer_sizes
if layer_sizes is None:
layer_sizes = []
dims = [64 * 64] + layer_sizes + [4096]
layers = []
for i, (a, b) in enumerate(zip(dims, dims[1:])):
if i > 0:
layers.append(torch.nn.ReLU(inplace=True))
layers.append(torch.nn.Linear(a, b))
self.seq = torch.nn.Sequential(*layers)

self.train_acc = Accuracy()
self.val_acc = Accuracy()

# pyre-fixme[14]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.l1(x.view(x.size(0), -1)))
return self.seq(x.view(x.size(0), -1))

# pyre-fixme[14]
def training_step(
Expand Down
55 changes: 55 additions & 0 deletions examples/apps/lightning_classy_vision/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Simple Logging Profiler
===========================

This is a simple profiler that's used as part of the trainer app example. This
logs the Lightning training stage durations a logger such as Tensorboard. This
output is used for HPO optimization with Ax.
"""

import time
from typing import Dict

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.profiler.base import BaseProfiler


class SimpleLoggingProfiler(BaseProfiler):
"""
This profiler records the duration of actions (in seconds) and reports the
mean duration of each action to the specified logger. Reported metrics are
in the format `duration_<event>`.
"""

def __init__(self, logger: LightningLoggerBase) -> None:
super().__init__()

self.current_actions: Dict[str, float] = {}
self.logger = logger

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
f"Attempted to start {action_name} which has already started."
)
self.current_actions[action_name] = time.monotonic()

def stop(self, action_name: str) -> None:
end_time = time.monotonic()
if action_name not in self.current_actions:
raise ValueError(
f"Attempting to stop recording an action ({action_name}) which was never started."
)
start_time = self.current_actions.pop(action_name)
duration = end_time - start_time
self.logger.log_metrics({"duration_" + action_name: duration})

def summary(self) -> str:
return ""
Empty file.
31 changes: 31 additions & 0 deletions examples/apps/lightning_classy_vision/test/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from examples.apps.lightning_classy_vision.model import (
TinyImageNetModel,
)


class ModelTest(unittest.TestCase):
def test_basic(self) -> None:
model = TinyImageNetModel()
self.assertEqual(len(model.seq), 1)
out = model(torch.zeros((1, 64, 64)))
self.assertIsNotNone(out)

def test_layer_sizes(self) -> None:
model = TinyImageNetModel(
layer_sizes=[
10,
15,
],
)
self.assertEqual(len(model.seq), 5)
out = model(torch.zeros((1, 64, 64)))
self.assertIsNotNone(out)
12 changes: 11 additions & 1 deletion examples/apps/lightning_classy_vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
TinyImageNetModel,
export_inference_model,
)
from examples.apps.lightning_classy_vision.profiler import (
SimpleLoggingProfiler,
)


def parse_args(argv: List[str]) -> argparse.Namespace:
Expand Down Expand Up @@ -75,6 +78,12 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
help="path to place the tensorboard logs",
default="/tmp",
)
parser.add_argument(
"--layers",
nargs="+",
type=int,
help="the MLP hidden layers and sizes, used for neural architecture search",
)
return parser.parse_args(argv)


Expand All @@ -83,7 +92,7 @@ def main(argv: List[str]) -> None:
args = parse_args(argv)

# Init our model
model = TinyImageNetModel()
model = TinyImageNetModel(args.layers)

# Download and setup the data module
if args.test:
Expand Down Expand Up @@ -124,6 +133,7 @@ def main(argv: List[str]) -> None:
logger=logger,
max_epochs=args.epochs,
callbacks=[checkpoint_callback],
profiler=SimpleLoggingProfiler(logger),
)

# Train the model ⚡
Expand Down
2 changes: 1 addition & 1 deletion scripts/kfpint.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def run_pipeline(build: BuildInfo, pipeline_file: str) -> object:
experiment_name="integration-tests",
run_name=f"integration test {build.id} - {os.path.basename(pipeline_file)}",
)
ui_url = f"{HOST}/_/pipeline/#/runs/details/{resp.run_id}"
ui_url = f"{HOST}/#/runs/details/{resp.run_id}"
print(f"{resp.run_id} - launched! view run at {ui_url}")
return resp

Expand Down