Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/python/tensorflow_cloud/core/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
from __future__ import division
from __future__ import print_function

import io
import logging
import os
import sys
import tempfile

from . import machine_config

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

try:
from nbconvert import PythonExporter # pylint: disable=g-import-not-at-top
except ImportError:
Expand Down Expand Up @@ -164,7 +169,13 @@ def get_preprocessed_entry_point(
'exec(open("{}").read())\n'.format(entry_point_file_name))
else:
if called_from_notebook:
py_content = _get_colab_notebook_content()
# Kaggle integration
if os.getenv("KAGGLE_CONTAINER_NAME"):
logger.info("Preprocessing Kaggle notebook...")
py_content = _get_kaggle_notebook_content()
else:
# Colab integration
py_content = _get_colab_notebook_content()
else:
if PythonExporter is None:
raise RuntimeError(
Expand Down Expand Up @@ -212,6 +223,25 @@ def _get_colab_notebook_content():
return py_content


def _get_kaggle_notebook_content():
"""Returns the kaggle notebook python code contents."""
if PythonExporter is None:
raise RuntimeError(
# This should never occur.
# `nbconvert` is always installed on Kaggle.
"Please make sure you have installed `nbconvert` package."
)
from kaggle_session import UserSessionClient # pylint: disable=g-import-not-at-top # pytype: disable=import-error
kaggle_session_client = UserSessionClient()
try:
response = kaggle_session_client.get_exportable_ipynb()
ipynb_stream = io.StringIO(response["source"])
py_content, _ = PythonExporter().from_file(ipynb_stream)
return py_content.splitlines(keepends=True)
except:
raise RuntimeError("Unable to get the notebook contents.")


def get_tpu_cluster_resolver_fn():
"""Returns the fn required for runnning custom container on cloud TPUs.

Expand Down