-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
tpu_controller.py
228 lines (197 loc) · 7.02 KB
/
tpu_controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU controller class for common TPU manipulation."""
import functools
import multiprocessing
import os
import subprocess
from typing import List, Optional, Iterable, Callable, Any, Mapping, Union
from absl import logging
from fabric import Connection
import patchwork.transfers
import tpu_api
_SSH_KEYS_PATH = os.path.expanduser("~/.ssh/google_compute_engine")
def connect(ip_address: str) -> Connection:
return Connection(
ip_address,
connect_kwargs={
"key_filename": _SSH_KEYS_PATH,
},
)
class TPUController:
"""Generic TPU controller interface.
Attributes:
tpu_name: the TPU name.
accelerator_type: the TPU generation, e.g. V4.
accelerator_topology: the topology of the TPU. E.g. '4x4x4'
zone: the GCP zone.
project: the GCP project.
version: the TPU version, e.g. 'tpu_vm_v4_base'.
startup_script: an optional set of commands that will be concatenated to run
on TPU VM startup.
"""
def __init__(
self,
tpu_name: str,
zone: str,
project: str,
accelerator_type: str,
accelerator_topology: str,
version: str,
startup_script: Optional[List[str]],
network: Optional[str] = "default",
subnetwork: Optional[str] = "default",
preemptible: bool = False,
reserved: bool = False,
):
self._tpu_name = tpu_name
self._zone = zone
self._project = project
self._accelerator_type = accelerator_type
self._accelerator_topology = accelerator_topology
self._version = version
self._startup_script = startup_script
self._ip_addresses = []
self._connections = {}
self._network = network
self._subnetwork = subnetwork
self._preemptible = preemptible
self._reserved = reserved
@property
def tpu_name(self) -> str:
return self._tpu_name
def tpu_exists(self) -> bool:
"""Checks if the TPU exists."""
return tpu_api.tpu_exists(
tpu_name=self._tpu_name, project=self._project, zone=self._zone
)
def get_ip_addresses(self) -> List[str]:
"""Returns the IP addresses of the workers in the cluster."""
if not self._ip_addresses:
for endpoint in self.get_tpu()["networkEndpoints"]:
if "ipAddress" in endpoint:
self._ip_addresses.append(endpoint["ipAddress"])
return self._ip_addresses
def _maybe_configure_ssh_on_admin(self) -> str:
"""Runs the bash command to generate necessary SSH keys on the admin VM."""
if not os.path.exists(_SSH_KEYS_PATH):
subprocess.check_output("gcloud compute config-ssh", shell=True)
def get_connections(self) -> Mapping[str, Connection]:
"""Returns the mapping between IP and fabric.Connection."""
if not self._connections:
self._maybe_configure_ssh_on_admin()
for ip_address in self.get_ip_addresses():
self._connections[ip_address] = connect(ip_address)
return self._connections
def create_tpu(self):
"""Creates the TPU."""
tpu_api.create_tpu(
tpu_name=self._tpu_name,
zone=self._zone,
project=self._project,
accelerator_type=self._accelerator_type,
accelerator_topology=self._accelerator_topology,
version=self._version,
startup_script=self._startup_script,
network=self._network,
subnetwork=self._subnetwork,
preemptible=self._preemptible,
reserved=self._reserved,
)
self._ip_addresses.clear()
def maybe_create_tpu(self) -> bool:
"""Creates the TPU if it doesn't exist.
Returns:
True if the TPU needed to be created, False otherwise.
"""
if not self.tpu_exists():
self.create_tpu()
return True
return False
def delete_tpu(self):
"""Deletes the TPU."""
tpu_api.delete_tpu(
tpu_name=self._tpu_name, project=self._project, zone=self._zone
)
def get_tpu(self):
"""Gets the TPU info."""
return tpu_api.get_tpu(
tpu_name=self._tpu_name, project=self._project, zone=self._zone
)
def get_health(self):
return self.get_tpu()["health"]
def get_state(self):
return self.get_tpu()["state"]
def _run_on_worker(
self, ip_address: str, commands: Iterable[str], verbose: bool = True
):
"""Runs command(s) on a single worker."""
for command in commands:
logging.info("Running %s on %s", command, ip_address)
if command.startswith("sudo"):
# Strip 'sudo' from command
command = command[5:]
output = self.get_connections()[ip_address].sudo(command)
if verbose:
logging.info(f"{ip_address}: " + output.stdout)
else:
output = self.get_connections()[ip_address].run(command)
if verbose:
logging.info(f"{ip_address}: " + output.stdout)
def _run_per_worker(self, fn: Callable[..., Any]):
"""Runs a callable function for all workers."""
with multiprocessing.Pool(processes=len(self.get_ip_addresses())) as p:
p.map(fn, self.get_ip_addresses())
def run_commands_on_workers(self, commands: Iterable[str]):
"""Runs a list of commands for all workers."""
self._run_per_worker(
functools.partial(self._run_on_worker, commands=commands)
)
def _copy_files_to_worker(
self, ip_address: str, files: Union[str, Iterable[str]]
):
"""Copies files to a single worker."""
connection = self.get_connections()[ip_address]
for file in files:
if os.path.isdir(file):
patchwork.transfers.rsync(
connection, file, "~/", exclude=".git", strict_host_keys=False
)
else:
connection.put(file)
def copy_files_to_workers(self, files: Union[str, Iterable[str]]):
"""Copies files to all workers."""
if isinstance(files, str):
files = [files]
self._run_per_worker(
functools.partial(self._copy_files_to_worker, files=files)
)
def _get_files_from_worker(
self, ip_address: str, files: Union[str, Iterable[str]]
):
"""Gets files from a single worker."""
connection = self.get_connections()[ip_address]
for file in files:
connection.get(file)
def get_files_from_workers(self, files: Union[str, Iterable[str]]):
"""Gets files from all workers."""
if isinstance(files, str):
files = [files]
self._run_per_worker(
functools.partial(self._get_files_from_worker, files=files)
)
def get_num_nodes(self):
"""Returns the number of hosts in the TPU pod."""
return len(self.get_ip_addresses())