# Copy Data from GCP to Databricks DBFS

In [0]:
%pip install --upgrade google-cloud-storage
dbutils.library.restartPython()

In [0]:
import json
from google.cloud import storage
from pyspark.dbutils import DBUtils

from time import time
from pathlib import Path

In [0]:
dbutils = DBUtils(spark)

# Google Cloud credentials and bucket details
json_key_file_path = "../../dbfs/FileStore/shared_uploads/bubbly_sandbox_406100_b00687afd8a0.json"
gcs_bucket_name = "data516project"
gcs_blob_name = "data516project/Military Aircraft Detection/data/dataset"  # File path in GCS

# Authenticate and initialize GCS client
client = storage.Client.from_service_account_json(json_key_file_path)
bucket = client.get_bucket(gcs_bucket_name)

In [0]:
# Download the file from GCS and log time
time_start = time()
blob = bucket.blob(gcs_blob_name)

blobs = list(bucket.list_blobs(prefix='Military Aircraft Detection/data/dataset'))

for idx, blob in enumerate(blobs):
    if idx % 2000 == 0:
        print(f"{idx} files downloaded in {time()-time_start}s")
    if(not blob.name.endswith("/")):
        if(not blob.name.endswith(".DS_Store")):
            if(not blob.name.endswith(".csv")):
                blob.download_to_filename('../../dbfs/FileStore/shared_uploads/data/' + blob.name.split('/')[-1])

print(f'Download Process took {time()-time_start}s for approximately {len(blobs)} files')

0 files downloaded in 3.4030089378356934s
2000 files downloaded in 303.46774554252625s
4000 files downloaded in 597.6927881240845s
6000 files downloaded in 889.1370847225189s
8000 files downloaded in 1182.4302940368652s
10000 files downloaded in 1478.9855275154114s
12000 files downloaded in 1778.523443698883s
14000 files downloaded in 2081.5568969249725s
16000 files downloaded in 2371.435096025467s
18000 files downloaded in 2730.791647672653s
20000 files downloaded in 3084.4676949977875s
22000 files downloaded in 3368.206021308899s
24000 files downloaded in 3651.8589203357697s
Download Process took 3790.761983394623s for approximately 24974 files


# Mount GCP Storage 

In [0]:
import json
from pyspark.dbutils import DBUtils

dbutils = DBUtils(spark)

# Path to the JSON key file in DBFS
json_key_file_path = "../../dbfs/FileStore/shared_uploads/bubbly_sandbox_406100_b00687afd8a0.json"

# Load the JSON key file
with open(json_key_file_path) as f:
    gcs_key = json.load(f)

# Name of your GCS bucket and the mount point in Databricks
gcs_bucket_name = "data516project"
databricks_mount_point = "/mnt/"

# Mount the GCS bucket
dbutils.fs.mount(
    source = f"gs://{gcs_bucket_name}",
    mount_point = databricks_mount_point,
    extra_configs = {
                    # "google.cloud.auth.service.account.json.keyfile": json_key_file_path,
                     "fs.gs.project.id": gcs_key['project_id']}
)

True

# Inference Using Local Data

In [0]:
%pip install torch efficientnet_pytorch torchvision pillow

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting torch
  Downloading torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 670.2/670.2 MB 991.4 kB/s eta 0:00:00
Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting torchvision
  Downloading torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.8/6.8 MB 38.9 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 29.7 MB/s eta 0:00:00
