## Distributed Deep Learning Pipeline on GCP

Install libraries

In [2]:
!pip install torch
!pip install torchvision
!pip install cmake
!pip install horovod[pytorch,spark]

Collecting torch
  Downloading torch-1.10.0-cp38-cp38-manylinux1_x86_64.whl (881.9 MB)
     |███████████████████████████     | 742.7 MB 108.4 MB/s eta 0:00:02███████████████████████      | 714.2 MB 108.4 MB/s eta 0:00:02

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



     |████████████████████████████████| 881.9 MB 1.8 kB/s              
Installing collected packages: torch
Successfully installed torch-1.10.0
Collecting torchvision
  Downloading torchvision-0.11.1-cp38-cp38-manylinux1_x86_64.whl (23.3 MB)
     |████████████████████████████████| 23.3 MB 4.9 MB/s            
Installing collected packages: torchvision
Successfully installed torchvision-0.11.1
Collecting cmake
  Downloading cmake-3.22.1-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.7 MB)
     |████████████████████████████████| 22.7 MB 4.8 MB/s            
[?25hInstalling collected packages: cmake
Successfully installed cmake-3.22.1
Collecting horovod[pytorch,spark]
  Downloading horovod-0.23.0.tar.gz (3.4 MB)
     |████████████████████████████████| 3.4 MB 5.0 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.5.5-py3-none-any.whl (525 kB)
     |████████████████████████████████| 525 kB 9

In [3]:
import os
import subprocess
import sys
import numpy as np

import pyspark
import pyspark.sql.types as T
from pyspark import SparkConf
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import horovod.spark.torch as hvd
from horovod.spark.common.backend import SparkBackend
from horovod.spark.common.store import Store

In [4]:
conf = SparkConf().setAppName('pytorch_spark_CheXpert').set('spark.sql.shuffle.partitions', '2')
spark = SparkSession.builder.config(conf=conf).getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [5]:
%%time
patientImagedf = spark.read.format("image").option("dropInvalid", True).load("gs://chexpertcse6250fall2021/CheXpert-v1.0-small/train/*/*")
patientImagedf.printSchema()

                                                                                

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)

CPU times: user 238 ms, sys: 73.1 ms, total: 311 ms
Wall time: 1min 19s


In [6]:
patientImagedf.count()

                                                                                

10001

In [7]:
import matplotlib.pyplot as plt
import seaborn as sb
import cv2

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

In [8]:
sparkLabelsdf = spark.read.option('header', True).\
csv('gs://chexpertcse6250fall2021/CheXpert-v1.0-small/train_mod.csv')

In [9]:
sparkLabelsdf.printSchema()

root
 |-- Path: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Frontal/Lateral: string (nullable = true)
 |-- AP/PA: string (nullable = true)
 |-- No Finding: string (nullable = true)
 |-- Enlarged Cardiomediastinum: string (nullable = true)
 |-- Cardiomegaly: string (nullable = true)
 |-- Lung Opacity: string (nullable = true)
 |-- Lung Lesion: string (nullable = true)
 |-- Edema: string (nullable = true)
 |-- Consolidation: string (nullable = true)
 |-- Pneumonia: string (nullable = true)
 |-- Atelectasis: string (nullable = true)
 |-- Pneumothorax: string (nullable = true)
 |-- Pleural Effusion: string (nullable = true)
 |-- Pleural Other: string (nullable = true)
 |-- Fracture: string (nullable = true)
 |-- Support Devices: string (nullable = true)



In [10]:
# Change Column Names (to remove spaces)
sparkLabelsdf = sparkLabelsdf.withColumnRenamed("No Finding","NoFinding") \
.withColumnRenamed('Enlarged Cardiomediastinum','EnlargedCardiomediastinum')\
.withColumnRenamed('Lung Opacity','LungOpacity')\
.withColumnRenamed('Lung Lesion','LungLesion')\
.withColumnRenamed('Pleural Effusion','PleuralEffusion')\
.withColumnRenamed('Pleural Other','PleuralOther')\
.withColumnRenamed('Support Devices','SupportDevices')

sparkLabelsdf.printSchema()


