In [13]:
import os
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import io
import warnings


In [14]:
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
# Set up our connection to the API.
stability_api = client.StabilityInference(
    key='sk-kB00ghVoLTtUT6tvojO9VpsSso3f1g7XfYKjxyWchfxMhTTD', # API Key reference.
    verbose=True, # Print debug messages.
    engine="stable-diffusion-xl-1024-v1-0", # Set the engine to use for generation.
    # Check out the following link for a list of available engines: https://platform.stability.ai/docs/features/api-parameters#engine
)

In [15]:

def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        if os.path.isfile(img_path):
            img = Image.open(img_path)
            img = img.resize((700, 700))  # Resize the image
            if img is not None:
                images.append((img, filename))
    return images

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_cell_description(cell_type):
    cell_descriptions = {
        "cell_type_A549": "human lung adenocarcinoma alveolar basal epithelial cells",
        "cell_type_CACO-2": "colorectal adenocarcinoma epithelial cells",
        "cell_type_HPMEC": "human pulmonary microvascular endothelial cells",
        "cell_type_HSAEC": "human small airway epithelial cells",
        "cell_type_HUVEC": "human umbilical vein endothelial cells",
        "cell_type_NHBE": "normal human bronchial epithelial cells"
    }
    return cell_descriptions.get(cell_type, "Unknown cell type")

In [20]:
def generate_synthetic_images(base_folder, output_base):
    categories = ['bad', 'good']
    cell_types = ['cell_type_A549', 'cell_type_CACO-2', 'cell_type_HPMEC', 'cell_type_HSAEC', 'cell_type_HUVEC', 'cell_type_NHBE']
    time_periods = ['0-1_days', '2-3_days', '4+_days', '4_days']

    for category in categories:
        for cell_type in cell_types:
            for period in time_periods:
                real_path = os.path.join(base_folder, category, cell_type, period)
                synthetic_path = os.path.join(output_base, category, cell_type, period)
                ensure_dir(synthetic_path)
                if os.path.exists(real_path):
                    images = load_images_from_folder(real_path)
                    for img_resized, filename in images:
                        img_resized.save(os.path.join(synthetic_path, filename))  # Save the resized image
                        for i in range(3):  # Generate 3 synthetic images for each real image
                            prompt = f"{get_cell_description(cell_type)}, cell age {period}"
                            answers2 = stability_api.generate(
                                prompt=prompt,
                                init_image=img_resized,  # Ensure img_resized is the image object
                                start_schedule=0.2,
                                steps=50,
                                cfg_scale=10,
                                width=700,
                                height=700
                            )
                            # Handle response from stability API
                            for resp in answers2:
                                for artifact in resp.artifacts:
                                    if artifact.finish_reason == generation.FILTER:
                                        warnings.warn("Your request activated the API's safety filters.")
                                    if artifact.type == generation.ARTIFACT_IMAGE:
                                        img2 = Image.open(io.BytesIO(artifact.binary))
                                        synthetic_file_name = f"synthetic_{os.path.splitext(filename)[0]}_{i+1}.png"
                                        img2.save(os.path.join(synthetic_path, synthetic_file_name))


In [21]:
base_path = "OOC_image_dataset/train"
output_path = "synthetic_data/train"
generate_synthetic_images(base_path, output_path)

_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.RESOURCE_EXHAUSTED
	details = "Your organization does not have enough balance to request this action (need $0.0025134375, have $0.00116944 in active grants, $0 in balance)."
	debug_error_string = "UNKNOWN:Error received from peer ipv4:172.64.153.32:443 {grpc_message:"Your organization does not have enough balance to request this action (need $0.0025134375, have $0.00116944 in active grants, $0 in balance).", grpc_status:8, created_time:"2024-04-29T16:45:55.481018996+00:00"}"
>