Collecting fsspec
  Downloading fsspec-2023.12.2-py3-none-any.whl (168 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 169.0/169.0 kB 12.9

In [0]:
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
import time
import os

# Broadcast the model to all nodes
model = EfficientNet.from_pretrained('efficientnet-b1')
model.eval()
broadcast_model = sc.broadcast(model)

# Image preprocessing transformation
transform = transforms.Compose([
    transforms.Resize((240, 240)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to run inference on a single image
def predict_image(image_path):
    # Load the broadcasted model
    model = broadcast_model.value
    try:
        image = Image.open(image_path).convert('RGB')
        # Add batch dimension
        image = transform(image).unsqueeze(0) 
    except Exception as e:
        print(e)
        return 0

    with torch.no_grad():
        outputs = model(image)
    return torch.argmax(outputs, dim=1)

# List of image paths (modify this to point to your images)
image_paths = list()

for file in os.listdir('../../dbfs/FileStore/shared_uploads/data/'):
    if file.endswith('.csv'):
        continue
    image_paths.append('dbfs:/FileStore/shared_uploads/data/' + file)

dir_lists = [
            # '../../dbfs/FileStore/shared_uploads/data2/data', 
            #  '../../dbfs/FileStore/shared_uploads/data3/data',
            #  '../../dbfs/FileStore/shared_uploads/data4/data', 
            #  '../../dbfs/FileStore/shared_uploads/data5/data', 
            #  '../../dbfs/FileStore/shared_uploads/data6/data', 
            #  '../../dbfs/FileStore/shared_uploads/data7/data',
            #  '../../dbfs/FileStore/shared_uploads/data8/', 
            #  '../../dbfs/FileStore/shared_uploads/data9/data',
            #  '../../dbfs/FileStore/shared_uploads/data10/data'
            ]
for path in dir_lists:
    for file in os.listdir(path):
        if file.endswith('.csv'):
            continue
        image_paths.append(path + file)



Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b1-f1951068.pth
0.0%0.1%0.1%0.1%0.1%0.2%0.2%0.2%0.2%0.3%0.3%0.3%0.3%0.4%0.4%0.4%0.4%0.5%0.5%0.5%0.5%0.6%0.6%0.6%0.6%0.7%0.7%0.7%0.8%0.8%0.8%0.8%0.9%0.9%0.9%0.9%1.0%1.0%1.0%1.0%1.1%1.1%1.1%1.1%1.2%1.2%1.2%1.2%1.3%1.3%1.3%1.4%1.4%1.4%1.4%1.5%1.5%1.5%1.5%1.6%1.6%1.6%1.6%1.7%1.7%1.7%1.7%1.8%1.8%1.8%1.8%1.9%1.9%1.9%1.9%2.0%2.0%2.0%2.1%2.1%2.1%2.1%2.2%2.2%2.2%2.2%2.3%2.3%2.3%2.3%2.4%2.4%2.4%2.4%2.5%2.5%2.5%2.5%2.6%2.6%2.6%2.7%2.7%2.7%2.7%2.8%2.8%2.8%2.8%2.9%2.9%2.9%2.9%3.0%3.0%3.0%3.0%3.1%3.1%3.1%3.1%3.2%3.2%3.2%3.2%3.3%3.3%3.3%3.4%3.4%3.4%3.4%3.5%3.5%3.5%3.5%3.6%3.6%3.6%3.6%3.7%3.7%3.7%3.7%3.8%3.8%3.8%3.8%3.9%3.9%3.9%4.0%4.0%4.0%4.0%4.1%4.1%4.1%4.1%4.2%4.2%4.2%4.2%4

Loaded pretrained weights for efficientnet-b1


In [0]:
image_paths = [x.replace('/dbfs', '/dbfs/') for x in image_paths]
len(image_paths)

12487

In [0]:
# Start timing
start_time_1 = time.time()

# Parallelize the image paths list with Spark
# The paths are already in the correct format for Spark executors
rdd = sc.parallelize(image_paths)

start_time_2 = time.time()
# Run inference in parallel
predictions = rdd.map(predict_image).collect()

# End timing
end_time = time.time()

# Benchmark results
total_time = end_time - start_time_1
sub_time = end_time - start_time_2
avg_time_per_image = total_time / len(predictions)

print('Number of predictions', len(predictions))
print(f"Total inference time: {total_time} seconds")
print(f"Total inference time: {sub_time} seconds")
print(f"Average time per image: {avg_time_per_image:} seconds")

Number of predictions 12487
Total inference time: 0.6807379722595215 seconds
Total inference time: 0.6461248397827148 seconds
Average time per image: 5.4515734144271764e-05 seconds


## Inference and RDD creation


### Inference Only
- Total inference time: 62.35115885734558 seconds
- Average time per image: 0.0049932857257424185 seconds

- Total inference time: 54.9078311920166 seconds
- Average time per image: 0.004397199582927573 seconds

- Total inference time: 15.035117149353027 seconds
- Average time per image: 0.0012040615960080905 seconds

- Total inference time: 10.614421606063843 seconds
- Average time per image: 0.0008500377677635815 seconds

- Total inference time: 2.1653757095336914 seconds
- Average time per image: 0.00017341040358242102 seconds

- Total inference time: 16.094763040542603 seconds
- Average time per image: 0.001288921521625899 seconds

# Inference using Google Cloud Storage Python API

In [0]:
%pip install torch efficientnet_pytorch torchvision pillow

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting torch
  Downloading torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 670.2/670.2 MB 1.5 MB/s eta 0:00:00
Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting torchvision
  Downloading torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.8/6.8 MB 62.7 MB/s eta 0:00:00
Collecting nvidia-cuda-nvrtc-cu12==12.1.105
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 41.9 MB/s eta 0:00:00
Collecting nvidia-cufft-cu12==11.0.2.54
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
     ━━━━━━━━━

In [0]:
%pip install google-cloud-storage
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting google-cloud-storage
  Downloading google_cloud_storage-2.13.0-py2.py3-none-any.whl (121 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.1/121.1 kB 1.6 MB/s eta 0:00:00
Collecting google-resumable-media>=2.6.0
  Downloading google_resumable_media-2.6.0-py2.py3-none-any.whl (80 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 80.3/80.3 kB 12.9 MB/s eta 0:00:00
Collecting google-auth<3.0dev,>=2.23.3
  Downloading google_auth-2.25.1-py2.py3-none-any.whl (184 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 184.2/184.2 kB 20.2 MB/s eta 0:00:00
Collecting google-crc32c<2.0dev,>=1.0
  Downloading google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (32 kB)
Collecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5
  Downloading google_api_core-2.14.0-py3-none-any.whl (122 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1

In [0]:
from google.cloud import storage
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from PIL import Image
import time
import io
# Allow Image.Io to load truncated files as well.
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Initialize the GCS client
json_key_file_path = "../../dbfs/FileStore/shared_uploads/bubbly_sandbox_406100_b00687afd8a0.json"
storage_client = storage.Client.from_service_account_json(json_key_file_path)

# Bucket and folder details
bucket_name = 'data516project'
folder_name = 'Military Aircraft Detection/data/dataset'

# Broadcast the model to all nodes
model = EfficientNet.from_pretrained('efficientnet-b1')
model.eval()
broadcast_model = sc.broadcast(model)

# Image preprocessing transformation
transform = transforms.Compose([
    transforms.Resize((240, 240)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to run inference on a single image
def predict_image(image_blob):
    # Last version failed on loading a truncated bad image:
    # counter with try:except
    try:
        # Load the broadcasted model
        model = broadcast_model.value
        # Download the image from GCS
        storage_client = storage.Client.from_service_account_info(
        {
            "type": "service_account",
            "project_id": "bubbly-sandbox-406100",
            "private_key_id": "b00687afd8a0557d2daa2281fdf8a9ca4667b171",
            "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCY/29n5UWYZUIM\nmy9t6hrSBXcPXbptIUTO8aitE6zbSX6G7OicZe1zxQnSpkqhM/ZoiKCCrL1GmnRf\netkS61Crvx/mbur+5MZcqOnwWfnkazbu+udFwWC0TkuRnjheelgKjPG7Ial1O6iz\nFgnP77Tj3pF0A0w5R4tZAfnTVe1HRunx3iKkAFOfsSfVZ9G445+KvgfonWtlsJ09\nryaVt1GFxRwsZFzoFS+RaJJxZ8UGAr81uj4zbvtlfet2jABPnFfFlWTISVNCNNtj\niAMDX1e6Q8Iq3AYk/SEZy5taukDqQ8A454eMcBeLF252LUX8Z3YwwWTNZ8j9zkCO\nKEjwpwBLAgMBAAECggEALS2AIBW7bzy3PDG216YBNXMXrRnoKsFgKBuDy1e7+yTA\nj+b06PnQHBdaxzEDaIaHhGaTvwJzAZ4qLud/oe6Ka6yMJucpzQqQ6FuYeLiY6bIf\nDoJYOGjZvWRBPEaULpHB3Zuw5akS+OMc9YTUcOo/Z+oY2UEACQnvQ+EpfBYrcpgw\n9Vcq6WQmkxUBjn7Ecg+XGIi0LBniwcBPF96vXrw5WgjLFB7vP4Qfhp7R3CByuoFU\nnzqumaD8W4j7w3k+O6SqzSJ6YS1Vy7pz5iNI8sDENxZUHh7G1qCLlAv8Vbf82l1K\nnuv+VYvcrR75wI6DsI1nkWIzlz4SWxoRXhbwkaDMwQKBgQDSKyYpRU20lN11EdXG\n+JLwJEil3SP4iGVRwvDMTUAqMKnMbJVitR8ZDSoZaFPdKFhUCdWFtiYsvMCDejsJ\nr7aMJ+AcIlzp/xav9gUp26PGME9GFZCRQAYRmVJRtw2uZnD1MgMXgfOh5/suGFH5\nSEZ2ThgeBtKsdZQ4d72F4q6ldwKBgQC6XK0tHOJ006SLCCSukDcqRBEZzlS9vj7m\n1XfhdeyfXYtsJvcKupbayvWn/AfAJBgRogH1SLFkjw8Cb9CtHV9OttNYeV+RCi27\nPcsJHyNLYFSkBxNNPi/ciZSxOScxpoHEnxex87QzsVlUX90l+4yFI9QSQ/e5CZjL\njScnsLeAzQKBgAwPn9wEdyYi2OasBFmr0DrpyiWCoFrV8QKVLl57HHCPZF8v6A7H\nGCbG46CbpRxvAUqpWdCXmG8+0cl5zUOMCuzWKP1UV4KyOeoVM0yopyhrg1EhUa+U\nPgTqiJfZ6nYUuOzJVvYVZhbnijml+aZVqH2Hk8i9WX++K6hl7sFtNdqXAoGAFCcX\ndwxrlLSaU4qS5OmlfYWUz9SXJ1meVbn518C30Nf9zrk798YLsCdIf1zqvaVpkOaT\nDaJJxJUPy3Sp/T5y4wE0Q5oEcpvGLDxXa2oZl/oXzRESzkSOSmv6T1vd2CmoWMnX\ndx/iecCQ52XidUwJexBPy5XcAP9/rWtbM5iRj3UCgYAHOP36xNOxx+yNhxxhT1PB\n4kUyJ+lXsYeWo4Fy8eSZzTVL9RKpuETfCENLK3yElx6NuEP718epKWqQrMjdouEj\n7pC9j/Rcrm1FhBp+8YkNedIu8k13lBBtU47AejIMLLYRI6O1UhTa2nM5xEFEwmSY\nLBXE20dl2N6diASM1QOVvA==\n-----END PRIVATE KEY-----\n",
            "client_email": "452869396641-compute@developer.gserviceaccount.com",
            "client_id": "111941241038080001862",
            "auth_uri": "https://accounts.google.com/o/oauth2/auth",
            "token_uri": "https://oauth2.googleapis.com/token",
            "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
            "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/452869396641-compute%40developer.gserviceaccount.com",
            "universe_domain": "googleapis.com"
        })
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(image_blob)
        image_bytes = blob.download_as_bytes()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        # Add batch dimension
        image = transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = model(image)
        return torch.argmax(outputs, dim=1)
    except Exception as e:
        if "image file is truncated" in str(e):
            # Handle the truncated image file exception
            print("Error: Image file is truncated.")
        else:
            # Handle other exceptions
            print("An unexpected error occurred:", e)
        print(f"unable to predict {image_blob}")
        return -1
# Get the list of image blobs in the specified folder
bucket = storage_client.bucket(bucket_name)
broadcast_bucket = sc.broadcast(bucket)
blobs = bucket.list_blobs(prefix=folder_name)
image_blobs = [blob.name for blob in blobs if not blob.name.endswith('.csv')]

# # Parallelize the image blob names list with Spark
rdd = sc.parallelize(image_blobs)

# Start timing
start_time = time.time()

# Run inference in parallel
predictions = rdd.map(predict_image).collect()

# End timing
end_time = time.time()

# Benchmark results
total_time = end_time - start_time
avg_time_per_image = total_time / len(predictions)

print(f"Total inference time: {total_time:.3f} seconds")
print(f"Average time per image: {avg_time_per_image:.3f} seconds")


Loaded pretrained weights for efficientnet-b1


[0;31m---------------------------------------------------------------------------[0m
[0;31mPicklingError[0m                             Traceback (most recent call last)
File [0;32m<command-565462421550211>, line 75[0m
[1;32m     73[0m [38;5;66;03m# Get the list of image blobs in the specified folder[39;00m
[1;32m     74[0m bucket [38;5;241m=[39m storage_client[38;5;241m.[39mbucket(bucket_name)
[0;32m---> 75[0m broadcast_bucket [38;5;241m=[39m sc[38;5;241m.[39mbroadcast(bucket)
[1;32m     76[0m blobs [38;5;241m=[39m bucket[38;5;241m.[39mlist_blobs(prefix[38;5;241m=[39mfolder_name)
[1;32m     77[0m image_blobs [38;5;241m=[39m [blob[38;5;241m.[39mname [38;5;28;01mfor[39;00m blob [38;5;129;01min[39;00m blobs [38;5;28;01mif[39;00m [38;5;129;01mnot[39;00m blob[38;5;241m.[39mname[38;5;241m.[39mendswith([38;5;124m'[39m[38;5;124m.csv[39m[38;5;124m'[39m)]

File [0;32m/databricks/spark/python/pyspark/context.py:1856[0m, in [0;36mSparkConte

In [0]:
bucket.blob(image_blobs[1])

<Blob: data516project, Military Aircraft Detection/data/dataset/00032844ab679240fc03ecd27d29a6aa.jpg, None>

In [0]:
type(image_blobs)

list

[<Blob: data516project, Military Aircraft Detection/data/dataset/000106393cfe2343888c584e65fd2274.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/00032844ab679240fc03ecd27d29a6aa.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/0003f56298fa8999168d7988a2e9549d.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/000aa01b25574f28b654718db0700f72.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/000e7662268a1071827c5a8663e773f9.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/000ec980b5b17156a55093b4bd6004ab.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/0036c2784a5b8b2a4a1a4bd9109eb2f7.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/0039eb1ff33d29c55f943e05730bb259.jpg, None>,
 <Blob: data516project, Military Aircraft Detection/data/dataset/0041e69431bf872309d1aff628b6494f.jpg, None>,
 <Blob: da

### Run Times:
Total inference time: 2624.192 seconds
Average time per image: 0.210 seconds

Total inference time: 2800.103 seconds
Average time per image: 0.224 seconds

Total inference time: 2589.346 seconds
Average time per image: 0.207 seconds

# Distributed Training

In [0]:
%pip install horovod[tensorflow]

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
Collecting horovod[tensorflow]
  Using cached horovod-0.28.1-cp310-cp310-linux_x86_64.whl
Collecting cloudpickle
  Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)
Collecting pyyaml
  Using cached PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (705 kB)
Collecting tensorflow
  Using cached tensorflow-2.15.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (475.2 MB)
Collecting h5py>=2.9.0
  Using cached h5py-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
Collecting termcolor>=1.1.0
  Using cached termcolor-2.4.0-py3-none-any.whl (7.7 kB)
Collecting wrapt<1.15,>=1.11.0
  Using cached wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (77 kB)
Collecting google-pasta>=0.1.1
  Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Collecting pr

In [0]:
import horovod.tensorflow.keras as hvd
import tensorflow as tf

hvd.init()

[0;31m---------------------------------------------------------------------------[0m
[0;31mImportError[0m                               Traceback (most recent call last)
File [0;32m<command-3171562424461272>, line 1[0m
[0;32m----> 1[0m [38;5;28;01mimport[39;00m [38;5;21;01mhorovod[39;00m[38;5;21;01m.[39;00m[38;5;21;01mtensorflow[39;00m[38;5;21;01m.[39;00m[38;5;21;01mkeras[39;00m [38;5;28;01mas[39;00m [38;5;21;01mhvd[39;00m
[1;32m      2[0m [38;5;28;01mimport[39;00m [38;5;21;01mtensorflow[39;00m [38;5;28;01mas[39;00m [38;5;21;01mtf[39;00m
[1;32m      4[0m [38;5;66;03m# Initialize Horovod[39;00m

File [0;32m/databricks/python_shell/dbruntime/PostImportHook.py:218[0m, in [0;36m_ImportHookChainedLoader.load_module[0;34m(self, fullname)[0m
[1;32m    216[0m [38;5;28;01mdef[39;00m [38;5;21mload_module[39m([38;5;28mself[39m, fullname):
[1;32m    217[0m     [38;5;28;01mtry[39;00m:
[0;32m--> 218[0m         module [38;5;241m=[39m [38;5;2

In [0]:
scaled_lr = 0.001 * hvd.size()
optimizer = tf.keras.optimizers.Adam(scaled_lr)

optimizer = hvd.DistributedOptimizer(optimizer)

from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler

checkpoint_dir = '/dbfs/tmp/checkpoints' if hvd.rank() == 0 else None
callbacks = [
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
    hvd.callbacks.MetricAverageCallback(),
    hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1),
    ModelCheckpoint(checkpoint_dir, save_weights_only=True)
]

from tensorflow.keras.applications import EfficientNetB1
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

base_model = EfficientNetB1(weights='imagenet', include_top=False, input_shape=(240, 240, 3))
base_model.trainable = False

model = Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    Dense(1, activation='sigmoid')
])

model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

batch_size = 32
model.fit(dataset, callbacks=callbacks, epochs=1)