root
 |-- Path: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Frontal/Lateral: string (nullable = true)
 |-- AP/PA: string (nullable = true)
 |-- NoFinding: string (nullable = true)
 |-- EnlargedCardiomediastinum: string (nullable = true)
 |-- Cardiomegaly: string (nullable = true)
 |-- LungOpacity: string (nullable = true)
 |-- LungLesion: string (nullable = true)
 |-- Edema: string (nullable = true)
 |-- Consolidation: string (nullable = true)
 |-- Pneumonia: string (nullable = true)
 |-- Atelectasis: string (nullable = true)
 |-- Pneumothorax: string (nullable = true)
 |-- PleuralEffusion: string (nullable = true)
 |-- PleuralOther: string (nullable = true)
 |-- Fracture: string (nullable = true)
 |-- SupportDevices: string (nullable = true)



Label standardization:
 - Findings with a '-1' label which implies uncertainty is replaced with 0

In [13]:
from pyspark.sql.functions import *


standardLabels = sparkLabelsdf\
.withColumn('NoFinding_mod', when(sparkLabelsdf.NoFinding == 1, 1.0)
            .when(sparkLabelsdf.NoFinding == '-1', 0.0)
            .otherwise(0.0))\
.withColumn('EnlargedCardiomediastinum_mod', when(sparkLabelsdf.EnlargedCardiomediastinum == '1', 1.0)
            .when(sparkLabelsdf.EnlargedCardiomediastinum == -1, 0.0)
            .otherwise(0.0))\
.withColumn('Cardiomegaly_mod', when(sparkLabelsdf.Cardiomegaly == 1, 1.0)
            .when(sparkLabelsdf.Cardiomegaly == '-1', 0.0)
            .otherwise(0.0))\
.withColumn('LungOpacity_mod', when(sparkLabelsdf.LungOpacity == 1, 1.0)
            .when(sparkLabelsdf.LungOpacity == '-1', 0.0)
            .otherwise(0.0))\
.withColumn('LungLesion_mod', when(sparkLabelsdf.LungLesion == 1, 1.0)
            .when(sparkLabelsdf.LungLesion == '-1', 0.0)
            .otherwise(0.0))\
.withColumn('Edema_mod', when(sparkLabelsdf.Edema == 1, 1.0)
            .when(sparkLabelsdf.Edema == '-1', 0.0)
            .otherwise(0.0))\
.withColumn('Consolidation_mod', when(sparkLabelsdf.Consolidation == 1, 1.0)
            .when(sparkLabelsdf.Consolidation == -1, 0.0)
            .otherwise(0.0))\
.withColumn('Pneumonia_mod', when(sparkLabelsdf.Pneumonia == 1, 1.0)
            .when(sparkLabelsdf.Pneumonia == -1, 0.0)
            .otherwise(0.0))\
.withColumn('Atelectasis_mod', when(sparkLabelsdf.Atelectasis == '1', 1.0)
            .when(sparkLabelsdf.Atelectasis == -1, 0.0)
            .otherwise(0.0))\
.withColumn('Pneumothorax_mod', when(sparkLabelsdf.Pneumothorax == '1', 1.0)
            .when(sparkLabelsdf.Pneumothorax ==-1, 0.0)
            .otherwise(0.0))\
.withColumn('PleuralEffusion_mod', when(sparkLabelsdf.PleuralEffusion == '1', 1.0)
            .when(sparkLabelsdf.PleuralEffusion == -1, 0.0)
            .otherwise(0.0))\
.withColumn('PleuralOther_mod', when(sparkLabelsdf.PleuralOther == '1', 1.0)
            .when(sparkLabelsdf.PleuralOther == -1, 0.0)
            .otherwise(0.0))\
.withColumn('Fracture_mod', when(sparkLabelsdf.Fracture == 1, 1.0)
            .when(sparkLabelsdf.Fracture == -1, 0.0)
            .otherwise(0.0))\
.withColumn('SupportDevices_mod', when(sparkLabelsdf.SupportDevices == 1, 1.0)
            .when(sparkLabelsdf.SupportDevices == -1, 0.0)
            .otherwise(0.0)) \
.drop('NoFinding', 'LungLesion', 'EnlargedCardiomediastinum','Edema', 'Consolidation', 'Pneumonia',
      'Atelectasis', 'Pneumothorax', 'PleuralEffusion', 'PleuralOther',
     'Fracture', 'SupportDevices', 'Cardiomegaly', 'LungOpacity')

In [14]:
standardLabels.printSchema()

