Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
import os
import re
import socket
import tempfile
import subprocess
import time

from .version import __version__

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand All @@ -11,6 +15,53 @@
XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS)


def _maybe_select_tpu_version():
# Setup correct TPU runtime version for Colab and Kaggle.

def _is_open(ip, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if s.connect_ex((ip, int(port))) == 0:
return True
return False

def _wait_for_open(version, timeout=100, interval=10, log=True):
tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1]
deadline = time.time() + timeout

while not _is_open(*tpu_addr.split(':')):
if log:
logging.warning(
f'Waiting for TPU to be start up with version pytorch-{version}...')
if time.time() > deadline:
raise RuntimeError('Timed out waiting for TPU to start up')
time.sleep(interval)

if log:
logging.warning(
f'TPU has started up successfully with version pytorch-{version}')

try:
tpu_name = os.environ.get('TPU_NAME', '')
if not tpu_name.startswith('grpc://'):
# Not colab/kaggle
return

import cloud_tpu_client
client = cloud_tpu_client.Client(tpu_name)
client.configure_tpu_version(
f'pytorch-{__version__}', restart_type='ifNeeded')
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
_wait_for_open(__version__)
except ImportError:
logging.warning((
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
'installed. Ignore if not running on Colab/Kaggle TPU.'))
except Exception:
# This path is hit, when we get throttled by the verison changer
# when we import torch_xla from xmp.spawn-ed processes.
_wait_for_open(__version__, log=False)


def server_is_alive():
# pgrep returns 0 when at least one running process matches the requested name.
# Otherwise, the exit code is 1. If pgrep is not availiable in the system, it
Expand Down Expand Up @@ -92,6 +143,7 @@ def _tpu_vm_init():


# These needs to be called before the _XLAC module is loaded.
_maybe_select_tpu_version()
_setup_default_env()
_setup_grpc()
_setup_xla_flags()
Expand Down