In [35]:
def get_img_from_s3(bucket, key):
    
    bucket = s3_resource.Bucket(bucket)
    file_stream = BytesIO()
    bucket.Object(key).download_fileobj(file_stream)
    np_1d_array = np.frombuffer(file_stream.getbuffer(), dtype="uint8")
    img = cv2.imdecode(np_1d_array, cv2.IMREAD_COLOR).copy()
    
    return resize_img(img)

In [36]:
def get_category_list_from_s3(bucket, key, delimiter, n_cats):
    
    list_resp = s3_client.list_objects_v2(Bucket=bucket, 
                                          Prefix=key+"/",
                                         Delimiter=delimiter)
    print("List_resp", list_resp)
    
    category_list = [ x['Prefix'] for x in list_resp['CommonPrefixes'][:n_cats]]
    print(category_list)
    
    return category_list

In [37]:
def get_images_and_labels(bucket, category_list, n_images):

    img_list = []
    labels_list = []
    
    for index, category in enumerate(category_list):
        # breed = category.split("/")[-2]
        list_resp = s3_client.list_objects_v2(Bucket=bucket, Prefix=category)

        for c in list_resp['Contents'][:n_images]:
            key = c['Key']
            img = get_img_from_s3(bucket, key)
            label = index
            
            img_list.append(img)
            labels_list.append(label)
    
    images = np.array(img_list)
    labels = np.array(labels_list)

    # print("Images shape = ",images.shape,"\nLabels shape = ",labels.shape)
    # print(type(images),type(labels))
    
    return images, labels

In [38]:
def load_training_data(params):
    
    category_list = get_category_list_from_s3(bucket=params['repo_name'],
                                         key=params['image_path'],
                                         delimiter=params['delimiter'],
                                         n_cats=params['n_cats']
                                         )
    

    images, labels = get_images_and_labels(bucket=repo_name,
                                             category_list=category_list,
                                             n_images=params['n_images'])
    
    return images, labels

In [39]:
def model_load(model_name, bucket_name, key):
    
    key = f"{key}/{model_name}"
    
    # READ
    with tempfile.TemporaryFile() as fp:
        s3_client.download_fileobj(Fileobj=fp, Bucket=bucket_name, Key=key)
        fp.seek(0)
        model = joblib.load(fp)

    # DELETE
    # s3_client.delete_object(Bucket=bucket_name, Key=key)
    
    print(type(model))
    
    return model

In [40]:
def model_save(model, model_name, bucket_name, key):
    
    joblib.dump(model, model_name)
    
    key = f"{key}/{model_name}"
    print(model_name, bucket_name, key)

    # WRITE
    with tempfile.TemporaryFile() as fp:
        joblib.dump(model, fp)
        fp.seek(0)
        s3_client.put_object(Body=fp.read(), Bucket=bucket_name, Key=key)

    return

In [41]:
def save_metrics(metrics, bucket_name, key):
    
    data = [(str(metrics['loss']), str(metrics['accuracy']))]

    schema = StructType([ \
        StructField("loss",StringType(),True), \
        StructField("accuracy",StringType(),True) \
      ])
 
    df = spark.createDataFrame(data=data,schema=schema)
    df.printSchema()
    df.show(truncate=False)
    
    path = f"s3a://{bucket_name}/{key}"
    df.write.json(path)
    return

In [42]:
def load_metrics(bucket_name, key):
    
    path = f"s3a://{bucket_name}/{key}"
    
    df = spark.read.json(path)
    metrics = df.collect()[0]
    loss = metrics['loss']
    accuracy = metrics['accuracy']
    return loss, accuracy