root
 |-- Path: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Frontal/Lateral: string (nullable = true)
 |-- AP/PA: string (nullable = true)
 |-- NoFinding_mod: double (nullable = false)
 |-- EnlargedCardiomediastinum_mod: double (nullable = false)
 |-- Cardiomegaly_mod: double (nullable = false)
 |-- LungOpacity_mod: double (nullable = false)
 |-- LungLesion_mod: double (nullable = false)
 |-- Edema_mod: double (nullable = false)
 |-- Consolidation_mod: double (nullable = false)
 |-- Pneumonia_mod: double (nullable = false)
 |-- Atelectasis_mod: double (nullable = false)
 |-- Pneumothorax_mod: double (nullable = false)
 |-- PleuralEffusion_mod: double (nullable = false)
 |-- PleuralOther_mod: double (nullable = false)
 |-- Fracture_mod: double (nullable = false)
 |-- SupportDevices_mod: double (nullable = false)



In [15]:
from pyspark.sql.functions import regexp_replace, col

patientImagedfMod = patientImagedf.withColumn('pathgcp', regexp_replace('image.origin', 'gs://chexpertcse6250fall2021/',''))

In [17]:
patientImagedfMod = patientImagedfMod.select('image.origin', 'image.height', 'image.width', 'image.nChannels', 'image.mode', 'image.data', 'pathgcp')

In [18]:
patientImagedfMod.printSchema()
patientImagedfMod.show()

root
 |-- origin: string (nullable = true)
 |-- height: integer (nullable = true)
 |-- width: integer (nullable = true)
 |-- nChannels: integer (nullable = true)
 |-- mode: integer (nullable = true)
 |-- data: binary (nullable = true)
 |-- pathgcp: string (nullable = true)



                                                                                

+--------------------+------+-----+---------+----+--------------------+--------------------+
|              origin|height|width|nChannels|mode|                data|             pathgcp|
+--------------------+------+-----+---------+----+--------------------+--------------------+
|gs://chexpertcse6...|   320|  652|        1|   0|[FD FE FF FF FD F...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  425|        1|   0|[08 0B 0D 0D 0D 0...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  422|        1|   0|[02 0A 05 03 06 0...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  413|        1|   0|[25 38 39 31 34 3...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   452|  320|        1|   0|[0E 16 16 0C 13 1...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  390|        1|   0|[06 05 03 02 02 0...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  421|        1|   0|[02 02 03 03 03 0...|CheXpert-v1.0-sma...|
|gs://chexpertcse6...|   320|  423|        1|   0|[11 0F 10 18 1E 2...

In [19]:
joineddf = patientImagedfMod.join(standardLabels, patientImagedfMod.pathgcp == standardLabels.Path, 'inner')
joineddf.printSchema()

root
 |-- origin: string (nullable = true)
 |-- height: integer (nullable = true)
 |-- width: integer (nullable = true)
 |-- nChannels: integer (nullable = true)
 |-- mode: integer (nullable = true)
 |-- data: binary (nullable = true)
 |-- pathgcp: string (nullable = true)
 |-- Path: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Frontal/Lateral: string (nullable = true)
 |-- AP/PA: string (nullable = true)
 |-- NoFinding_mod: double (nullable = false)
 |-- EnlargedCardiomediastinum_mod: double (nullable = false)
 |-- Cardiomegaly_mod: double (nullable = false)
 |-- LungOpacity_mod: double (nullable = false)
 |-- LungLesion_mod: double (nullable = false)
 |-- Edema_mod: double (nullable = false)
 |-- Consolidation_mod: double (nullable = false)
 |-- Pneumonia_mod: double (nullable = false)
 |-- Atelectasis_mod: double (nullable = false)
 |-- Pneumothorax_mod: double (nullable = false)
 |-- PleuralEffusion_mod: double (nullable = fals

In [20]:
joineddf.count()

                                                                                

7826

Keep only the frontal images

In [21]:
%%time
joineddf = joineddf.filter(joineddf.pathgcp.contains('frontal'))



CPU times: user 227 ms, sys: 75.9 ms, total: 303 ms
Wall time: 2min 3s


                                                                                

7826

In [22]:

trainingdf = joineddf.select('data', 'height', 'width', 'NoFinding_mod', 'EnlargedCardiomediastinum_mod',
                             'Cardiomegaly_mod', 'LungOpacity_mod', 'LungLesion_mod', 'Edema_mod',
                             'Consolidation_mod', 'Pneumonia_mod', 'Atelectasis_mod', 'Pneumothorax_mod', 
                             'PleuralEffusion_mod', 'PleuralOther_mod','Fracture_mod', 'SupportDevices_mod')


trainingdf.printSchema()

root
 |-- data: binary (nullable = true)
 |-- height: integer (nullable = true)
 |-- width: integer (nullable = true)
 |-- NoFinding_mod: double (nullable = false)
 |-- EnlargedCardiomediastinum_mod: double (nullable = false)
 |-- Cardiomegaly_mod: double (nullable = false)
 |-- LungOpacity_mod: double (nullable = false)
 |-- LungLesion_mod: double (nullable = false)
 |-- Edema_mod: double (nullable = false)
 |-- Consolidation_mod: double (nullable = false)
 |-- Pneumonia_mod: double (nullable = false)
 |-- Atelectasis_mod: double (nullable = false)
 |-- Pneumothorax_mod: double (nullable = false)
 |-- PleuralEffusion_mod: double (nullable = false)
 |-- PleuralOther_mod: double (nullable = false)
 |-- Fracture_mod: double (nullable = false)
 |-- SupportDevices_mod: double (nullable = false)



Making a vector column out of the 14 labels

In [23]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=["NoFinding_mod", "EnlargedCardiomediastinum_mod", "Cardiomegaly_mod","LungOpacity_mod", 
              "LungLesion_mod", "Edema_mod", "Consolidation_mod","Pneumonia_mod","Atelectasis_mod","PleuralEffusion_mod",
              "PleuralOther_mod","Fracture_mod","SupportDevices_mod"],
    outputCol="labels")

