Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 177 additions & 16 deletions torch_xla_py/xla_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
from __future__ import division
from __future__ import print_function

import requests

try:
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials
except ImportError:
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the xla_dist. Execute: '
'`pip install --upgrade google-api-python-client` '
'and `pip install --upgrade oauth2client` to '
'install with pip')

_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'


class Worker(object):

Expand All @@ -18,6 +32,25 @@ def __init__(self, internal_ip, machine_type, zone, hostname=None):
super(ClientWorker, self).__init__(internal_ip, machine_type, zone)
self._hostname = hostname

def __repr__(self):
return ('{{{internal_ip}, {machine_type}, {zone},'
' {hostname}}}').format(
internal_ip=self._internal_ip,
machine_type=self._machine_type,
zone=self._zone,
hostname=self._hostname)

def __eq__(self, other):
return (self._internal_ip == other._internal_ip and
self._machine_type == other._machine_type and
self._zone == other._zone and self._hostname == other._hostname)

def __ne__(self, other):
return not self.__eq__(self, other)

def __hash__(self):
return hash(repr(self))


class ServiceWorker(Worker):
# Same as base Worker ATM.
Expand All @@ -26,16 +59,18 @@ class ServiceWorker(Worker):

class Cluster(object):

def __init__(self, client_workers, service_workers,
def __init__(self,
client_workers,
service_workers,
check_client_machine_type=True,
check_service_machine_type=True):
"""Creates a cluster object.

Args:
client_workers: a list of ClientWorker objects.
service_workers: a list of ServiceWorker objects.
check_client_machine_type: whether to check if client workers all have
the same machine type.
check_client_machine_type: whether to check if client workers all have the
same machine type.
check_service_machine_type: whether to check if service workers all have
the same machine type.
"""
Expand Down Expand Up @@ -64,7 +99,7 @@ def validate(self):
"""
if len(self._client_workers) == 0 or len(self._service_workers) == 0:
raise RuntimeError(
'Both client_workers and service_workers should not be empty')
'Both client_workers and service_workers should not be empty')

if len(self._client_workers) != len(self._service_workers):
raise RuntimeError(
Expand All @@ -78,38 +113,164 @@ def validate(self):

if self._check_client_machine_type:
client_machine_types = {
worker._machine_type for worker in self._client_workers}
worker._machine_type for worker in self._client_workers
}
if len(client_machine_types) != 1:
raise RuntimeError(
'All client_workers must have the same machine_type, got: {}'.format(
client_machine_types))
'All client_workers must have the same machine_type, got: {}'
.format(client_machine_types))

if self._check_service_machine_type:
server_machine_types = {
worker._machine_type for worker in self._service_workers}
worker._machine_type for worker in self._service_workers
}
if len(server_machine_types) != 1:
raise RuntimeError(
'All service_workers must have the same machine_type, got: {}'.format(
server_machine_types))
'All service_workers must have the same machine_type, got: {}'
.format(server_machine_types))


class ClusterResolver(object):
"""Cluster Resolver for Client VM and Cloud TPU mesh."""

def __init__(self,
tpus,
vms=None,
zone=None,
project=None):
@staticmethod
def _get_instance_metadata(metadata):
response = requests.get(
'{}/computeMetadata/v1/{}'.format(_GCE_METADATA_ENDPOINT, metadata),
headers={'Metadata-Flavor': 'Google'})
return response.content.decode('utf-8')

@staticmethod
def _parse_resource_url(url, name):
parts = url.split('/')
idx = parts.index(name)
return parts[idx + 1]

def __init__(self, tpus, vms=None, zone=None, project=None):
"""Creates a new ClusterResolver object."""

assert tpus, "TPU name list must not be empty."
if not isinstance(tpus, list) or len(tpus) == 0:
raise ValueError('tpus must be a non-empty list')
if vms is not None:
if not isinstance(vms, list) or len(vms) == 0:
raise ValueError('vms must be a non-empty list if provided')

self._tpus = tpus
self._vms = vms
self._zone = zone
self._project = project

self._credentials = GoogleCredentials.get_application_default()
self._tpu_service = discovery.build(
'tpu', 'v1', credentials=self._credentials, cache_discovery=False)
self._compute_service = discovery.build(
'compute', 'v1', credentials=self._credentials, cache_discovery=False)

if project is None:
self._project = self._get_instance_metadata('project/project-id')
if zone is None:
zone_path = self._get_instance_metadata('instance/zone')
self._zone = self._parse_resource_url(zone_path, 'zones')
self._vm_master = self._get_instance_metadata('instance/name')

def _get_instance_group(self):
"""Gets the instance group that the current VM belongs to."""
resp = self._compute_service.instances().get(
project=self._project,
zone=self._zone,
instance=self._vm_master,
fields='metadata').execute()

if 'metadata' in resp and 'items' in resp['metadata']:
for item in resp['metadata']['items']:
if (item['key'] == 'created-by' and
'instanceGroupManagers' in item['value']):
return self._parse_resource_url(item['value'],
'instanceGroupManagers')

raise RuntimeError(('A vm list must be passed to ClusterResolver '
'if not using an instance group'))

def _get_member_instance_names(self, instance_group):
"""Gets all the instance names that belong to the given instance group."""
resp = self._compute_service.instanceGroups().listInstances(
project=self._project, zone=self._zone,
instanceGroup=instance_group).execute()

instances = []
for item in resp.get('items', []):
if 'instance' not in item or 'status' not in item:
continue
instance_path = item['instance']
instances.append(self._parse_resource_url(instance_path, 'instances'))

return instances

def _get_client_workers(self):
"""Gets client workers.

The instance group that the current VM belongs to is picked up from
the GCE instance metadata set of the VM. If a list of VMs was used for
initializing cluster resolver, we use that instead.

Returns:
A list of ClientWorker.

Raises:
RuntimeError: If the red VM cluster is not healthy.
"""
if self._vms is None:
# Using an instance group
instance_group = self._get_instance_group()
self._vms = self._get_member_instance_names(instance_group)
if len(self._vms) == 0:
raise RuntimeError('Client worker vms is empty in instance group')

workers = []
batch = self._compute_service.new_batch_http_request()

def add_worker(request_id, resp, exception):
"""Callback for each request in BatchHttpRequest."""
if exception is not None:
raise exception
hostname = self._parse_resource_url(resp['selfLink'], 'instances')
if resp['status'] != 'RUNNING':
raise RuntimeError(
('Instance {hostname} is not running yet. '
'Re-run when all VMs are running').format(hostname=hostname))
worker = ClientWorker(
internal_ip=resp['networkInterfaces'][0]['networkIP'],
machine_type=self._parse_resource_url(resp['machineType'],
'machineTypes'),
zone=self._parse_resource_url(resp['zone'], 'zones'),
hostname=hostname)
workers.append(worker)

for vm in self._vms:
req = self._compute_service.instances().get(
project=self._project, zone=self._zone, instance=vm,
fields=('machineType,metadata,selfLink,'
'networkInterfaces/networkIP,status,zone'))
batch.add(req, add_worker)
batch.execute()

return workers

def _get_service_workers(self):
"""Gets TPU VM cluster info.

Calls the TPU CLH to get TPU node data and returns list of TPU worker
VMs internal IP addresses. If zone and project are not specified at
ClusterResolver init time, we infer these bits from GCE metadata.

Returns:
A list of ServiceWorker.

Raises:
RuntimeError: If the TPU DNE or the TPU is in not in HEALTHY state.
"""
raise NotImplementedError()

def get_cluster(self):
"""Gets client and server side cluster info.

Expand Down
Loading