output = assembler.transform(joineddf)


In [24]:
output.select(col('Path'), col('labels')).show()

21/12/13 00:57:35 WARN org.apache.spark.sql.catalyst.util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 16:>                                                         (0 + 1) / 1]

+--------------------+--------------------+
|                Path|              labels|
+--------------------+--------------------+
|CheXpert-v1.0-sma...|          (13,[],[])|
|CheXpert-v1.0-sma...|(13,[3,5,12],[1.0...|
|CheXpert-v1.0-sma...|      (13,[3],[1.0])|
|CheXpert-v1.0-sma...|      (13,[6],[1.0])|
|CheXpert-v1.0-sma...|     (13,[12],[1.0])|
|CheXpert-v1.0-sma...|      (13,[5],[1.0])|
|CheXpert-v1.0-sma...|      (13,[3],[1.0])|
|CheXpert-v1.0-sma...|(13,[5,6],[1.0,1.0])|
|CheXpert-v1.0-sma...|      (13,[3],[1.0])|
|CheXpert-v1.0-sma...|(13,[0,12],[1.0,1...|
|CheXpert-v1.0-sma...|      (13,[5],[1.0])|
|CheXpert-v1.0-sma...|      (13,[6],[1.0])|
|CheXpert-v1.0-sma...|(13,[5,12],[1.0,1...|
|CheXpert-v1.0-sma...|     (13,[12],[1.0])|
|CheXpert-v1.0-sma...|(13,[0,12],[1.0,1...|
|CheXpert-v1.0-sma...|(13,[5,12],[1.0,1...|
|CheXpert-v1.0-sma...|      (13,[0],[1.0])|
|CheXpert-v1.0-sma...|(13,[2,3,12],[1.0...|
|CheXpert-v1.0-sma...|      (13,[3],[1.0])|
|CheXpert-v1.0-sma...|      (13,

                                                                                

In [25]:
from pyspark.sql.functions import col
from petastorm.spark import SparkDatasetConverter, make_spark_converter
import io
import numpy as np
import torch
import torchvision
from PIL import Image
from functools import partial 
from petastorm import TransformSpec
from torchvision import transforms 
# from hyperopt import fmin, tpe, hp, SparkTrials, STATUS_OK
import horovod.torch as hvd
# from sparkdl import HorovodRunner

In [26]:
df_train, df_val = trainingdf.randomSplit([0.9, 0.1], seed=12345)

# Make sure the number of partitions is at least the number of workers which is required for distributed training.
df_train = df_train.repartition(2)
df_val = df_val.repartition(2)

### Image data transformation

In [27]:
def rawBytesToPIL(img_str, height, width):
    nparr = np.frombuffer(img_str, np.uint8).reshape(height, width, 1)
    img_np = cv2.cvtColor(nparr, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img_np)
    return img

Saving the standardized dataframe to a Parquet file

In [28]:
%%time
df_train \
    .coalesce(1) \
    .write \
    .mode('overwrite') \
    .option('parquet.block.size', 1024*1024) \
    .parquet('gs://chexpertcse6250fall2021/parquetCache4')

                                                                                

CPU times: user 277 ms, sys: 50.8 ms, total: 328 ms
Wall time: 2min 24s


Create a spark converter instance for the train and validation dataframes

In [None]:
%%time
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "gs://chexpertcse6250fall2021/parquetCache4")

converter_train = make_spark_converter(df_train)
converter_val = make_spark_converter(df_val)

Converting floating-point columns to float32

CPU times: user 302 ms, sys: 110 ms, total: 411 ms
Wall time: 3min 6s


                                                                                

In [30]:
print(f"train: {len(converter_train)}")

train: 7027


In [37]:
converter_train.make_torch_dataloader()

<petastorm.spark.spark_dataset_converter.TorchDatasetContextManager at 0x7fa6d2340760>

In [39]:
def metric_average(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.item()

def train_and_evaluate_hvd(lr=0.001):
    hvd.init()  # Initialize Horovod.
  
  # Horovod: pin GPU to local rank.
    if torch.cuda.is_available():
        torch.cuda.set_device(hvd.local_rank())
        device = torch.cuda.current_device()
    else:
        device = torch.device("cpu")
  
    model = get_model(lr=lr)
    model = model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
  
  # Effective batch size in synchronous distributed training is scaled by the number of workers.
  # An increase in learning rate compensates for the increased batch size.
    optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=lr * hvd.size(), momentum=0.9)
  
  # Broadcast initial parameters so all workers start with the same parameters.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
  
  # Wrap the optimizer with Horovod's DistributedOptimizer.
    optimizer_hvd = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_hvd, step_size=7, gamma=0.1)

    with converter_train.make_torch_dataloader(transform_spec=get_transform_spec(is_train=True), 
                                             cur_shard=hvd.rank(), shard_count=hvd.size(),
                                             batch_size=BATCH_SIZE) as train_dataloader, \
       converter_val.make_torch_dataloader(transform_spec=get_transform_spec(is_train=False),
                                           cur_shard=hvd.rank(), shard_count=hvd.size(),
                                           batch_size=BATCH_SIZE) as val_dataloader:
    
        train_dataloader_iter = iter(train_dataloader)
        steps_per_epoch = len(converter_train) // (BATCH_SIZE * hvd.size())
    
        val_dataloader_iter = iter(val_dataloader)
        validation_steps = max(1, len(converter_val) // (BATCH_SIZE * hvd.size()))
    
        for epoch in range(NUM_EPOCHS):
            print('Epoch {}/{}'.format(epoch + 1, NUM_EPOCHS))
            print('-' * 10)

            train_loss, train_acc = train_one_epoch(model, criterion, optimizer_hvd, exp_lr_scheduler, 
                                              train_dataloader_iter, steps_per_epoch, epoch, 
                                              device)
            val_loss, val_acc = evaluate(model, criterion, val_dataloader_iter, validation_steps,
                                   device, metric_agg_fn=metric_average)

    return val_loss

In [31]:
torchvision.models.mobilenet_v2(pretrained=True)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

MobileNetV2(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05,

In [32]:
def get_model(lr=0.001):
  # Load a MobileNetV2 model from torchvision
  model = torchvision.models.mobilenet_v2(pretrained=True)
  # Freeze parameters in the feature extraction layers
  for param in model.parameters():
    param.requires_grad = False
    
  # Add a new classifier layer for transfer learning
  num_ftrs = model.classifier[1].in_features
  # Parameters of newly constructed modules have requires_grad=True by default
  model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes)
  
  return model

This line fails on GCP

In [40]:
# from sparkdl import HorovodRunner

In [102]:
# !pip install keras
# !pip install tensorflow
# !pip install tensorframes

In [113]:
import horovod.spark.common._namedtuple_fix

import copy
import io
import numbers
import time

from pyspark import keyword_only
from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.util import MLWritable, MLReadable
from pyspark.sql import SparkSession

from horovod.runner.common.util import codec
from horovod.spark.common import util
from horovod.spark.common.estimator import HorovodEstimator, HorovodModel
from horovod.spark.common.params import EstimatorParams
from horovod.spark.common.serialization import \
    HorovodParamsWriter, HorovodParamsReader
from horovod.spark.torch import remote
from horovod.spark.torch.util import deserialize_fn, serialize_fn, \
    save_into_bio

In [None]:
import horovod.spark.keras as hvd
from horovod.spark.common.backend import SparkBackend
from horovod.spark.common.store import Store
from horovod.tensorflow.keras.callbacks import BestModelCheckpoint

In [None]:
!conda install tensorflow

Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): - 