<a href="https://colab.research.google.com/github/smandalika/satty/blob/master/Final_Flow_IFrame.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# SET YOUR LOGO IMAGE DIRECTORY PATH AND IMAGE NAMES HERE
UI_CONFIG = {
    #"logo_image_dir" : "/content/drive/MyDrive/Colab Notebooks/Gpu_logo/",
    "logo_image_dir" : "/content/drive/MyDrive/Demokratik-AI/Images/",

    "logo_image_names" : ['all_gpu.jpeg', 'A100.png', 'H100.jpeg', 'A10.png', 'L4.jpg']

}

# @title Code for setting config variables
class Config:
    _instance = None  # Class attribute to hold the singleton instance

    def __new__(cls):
        # If no instance exists, create one
        if cls._instance is None:
            cls._instance = super(Config, cls).__new__(cls)
            cls._instance.config_data = {}  # Initialize the configuration dictionary
        # Return the singleton instance
        return cls._instance

    def set(self, key, value):
        """Sets a configuration value."""
        self.config_data[key] = value  # Store the value in the configuration dictionary

    def get(self, key, default=None):
        """Gets a configuration value, returns default if not found."""
        return self.config_data.get(key, default)  # Retrieve the value or return the default

# # Usage Example
# config = Config()  # Create or access the singleton instance
# config.set('database_url', 'mysql://localhost/mydb')
# print(config.get('database_url'))  # Output: mysql://localhost/mydb

# @title Code for enabling a cloud
import subprocess
from pathlib import Path
import shutil
import os
import yaml
import ipywidgets as widgets
from IPython.display import display

class CloudManager:
    def __init__(self, _selected_cloud):
        self._selected_cloud = _selected_cloud
        self.available_clouds = ['AWS', 'GCP', 'Azure', 'IBM', 'Cudo', 'Lambda', 'SCP', 'Kubernetes', 'OCI', 'Paperspace', 'RunPod', 'Vsphere', 'Fluidstack']
        self.credentials_widgets = {}  # To store widget references for later use

    def create_widgets_for_inputs(self):
        """
        Create input widgets based on the selected cloud provider and store them in a dictionary.
        """
        credentials = {}
        if self._selected_cloud == "AWS":
            credentials['aws_access_key_id'] = widgets.Text(description='AWS Access Key Id:', style={'description_width': 'initial'})
            credentials['aws_secret_access_key'] = widgets.Text(description='AWS Secret Key:', style={'description_width': 'initial'})
            credentials['region'] = widgets.Text(description='Region:', style={'description_width': 'initial'})
            credentials['output'] = widgets.Text(description='Output Format:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Lambda":
            credentials['api_key'] = widgets.Text(description='Lambda API Key:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Paperspace":
            credentials['api_key'] = widgets.Text(description='Paperspace API Key:', style={'description_width': 'initial'})
        elif self._selected_cloud == "RunPod":
            credentials['api_key'] = widgets.Text(description='RunPod API Key:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Fluidstack":
            credentials['api_key'] = widgets.Text(description='Fluidstack API Key:', style={'description_width': 'initial'})
            credentials['api_token'] = widgets.Text(description='Fluidstack API Token:', style={'description_width': 'initial'})
        elif self._selected_cloud == "IBM":
            credentials['iam_api_key'] = widgets.Text(description='IBM IAM API Key:', style={'description_width': 'initial'})
            credentials['resource_group_id'] = widgets.Text(description='Resource Group Id:', style={'description_width': 'initial'})
            credentials['access_key_id'] = widgets.Text(description='Access Key Id:', style={'description_width': 'initial'})
            credentials['secret_access_key'] = widgets.Text(description='Secret Access Key:', style={'description_width': 'initial'})
        elif self._selected_cloud == "scp":
            credentials['access_key'] = widgets.Text(description='SCP Access Key:', style={'description_width': 'initial'})
            credentials['secret_key'] = widgets.Text(description='SCP Secret Key:', style={'description_width': 'initial'})
            credentials['project_id'] = widgets.Text(description='SCP Project Id:', style={'description_width': 'initial'})
        elif self._selected_cloud == "OCI":
            credentials['private_key_path'] = widgets.Text(description='OCI Private Key Path:', style={'description_width': 'initial'})
            credentials['user_ocid'] = widgets.Text(description='User OCID:', style={'description_width': 'initial'})
            credentials['fingerprint'] = widgets.Text(description='Fingerprint:', style={'description_width': 'initial'})
            credentials['tenancy_ocid'] = widgets.Text(description='Tenancy OCID:', style={'description_width': 'initial'})
            credentials['region'] = widgets.Text(description='OCI Region:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Cudo":
            credentials['api_key'] = widgets.Text(description='Cudo API Key:', style={'description_width': 'initial'})
            credentials['key_name'] = widgets.Text(description='Cudo Key Name:', style={'description_width': 'initial'})
            credentials['data_center'] = widgets.Text(description='Cudo Data Center:', style={'description_width': 'initial'})
            credentials['project'] = widgets.Text(description='Cudo Project:', style={'description_width': 'initial'})
            credentials['billing_account'] = widgets.Text(description='Billing Account:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Vsphere":
            credentials['vsphere_server_name'] = widgets.Text(description='vSphere Server Name:', style={'description_width': 'initial'})
            credentials['username'] = widgets.Text(description='vSphere Username:', style={'description_width': 'initial'})
            credentials['password'] = widgets.Text(description='vSphere Password:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Kubernetes":
            credentials['kubeconfig_path'] = widgets.Text(description='Kubeconfig Path:', style={'description_width': 'initial'})
        elif self._selected_cloud == "GCP":
            credentials['service_account_path'] = widgets.Text(description='Service Account Path:', style={'description_width': 'initial'})
            credentials['project_id'] = widgets.Text(description='Project Id:', style={'description_width': 'initial'})
        elif self._selected_cloud == "Azure":
            credentials['client_id'] = widgets.Text(description='Client Id:', style={'description_width': 'initial'})
            credentials['client_secret'] = widgets.Text(description='Client Secret:', style={'description_width': 'initial'})
            credentials['tenant_id'] = widgets.Text(description='Tenant Id:', style={'description_width': 'initial'})
            credentials['subscription_id'] = widgets.Text(description='Subscription Id:', style={'description_width': 'initial'})
        else:
            print("Unknown cloud provider.")
            return {}

        self.credentials_widgets = credentials

    def display_widgets(self):
        """
        Display the created widgets and a submit button.
        """
        # Display all widgets
        for widget in self.credentials_widgets.values():
            display(widget)

        # Add a Submit button
        submit_button = widgets.Button(description="Submit", button_style='success', style={'description_width': 'initial'})
        submit_button.on_click(self.on_submit)
        display(submit_button)

    def on_submit(self, button):
        """
        Handle the Submit button click to collect user inputs and enable the cloud.
        """
        # Collect values from widgets
        credentials = {key: widget.value for key, widget in self.credentials_widgets.items()}
        print("Collected credentials:", credentials)
        self.enable_cloud(credentials)
        print("################## GPU SELECTION COMPLETED ##################")

    def input_credentials(self):
        """
        Create widgets for input and display them.
        """
        self.create_widgets_for_inputs()
        self.display_widgets()


    def enable_cloud(self, credentials):
          if self._selected_cloud:
              # print(f"Enabling {self._selected_cloud} with credentials {credentials}")

              if self._selected_cloud == "AWS":
                  self._enable_aws(credentials)
              elif self._selected_cloud == "Lambda":
                  self._enable_lambda_cloud(credentials)
              elif self._selected_cloud == "Paperspace":
                  self._enable_paperspace(credentials)
              elif self._selected_cloud == "RunPod":
                  self._enable_runpod(credentials)
              elif self._selected_cloud == "Fluidstack":
                  self._enable_fluidstack(credentials)
              elif self._selected_cloud == "IBM":
                  self._enable_ibm(credentials)
              elif self._selected_cloud == "scp":
                  self._enable_scp(credentials)
              elif self._selected_cloud == "OCI":
                  self._enable_oci(credentials)
              elif self._selected_cloud == "Cudo":
                  self._enable_cudo(credentials)
              elif self._selected_cloud == "Vsphere":
                  self._enable_vsphere(credentials)
              elif self._selected_cloud == "Kubernetes":
                  self._enable_kubernetes(credentials)
              elif self._selected_cloud == "GCP":
                  self._enable_gcp(credentials)

              elif self._selected_cloud == "Azure":
                  self._enable_azure(credentials)
          else:
              print("Cloud selection is missing.")


    def _enable_aws(self, credentials):
        try:
            # Install boto3 library
            subprocess.run(['pip', 'install', 'boto3'], check=True)
            # Run the aws configure commands
            subprocess.run(['aws', 'configure', 'set', 'aws_access_key_id', credentials['aws_access_key_id']], check=True)
            subprocess.run(['aws', 'configure', 'set', 'aws_secret_access_key', credentials['aws_secret_access_key']], check=True)
            if credentials['region']:
                subprocess.run(['aws', 'configure', 'set', 'region', credentials['region']], check=True)
            if credentials['output']:
                subprocess.run(['aws', 'configure', 'set', 'output', credentials['output']], check=True)
            print("AWS configuration is set successfully.")
        except subprocess.CalledProcessError as e:
            print(f"Error in configuring AWS: {e}")

    def _enable_lambda_cloud(self, credentials):
        try:
            # Step 1: Create the directory if it doesn't exist
            lambda_cloud_dir = Path.home() / ".lambda_cloud"
            lambda_cloud_dir.mkdir(parents=True, exist_ok=True)
            # Step 2: Write the API key to the file
            lambda_keys_path = lambda_cloud_dir / "lambda_keys"
            with lambda_keys_path.open('w') as file:
                file.write(f"api_key = {credentials['api_key']}\n")

            print("Lambda Cloud configuration is set successfully.")
        except Exception as e:
            print(f"Error in setting up Lambda Cloud: {e}")

    def _enable_paperspace(self, credentials):
        try:
            # Step 1: Create the directory if it doesn't exist
            paperspace_dir = Path.home() / ".paperspace"
            paperspace_dir.mkdir(parents=True, exist_ok=True)
            # Step 2: Write the API key to the config.json file
            config_path = paperspace_dir / "config.json"
            paperspace_api_key = credentials['api_key']
            config_content = f'{{"apiKey" : "{paperspace_api_key}"}}'

            with config_path.open('w') as file:
                file.write(config_content)
            print("Paperspace Cloud configuration is set successfully.")
        except Exception as e:
            print(f"Error in setting up Paperspace Cloud: {e}")

    def _enable_runpod(self, credentials):
        try:
            # Step 1: Install the RunPod package
            print("Installing RunPod package...")
            subprocess.run(['pip', 'install', 'runpod>=1.5.1'], check=True)
            print("RunPod package installed successfully.")
            # Define the hidden directory path
            config_dir_path = os.path.join(os.path.expanduser('~'), '.runpod')
            # Check if the directory exists
            if os.path.exists(config_dir_path) and os.path.isdir(config_dir_path):
                # Delete the directory and its contents
                shutil.rmtree(config_dir_path)
                print(f"Deleted the existing config: {config_dir_path}")
            # Step 2: Configure RunPod with the provided API key
            print("Configuring RunPod...")
            runpod_api_key = credentials['api_key']
            result = subprocess.run(['runpod', 'config', runpod_api_key], capture_output=True, text=True)
            print("Runpod Cloud configuration is set successfully.")
        except Exception as e:
            print(f"Error in setting up Runpod Cloud: {e}")

    def _enable_fluidstack(self, credentials):
        try:
            # Step 1: Ensure the Fluidstack config directory exists
            fluidstack_dir = Path.home() / ".fluidstack"
            fluidstack_dir.mkdir(parents=True, exist_ok=True)
            # Step 2: Write the API key to the api_key file
            api_key_path = fluidstack_dir / "api_key"
            with api_key_path.open('w') as key_file:
                key_file.write(credentials['api_key'])
            print("API key written to ~/.fluidstack/api_key")
            # Step 3: Write the API token to the api_token file
            api_token_path = fluidstack_dir / "api_token"
            with api_token_path.open('w') as token_file:
                token_file.write(credentials['api_token'])
            print("API token written to ~/.fluidstack/api_token")
            print("Fluidstack Cloud configuration is set successfully.")
        except Exception as e:
            print(f"An error occurred while setting up Fluidstack: {e}")

    def _enable_ibm(self,credentials):
        try:
            # Step 1: Ensure the IBM config directory exists
            config_dir = Path.home() / ".ibm"
            config_dir.mkdir(parents=True, exist_ok=True)
            iam_api_key = credentials['iam_api_key']
            resource_group_id = credentials['resource_group_id']
            access_key_id = credentials['access_key_id']
            secret_access_key = credentials['secret_access_key']
            # Step 2: Define the VPC credentials content
            vpc_credentials_content = f"""
    iam_api_key: {iam_api_key}
    resource_group_id: {resource_group_id}
    """
            # Step 3: Define the COS credentials content
            cos_credentials_content = f"""
    access_key_id: {access_key_id}
    secret_access_key: {secret_access_key}
    """
            # Step 4: Write the VPC credentials to the credentials.yaml file
            config_path = config_dir / "credentials.yaml"
            with config_path.open('w') as credentials_file:
                credentials_file.write(vpc_credentials_content.strip())
                credentials_file.write("\n")  # Ensure there's a newline before appending
            # Step 5: Append the COS credentials to the credentials.yaml file
            with config_path.open('a') as credentials_file:
                credentials_file.write(cos_credentials_content.strip())
                credentials_file.write("\n")  # Ensure there's a newline at the end
            #step 6 : install rclone
            cmd = "curl https://rclone.org/install.sh | sudo bash"
            output = subprocess.check_output(cmd, shell=True)
            result = output.decode("utf-8")
            print("IBM Cloud configuration is set successfully.")
        except Exception as e:
            print(f"An error occurred while setting up IBM credentials: {e}")


    def _enable_scp(self,credentials):
        try:
            # Step 1: Ensure the SCP config directory exists
            config_dir = Path.home() / ".scp"
            config_dir.mkdir(parents=True, exist_ok=True)
            access_key = credentials['access_key']
            secret_key = credentials['secret_key']
            project_id = credentials['project_id']
            # Step 2: Define the SCP credentials content
            credentials_content = f"""
    access_key = {access_key}
    secret_key = {secret_key}
    project_id = {project_id}
    """
            # Step 3: Write the credentials to the scp_credential file
            config_path = config_dir / "scp_credential"
            with config_path.open('w') as credentials_file:
                credentials_file.write(credentials_content.strip())
                credentials_file.write("\n")  # Ensure there's a newline at the end
            print("SCP configuration is set successfully.")
        except Exception as e:
            print(f"An error occurred while setting up SCP credentials: {e}")


    def _enable_oci(self, credentials):
        try:
            # Step 1: Ensure the OCI config directory exists
            config_dir = Path.home() / ".oci"
            config_dir.mkdir(parents=True, exist_ok=True)
            # Step 2: Copy the private key file to the .oci directory
            key_filename = Path(credentials['private_key_path']).name
            key_destination = config_dir / key_filename
            shutil.copy(credentials['private_key_path'], key_destination)

            user_ocid = credentials['user_ocid']
            fingerprint = credentials['fingerprint']
            tenancy_ocid = credentials['tenancy_ocid']
            region = credentials['region']
            # Step 3: Define the OCI configuration content
            config_content = f"""
        [DEFAULT]
        user={user_ocid}
        fingerprint={fingerprint}
        tenancy={tenancy_ocid}
        region={region}
        key_file={key_destination}
        """

            # Step 4: Write the configuration content to the config file
            config_path = config_dir / "config"
            with config_path.open('w') as config_file:
                config_file.write(config_content.strip())
                config_file.write("\n")  # Ensure there's a newline at the end

            print("OCI credentials have been set up successfully.")
        except Exception as e:
            print(f"An error occurred while setting up OCI credentials: {e}")

    def _enable_cudo(self, credentials):
        try:
            cmd_list = ["wget https://download.cudo.org/compute/cudoctl_0.3.2_linux_amd64.tar.gz -O cudoctl.tar.gz",
            "tar -xzf cudoctl.tar.gz",
            "sudo mv cudoctl /usr/local/bin/",
            "chmod +x /usr/local/bin/cudoctl"]

            for cmd in cmd_list:
              output = subprocess.check_output(cmd, shell=True)
            print("cudoctl installed")
            # Step 1: Ensure the Cudo config directory exists
            config_dir = Path.home() / ".config" / "cudo"
            config_dir.mkdir(parents=True, exist_ok=True)

            api_key = credentials['api_key']
            key_name = credentials['key_name']
            data_center = credentials['data_center']
            billing_account = credentials['billing_account']
            project = credentials['project']
            # Step 2: Define the configuration content

            config_content = {
              'keys': [
                  {
                      'key': api_key,
                      'name': key_name
                  }
              ],
              'configVersion': 'v0',
              'contexts': [
                  {
                      'name': key_name,
                      'key': key_name,
                      'billing-account': billing_account,
                      'data-center': data_center,
                      'project': project
                  }
              ],
              'current-context': key_name
            }


            # Step 3: Write the configuration content to the cudo.yml file
            config_path = config_dir / "cudo.yml"
            with config_path.open('w') as config_file:
                yaml.dump(config_content, config_file, default_flow_style=False)
            # Step 4 : install cudo-compute library
            subprocess.run(['pip', 'install', 'cudo-compute>=0.1.10'], check=True)
            print("cudo-compute package installed successfully.")
            print("Cudo Compute Cloud configuration has been set up successfully.")
        except Exception as e:
            print(f"An error occurred while setting up Cudo Compute Cloud: {e}")


    def _enable_vsphere(self, credentials):
        try:
            vsphere_server_name = credentials['vsphere_server_name']
            username = credentials['username']
            password = credentials['password']
            config_dir = Path.home() / ".vsphere"
            config_dir.mkdir(parents=True, exist_ok=True)
            config_content = {
                "vcenters": [
                    {
                        "name": vsphere_server_name,
                        "username": username,
                        "password": password,
                        "skip_verification": True,
                        "clusters": []
                    }
                ]
            }
            config_path = config_dir / "credential.yaml"
            with config_path.open('w') as config_file:
                yaml.dump(config_content, config_file, default_flow_style=False)

            print("vsphere Cloud configuration has been set up successfully.")
            Cloud_Preparation_for_vSphere = "https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/vsphere.html#cloud-prepare-vsphere"
            print(f"After configuring the vSphere credentials, ensure that the necessary preparations for vSphere are completed.\n"
                    f"Please refer to this guide for more information: {Cloud_Preparation_for_vSphere}")

        except Exception as e:
            print(f"An error occurred while setting up VMware vSphere Cloud: {e}")

    def _enable_kubernetes(self,credentials):
        try:
            source_kubeconfig_path = credentials['kubeconfig_path']
            # Define the target directory and file
            kube_dir = Path.home() / ".kube"
            kube_config_path = kube_dir / "config"
            # Create the .kube directory if it does not exist
            kube_dir.mkdir(parents=True, exist_ok=True)
            # Copy the kubeconfig file to the .kube directory
            shutil.copy(source_kubeconfig_path, kube_config_path)
            print("kubernetes configuration has been set up successfully.")
        except Exception as e:
            print(f"An error occurred while setting up Kubernetes: {e}")

    def _enable_gcp(self,credentials):
        try:
          service_account_path = credentials['service_account_path']
          project_id = credentials['project_id']

          os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_path
          gcp_cmds = [f"gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS", f"gcloud config set project {project_id}"]

          for cmd in gcp_cmds:
            output = subprocess.check_output(cmd, shell=True)
          print("GCP  enabled")
        except Exception as e:
            print(f"An error occurred while setting up GCP: {e}")

    def _enable_azure(self,credentials):
        try:
          client_id = credentials['client_id']
          client_secret = credentials['client_secret']
          tenant_id = credentials['tenant_id']
          subscription_id = credentials['subscription_id']

          azure_cmds = [f"az login --service-principal -u {client_id} -p {client_secret} --tenant {tenant_id}", f"az account set --subscription {subscription_id}"]

          for cmd in azure_cmds:
            output = subprocess.check_output(cmd, shell=True)
          print("Azure  enabled")
        except Exception as e:
            print(f"An error occurred while setting up Azure: {e}")


    def install_skypilot_cloud_packages(self):
        cloud = self._selected_cloud.lower()
        command = f'pip install -U "skypilot-nightly[{cloud}]"'
        subprocess.run(command, shell=True, check=True)

    def configure_and_enable_cloud(self):
        """
        Main function to configure and enable the selected cloud.
        """
        self.install_skypilot_cloud_packages()
        self.input_credentials()



# @title Code to extract GPU Rate sheet and select Cloud & GPU accelerator
from typing import Any, Dict, List, Optional, Union
import pandas as pd

import click
import subprocess
# Install skypilot before importing any sky modules
command = 'pip install -U skypilot-nightly'
subprocess.run(command, shell=True, check=True)

import sky
from sky import clouds
from sky.clouds import service_catalog
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.usage import usage_lib



class GPURates:
    def __init__(
        self,
        accelerator_str: Optional[str]=None,
        all: bool=False,
        cloud: Optional[str]=None,
        region: Optional[str]=None,
        all_regions: Optional[bool]=None
        ):


        self.accelerator_str = accelerator_str
        self.all = all
        self.cloud = cloud
        self.region = region
        self.all_regions = all_regions

    def validate_parameters(self):
        if self.region is not None and self.cloud is None:
            raise click.UsageError('The region flag is only valid when the cloud flag is set.')
        if self.all_regions and self.accelerator_str is None:
            raise click.UsageError('The all-regions flag is only valid when an accelerator is specified.')
        if self.all_regions and self.region is not None:
            raise click.UsageError('all-regions and region flags cannot be used simultaneously.')
        if self.all and self.accelerator_str is not None:
            raise click.UsageError('all is only allowed without a GPU name.')

    def get_cloud_object(self):
        try:
            self.cloud_obj = clouds.CLOUD_REGISTRY.from_str(self.cloud)
            service_catalog.validate_region_zone(self.region, None, clouds=self.cloud)
        except Exception as e:
            print("Exception : ", e)
    def _list_to_str(self, lst):
        return ', '.join([str(e) for e in lst])

    def _output(self):
        show_all = self.all
        import pandas as pd

        name, quantity = None, None
        if self.accelerator_str is None:
            result = service_catalog.list_accelerator_counts(
                gpus_only=True,
                clouds=self.cloud,
                region_filter=self.region,
            )

            if (len(result) == 0 and self.cloud_obj is not None and
                    self.cloud_obj.is_same_cloud(clouds.Kubernetes())):
                return kubernetes_utils.NO_GPU_ERROR_MESSAGE

            #### "Common" GPUs ####
            # If cloud is kubernetes, we want to show all GPUs here, even if
            # they are not listed as common in SkyPilot.
            gpu_rows = []
            if (self.cloud_obj is not None and
                    self.cloud_obj.is_same_cloud(clouds.Kubernetes())):
                for gpu, _ in sorted(result.items()):
                    gpu_rows.append({'COMMON_GPU': gpu, 'AVAILABLE_QUANTITIES': self._list_to_str(result.pop(gpu))})
            else:
                for gpu in service_catalog.get_common_gpus():
                    if gpu in result:
                        gpu_rows.append({'COMMON_GPU': gpu, 'AVAILABLE_QUANTITIES': self._list_to_str(result.pop(gpu))})

            # Create DataFrame from list of dictionaries
            gpu_df = pd.DataFrame(gpu_rows, index=None)
            # print("----- COMMON_GPU -----")
            # print(gpu_df)

            #### Google TPUs ####
            tpu_rows = []
            for tpu in service_catalog.get_tpus():
                if tpu in result:
                    # appedning row for dataframe
                    tpu_rows.append({'GOOGLE_TPU': tpu, 'AVAILABLE_QUANTITIES': self._list_to_str(result.pop(tpu))})

            # create tpu df
            # print("\n\n")
            tpu_df = pd.DataFrame(tpu_rows, index=None)
            # print("----- GOOGLE_TPU -----")
            # print(tpu_df)

            #### Other GPUs  ####
            other_rows = []
            if show_all:
                # print('\n\n')
                for gpu, qty in sorted(result.items()):
                    other_rows.append({'OTHER_GPU': gpu, 'AVAILABLE_QUANTITIES': self._list_to_str(qty)})

                # create other df
                other_df = pd.DataFrame(other_rows, index=None)
                # print("----- OTHER_GPU -----")
                # print(other_df)

            else:
                print('\n\nHint: use all=True to see all accelerators '
                       '(including non-common ones) and pricing.')
                # return

        else:
            # Parse accelerator string
            accelerator_split = self.accelerator_str.split(':')
            if len(accelerator_split) > 2:
                raise click.UsageError(
                    f'Invalid accelerator string {self.accelerator_str}. '
                    'Expected format: <accelerator_name>[:<quantity>].')

            if len(accelerator_split) == 2:
                name = accelerator_split[0]
                # Check if quantity is valid
                try:
                    quantity = int(accelerator_split[1])
                    if quantity <= 0:
                        raise ValueError(
                            'Quantity cannot be non-positive integer.')
                except ValueError as invalid_quantity:
                    raise click.UsageError(
                        f'Invalid accelerator quantity {accelerator_split[1]}. '
                        'Expected a positive integer.') from invalid_quantity
            else:
                name, quantity = self.accelerator_str, None


        # Case-sensitive
        result = service_catalog.list_accelerators(gpus_only=True,
                                                   name_filter=name,
                                                   quantity_filter=quantity,
                                                   region_filter=self.region,
                                                   clouds=self.cloud,
                                                   case_sensitive=False,
                                                   all_regions=self.all_regions)

        if len(result) == 0:
            if self.cloud == 'kubernetes':
                print(kubernetes_utils.NO_GPU_ERROR_MESSAGE)
                return

            quantity_str = (f' with requested quantity {quantity}'
                            if quantity else '')
            print(f'Resources \'{name}\'{quantity_str} not found. ')
            print('Try GPURates(all=True) ')
            print('to show available accelerators.')
            return

        import pandas as pd
        # Define DataFrame columns
        df_columns = [
            'GPU', 'Quantity', 'Cloud', 'Instance Type', 'Device Memory (GB)',
            'vCPUs', 'Host Memory (GB)', 'Hourly Price ($)', 'Hourly Spot Price ($)',
        ]

        # Conditionally add the 'REGION' column if not showing all details
        if not show_all:
            df_columns.append('Region')

        all_dfs = []
        # Prepare data for DataFrame
        for i, (gpu, items) in enumerate(result.items()):
            data_for_df = []

            for item in items:
                instance_type_str = item.instance_type if not pd.isna(
                    item.instance_type) else '(attachable)'
                cpu_count = item.cpu_count
                if pd.isna(cpu_count):
                    cpu_str = '-'
                elif isinstance(cpu_count, (float, int)):
                    if int(cpu_count) == cpu_count:
                        cpu_str = str(int(cpu_count))
                    else:
                        cpu_str = f'{cpu_count:.1f}'
                device_memory_str = (f'{item.device_memory:.0f}GB' if
                                     not pd.isna(item.device_memory) else '-')
                host_memory_str = f'{item.memory:.0f}GB' if not pd.isna(
                    item.memory) else '-'
                price_str = f'$ {item.price:.3f}' if not pd.isna(
                    item.price) else '-'
                spot_price_str = f'$ {item.spot_price:.3f}' if not pd.isna(
                    item.spot_price) else '-'
                region_str = item.region if not pd.isna(item.region) else '-'

                row = {
                    'GPU': item.accelerator_name,
                    'Quantity': item.accelerator_count,
                    'Cloud': item.cloud,
                    'Instance Type': instance_type_str,
                    'Device Memory (GB)': device_memory_str,
                    'vCPUs': cpu_str,
                    'Host Memory (GB)': host_memory_str,
                    'Hourly Price ($)': price_str,
                    'Hourly Spot Price ($)': spot_price_str,
                }

                # Add region information conditionally
                if not show_all:
                    row['Region'] = region_str

                data_for_df.append(row)

            if i != 0:
                print('\n')


            gpu_rates_df = pd.DataFrame(data_for_df, columns=df_columns)
            all_dfs.append(gpu_rates_df)


        ###
        # Concatenate all DataFrames into one
        merged_df = pd.concat(all_dfs)
        # Group by 'Cloud' column
        grouped = merged_df.groupby('Cloud')

        return grouped, merged_df


    @service_catalog.fallback_to_default_catalog
    @usage_lib.entrypoint
    def select_gpu(self):
        try:
            self.validate_parameters()
            self.get_cloud_object()
            grouped, merged_df = self._output()

            return grouped, merged_df
        except Exception as e:
            print("Exception : ", e)



# @title GPU selection UI
import ipywidgets as widgets
from ipywidgets import interact, HBox, VBox
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from IPython.display import display


SKY_CLOUD = {
        'AWS': sky.AWS,
        'GCP': sky.GCP,
        'Azure': sky.Azure,
        'IBM': sky.IBM,
        'Cudo': sky.Cudo,
        'Lambda': sky.Lambda,
        'scp': sky.SCP,
        'Kubernetes': sky.Kubernetes,
        'OCI': sky.OCI,
        'Paperspace': sky.Paperspace,
        'RunPod': sky.RunPod,
        'Vsphere': sky.Vsphere,
        'Fluidstack': sky.Fluidstack,
    }

class GpuRateUI:
    def __init__(self, max_value=45):
        self.max_value = max_value
        self.image_path = UI_CONFIG['logo_image_dir']
        self.selected_config = {}  # Store selected configuration
        self.GPU_Family_Name = ""
        self.RAM = ""
        self.Price = 0.0

    def draw_speedometer(self, price=0):
        price = max(0, min(price, self.max_value))

        fig, ax = plt.subplots(figsize=(4, 2))

        theta1, theta2 = 180, 0
        split_point = theta1 - 3/4 * (theta1 - theta2)

        green_arc = np.linspace(np.deg2rad(theta1), np.deg2rad(split_point), 500)
        ax.plot(np.cos(green_arc), np.sin(green_arc), color='green', lw=10)

        red_arc = np.linspace(np.deg2rad(split_point), np.deg2rad(theta2), 500)
        ax.plot(np.cos(red_arc), np.sin(red_arc), color='red', lw=10)

        angle = np.deg2rad(theta1 - (price / self.max_value) * (theta1 - theta2))

        ax.arrow(0, 0, 0.8 * np.cos(angle), 0.8 * np.sin(angle),
                 head_width=0.05, head_length=0.1, fc='blue', ec='blue', lw=2)

        ax.text(0, -1.2, f'Price: ${price:.2f}', fontsize=10, ha='center')

        label_angle = (theta2 + split_point) / 2
        label_radius = 1.2
        ax.text(label_radius * np.cos(np.deg2rad(label_angle)),
                label_radius * np.sin(np.deg2rad(label_angle)),
                'Big 3 clouds',
                fontsize=10, ha='center', va='center', rotation=label_angle-90, rotation_mode='anchor')

        ax.axis('off')
        plt.show()

    def on_submit(self, image_name, radio_buttons, slider):
        self.GPU_Family_Name = image_name
        self.RAM = radio_buttons.value
        self.Price = slider.value
        selected_gpu = Path(image_name).stem
        print(f"GPU Family: {selected_gpu}, Price: ${self.Price:.2f}")
        print("PROCESSING .....")
        filtered_df = self.process_gpu_rate_sheet(selected_gpu, self.Price)

        grouped = filtered_df.groupby('Cloud')
        for j, (name, group) in enumerate(grouped, start=1):
            print(f"{j}. ---- Cloud: {name} ----")
            group.reset_index(drop=True, inplace=True)
            group.index += 1
            print(group.to_string())
            print("\n")

        self.create_dropdowns(filtered_df)

    def create_dropdowns(self, df):
        cloud_dropdown = widgets.Dropdown(
            options=df['Cloud'].unique().tolist(),
            layout=widgets.Layout(width='350px')
        )
        cloud_label = widgets.Label(
            value='Select Cloud :',
            layout=widgets.Layout(width='100px')
        )
        cloud_dropbox = widgets.HBox([cloud_label, cloud_dropdown])

        gpu_dropdown = widgets.Dropdown(
            options=[],
            layout=widgets.Layout(width='350px')
        )
        gpu_label = widgets.Label(
            value='Select GPU :',
            layout=widgets.Layout(width='100px')
        )
        gpu_dropbox = widgets.HBox([gpu_label, gpu_dropdown])

        instance_dropdown = widgets.Dropdown(
            options=[],
        )
        instance_label = widgets.Label(
            value='Select Instance :',
            layout=widgets.Layout(width='100px')
        )
        instance_dropbox = widgets.HBox([instance_label, instance_dropdown])

        refresh_button = widgets.Button(description="Select", button_style='success', layout=widgets.Layout(width='200px'))

        def refresh_gpu_options(b):
            gpu_dropdown.options = []
            instance_dropdown.options = []

            cloud = cloud_dropdown.value
            if cloud:
                gpu_options = df[df['Cloud'] == cloud]['GPU'].unique().tolist()
                gpu_dropdown.options = gpu_options

        def update_instance_options(change):
            instance_dropdown.options = []

            gpu = change['new']
            if gpu:
                instances = df[(df['Cloud'] == cloud_dropdown.value) & (df['GPU'] == gpu)]
                instance_options = [
                    f"{row['Instance Type']} - ${row['Hourly Price ($)']}/hour"
                    for _, row in instances.iterrows()
                ]
                instance_dropdown.options = instance_options

        gpu_dropdown.observe(update_instance_options, names='value')
        refresh_button.on_click(refresh_gpu_options)

        display(cloud_dropbox, refresh_button, gpu_dropbox, instance_dropbox)

        confirm_button = widgets.Button(description="Confirm Selection", button_style='success', layout=widgets.Layout(width='200px'))

        def on_confirm_button_clicked(b):
            if not gpu_dropdown.value:
                print("Please select a valid GPU.")
            elif not instance_dropdown.value:
                print("Please select a valid instance.")
            else:
                instance_type = instance_dropdown.value.split(' - ')[0]
                selected_config = df[(df['Cloud'] == cloud_dropdown.value) &
                                     (df['GPU'] == gpu_dropdown.value) &
                                     (df['Instance Type'] == instance_type)]
                row = selected_config.iloc[0]
                quantity_value = row['Quantity']
                # Store the selected configuration
                self.selected_config = {
                    'cloud': cloud_dropdown.value,
                    'cloud_class': SKY_CLOUD[cloud_dropdown.value],
                    'gpu': gpu_dropdown.value,
                    'instance_type': instance_type,
                    'quantity': quantity_value
                }
                ##
                # set congig variables
                config = Config()  # Access the singleton instance
                config.set('cloud', cloud_dropdown.value)  # Set GPU type in the config
                config.set('cloud_class', SKY_CLOUD[cloud_dropdown.value])
                config.set('gpu', gpu_dropdown.value)
                config.set('instance_type', instance_type)
                config.set('quantity', quantity_value)
                ##
                print(f"Selected Machine Config ==>  Cloud : {cloud_dropdown.value} , GPU : {gpu_dropdown.value} , Instance Type : {instance_type} , Quantity : {quantity_value}")

                print("\n")
                cloud_manager = CloudManager(cloud_dropdown.value)
                cloud_manager.configure_and_enable_cloud()

        confirm_button.on_click(on_confirm_button_clicked)
        display(confirm_button)

    def run_gpu_rates_ui(self):
        columns = []
        for img_name in UI_CONFIG['logo_image_names']:
            output = widgets.Output()
            price_slider = widgets.FloatSlider(
                value=0,
                min=0,
                max=self.max_value,
                step=0.1,
                description='Price:',
                continuous_update=True,
                orientation='horizontal'
            )
            with output:
                interact(self.draw_speedometer, price=price_slider)

            image = widgets.Image(value=open(f'{self.image_path}{img_name}', "rb").read(), format='png', width=100, height=100)

            radio_buttons = widgets.RadioButtons(
                layout={'width': 'max-content', 'margin': '1px'},
                style={'description_width': 'initial'}
            )

            submit_button = widgets.Button(
                description='Submit',
                button_style='success',
                tooltip='Click to submit',
                icon='check'
            )
            submit_button.on_click(lambda b, img_name=img_name, radio_buttons=radio_buttons, slider=price_slider: self.on_submit(img_name, radio_buttons, slider))

            column = VBox([output, image, radio_buttons, submit_button], layout=widgets.Layout(align_items='center', padding='0px', margin='1px'))
            columns.append(column)

        row = HBox(columns, layout=widgets.Layout(justify_content='center'))
        display(row)


    def process_gpu_rate_sheet(self, selected_accelerator, hourly_cost_limit):
        gpu_rates = GPURates(all=True) # accelerator_str="A100",all_regions=True
        gpu_pdf, merged_df = gpu_rates.select_gpu()

        # Sort by column
        merged_df_sorted = merged_df.sort_values(by='Cloud')
        # Frop unwanted columns
        merged_df_sorted.drop(columns=['Hourly Spot Price ($)'], inplace=True)
        # Change hourly rate column value from str to float
        # merged_df_sorted['Hourly Price ($)'] = merged_df_sorted['Hourly Price ($)'].str.replace('$', '').astype(float)
        merged_df_sorted['Hourly Price ($)'] = (
            merged_df_sorted['Hourly Price ($)']
            .str.replace('$', '', regex=False)  # Remove the dollar sign
            .replace('-', float('nan'))         # Replace '-' with NaN
            .astype(float)                      # Convert to float
        )
        # Filter rows with Hourly Price < hourly_cost_limit
        merged_df_sorted = merged_df_sorted[merged_df_sorted['Hourly Price ($)'] < hourly_cost_limit]
        # Reset index if needed
        merged_df_sorted = merged_df_sorted.reset_index(drop=True)

        if selected_accelerator == "all_gpu":
            return merged_df_sorted
        else:
            # filter gpu using selected_accelerator
            merged_df_sorted = merged_df_sorted[merged_df_sorted['GPU'] == selected_accelerator]
            merged_df_sorted = merged_df_sorted.reset_index(drop=True)
            return merged_df_sorted

    def get_selected_config(self):
        # Return the selected configuration
        return self.selected_config




# @title Functions for forestore database operations
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore

import os
import pathlib
import sqlite3
import pickle
from datetime import datetime, timezone

# Initialize Firebase Admin SDK

firebase_credentials  = {
  "type": "service_account",
  "project_id": "demokratik-ai",
  "private_key_id": "b4fba9a840665ca0de436aa1343e338da548cddd",
  "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/DDnRp7BCci8z\nCQa3KND4C0cSIyj/KKSk7TVdl3iUXDmkCNonmpcEwhz5ueRIdFM0Qp6LEoFpKYj2\nUxxx3InIFyYzME+A7dsfSXlt0S3LLiMBtix1e7JPqF8VIxuj/gU7Y7LFX/2w3KlF\nScnTh9p0aL04ig+CfFb4eu1THoaoJCHF9QPXFVjovVxa3IR0ALwyZ7kNUCAQ1R9K\nSdooCX3BBh4CAudkiSVmBvtAbCl9N3hXyJI2FUGshyHnOIPjxdEgGKdJEQzLqLEc\nYeo3r2N/22bpkOtcdu/qXOvYyFHfid57clADHjUVAIeW+H4M/M1FJyBykJR+bpf0\n6WqI6jZZAgMBAAECggEAND3inasOLs1MjNceB/3YTsu+rn4A2u1BQHxBCIvABEnH\n1Uw6Y3QJcdqm0776Zrmweuzwr7TqwFgERN9rOtOEtbtmzb6bLvVb4w1TcC0wKYw/\nNFudgq7FjrTK81u7VjfDNH9JxmE+XRuqCfFoiDhxCz4M6CYQoazgl2f996m0RsvJ\nMSioMB45TDpaQ13etRmdvyAupJAT3q4uz9hhbeFg3sK0FQ0p14NoPNHTHZc+ao0Y\ntFpycXlpN60lYxYlwR5bVie8gYwMpOSJlx5tmoacd9wkfMWeWZ3WxItuKLlCzOvR\nV8hrHM/ouruW+IR8iA6tgkwjQRm9ntRulQ4uZVl9NQKBgQDwzoqzNif4+9rHcf/Z\nSPEErwnJfeklaPlUkwmZR6WIGP2nJaJ+VCZHqMH/uBvUXyLtIw3O4WZoZknmrlV6\nrMsH3+wpkVI1avSwYwFLXe5uSnpSUWtgjuUQwk+ZfRX3sieMd/cvEB2sBexbraq0\ndLvA3T06mvrBh+G3w+IMMUuczwKBgQDLGfyosscK1oFv7hL26+rdLABbzPXALTxv\nw2tC6mDdA040nGQVSZ9PKIi1UPC5M1fZxD95CTtuVhNyccMC8JxMl9qDyOhzV19b\nmGv+yexIuccDPUdd/OSeZb7qgTBH1TVu8j4Tk0OnCKVhxzpRAguSab9HLRhHJgEI\n6S3U0/xUVwKBgAqn2RNdDh/CZUf2D3A5/hBK/o/f9nUlVAMeDoqt1PFUycSfsbUN\nXDRHj24VY4XeP9lmd7Hz32g3PoXqqPWot8M9cWHVgGHEvdcgSmpfWbGfshH6cFnx\neMlpD8Sm+FXhmEZq+JEokWePS4ozAR7DG5vJLnBtczMlJCPRkyse802ZAoGBALFh\nJ9XaFZY0k2VDItL+Zef2JvYGAldt6oudu478jc5YtbOjJpGk8/q5ZvQcFbFZ0w8A\no4K8+me5y1A3ZN6rb4ZBVXmlPMaHd2PSsfcbOT6m6dW9bhf651wjdmLuo1EJN5qG\n/IoyInK+TVohGqBnTd34dzLdgoeY/Gxg4kMvRQAPAoGAfu1SzNEhtihJiJXOEymn\nlnpVDxWNluTUue3xqKcPbKAvzeFG8ikDi9oYHSC+jV6dGmMlrkzAXY+VLp0XuwHp\nH+WqWqoRQDqUtQpiQW85q3tt2IeTwMivJ+dfgpoT+hVLHNQu4XNkNkgnfd/hVubF\nF7KRerPOyIG4mvQUUNGPJhI=\n-----END PRIVATE KEY-----\n",
  "client_email": "firebase-adminsdk-8pl66@demokratik-ai.iam.gserviceaccount.com",
  "client_id": "112837816332224596495",
  "auth_uri": "https://accounts.google.com/o/oauth2/auth",
  "token_uri": "https://oauth2.googleapis.com/token",
  "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
  "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-8pl66%40demokratik-ai.iam.gserviceaccount.com",
  "universe_domain": "googleapis.com"
}

try:
    firebase_admin.get_app()
except ValueError:
    # If not initialized, initialize with the provided credentials
    cred = credentials.Certificate(firebase_credentials )
    firebase_admin.initialize_app(cred)

db = firestore.client()



def update_cluster_history_collection(user_id, data):

    existing_doc = db.collection(u'cluster-history').where(u'user_id', u'==', user_id).where(u'cluster_hash', u'==', data['cluster_hash']).get()
    if len(existing_doc) == 0:
        doc_ref = db.collection(u'cluster-history').document()
        data['user_id'] = user_id
        data['created_at'] = datetime.now(timezone.utc)
        doc_ref.set(data)
        return doc_ref.id
    else:
        # If existing document, update it with only non-None values
        doc_id = existing_doc[0].id
        doc_ref = db.collection(u'cluster-history').document(doc_id)
        filtered_data = {key: value for key, value in data.items() if value is not None}
        doc_ref.set(filtered_data, merge=True)  # Merge the non-None values
        return doc_ref.id


def extract_cluster_history_data(user_id):
    """
    Extracts cluster history data from a SQLite database.
    """
    # Define the path to the states SQLite database
    _DB_PATH = os.path.expanduser('~/.sky/state.db')
    pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)
    # Connect to the SQLite database
    connection = sqlite3.connect(_DB_PATH)
    cursor = connection.cursor()
    # Execute a query to fetch all rows from the 'cluster_history' table
    cursor.execute("SELECT * FROM cluster_history;")
    # Fetch all rows from the result
    rows = cursor.fetchall()

    if len(rows) != 0:
        for row in rows:
            data_dict = {}
            cluster_hash = row[0]
            name = row[1]
            num_nodes = row[2]
            requested_resources = pickle.loads(row[3])
            launched_resources = pickle.loads(row[4])
            usage_intervals = pickle.loads(row[5])

            data_dict['cluster_hash'] = cluster_hash
            data_dict['name'] = name
            data_dict['num_nodes'] = num_nodes
            # data_dict['requested_resources'] = requested_resources
            data_dict['usage_start'] = usage_intervals[0][0]
            data_dict['usage_end'] = usage_intervals[-1][1]

            cloud = None
            instance_type = None
            cpus = None
            memory = None
            accelerators = None
            accelerator_args = None
            use_spot = None
            region = None
            zone = None
            disk_size = None

            if launched_resources is not None:
                cloud = launched_resources.cloud
                instance_type = launched_resources.instance_type
                cpus = launched_resources.cpus
                memory = launched_resources.memory
                accelerators = launched_resources.accelerators
                accelerator_args = launched_resources.accelerator_args
                use_spot = launched_resources.use_spot
                region = launched_resources.region
                zone = launched_resources.zone
                disk_size = launched_resources.disk_size

                data_dict['cloud'] = str(cloud)
                data_dict['instance_type'] = instance_type
                data_dict['cpus'] = cpus
                data_dict['memory'] = memory

                if accelerators is not None:
                    data_dict['accelerator_name'] = list(accelerators.keys())[0]
                    data_dict['accelerator_count'] = list(accelerators.values())[0]
                else:
                    data_dict['accelerator_name'] = None
                    data_dict['accelerator_count'] = None
                data_dict['accelerator_args'] = accelerator_args
                data_dict['use_spot'] = use_spot
                data_dict['region'] = region
                data_dict['zone'] = zone
                data_dict['disk_size'] = disk_size

            doc_id = update_cluster_history_collection(user_id, data_dict)
        # Close the connection to the SQLite database
        connection.close()


# Update clusters collection
def update_clusters_collection(user_id, data):
    existing_doc = db.collection(u'clusters').where(u'user_id', u'==', user_id).where(u'cluster_hash', u'==', data['cluster_hash']).get()
    if len(existing_doc) == 0:
        doc_ref = db.collection(u'clusters').document()
        data['user_id'] = user_id
        data['created_at'] = datetime.now(timezone.utc)
        doc_ref.set(data)
        return doc_ref.id
    else:
        # If existing document, update it with only non-None values
        doc_id = existing_doc[0].id
        doc_ref = db.collection(u'clusters').document(doc_id)
        filtered_data = {key: value for key, value in data.items() if value is not None}
        doc_ref.set(filtered_data, merge=True)  # Merge the non-None values
        return doc_ref.id

# Extract data from clusters table
def extract_clusters_data(user_id):
    # Define the path to the states SQLite database
    _DB_PATH = os.path.expanduser('~/.sky/state.db')
    pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)
    # Connect to the SQLite database
    connection = sqlite3.connect(_DB_PATH)
    cursor = connection.cursor()
    # Execute a query to fetch all rows from the 'cluster_history' table
    cursor.execute("SELECT * FROM clusters;")
    # Fetch all rows from the result
    rows = cursor.fetchall()

    if len(rows) != 0:
        for row in rows:
            data_dict = {}
            name = row[0]
            launched_at = row[1]
            handle = pickle.loads(row[2])
            cluster_name_on_cloud = handle.cluster_name_on_cloud
            head_ip = handle.head_ip
            launched_resources = handle.launched_resources
            cloud = None
            instance_type = None
            accelerators = None

            if launched_resources is not None:
                cloud = launched_resources.cloud
                instance_type = launched_resources.instance_type
                accelerators = launched_resources.accelerators
            docker_user = handle.docker_user
            ssh_user = handle.ssh_user
            tpu_create_script = handle.tpu_create_script
            tpu_delete_script = handle.tpu_delete_script

            last_use = row[3]
            status = row[4]
            autostop = row[5]
            metadata = row[6]
            to_down = row[7]
            owner = row[8]
            cluster_hash = row[9]
            storage_mounts_metadata = row[10]
            cluster_ever_up = row[11]

            data_dict['name'] = name
            data_dict['launched_at'] = launched_at
            data_dict['cluster_name_on_cloud'] = cluster_name_on_cloud
            data_dict['head_ip'] = head_ip
            data_dict['cloud'] = str(cloud)
            data_dict['instance_type'] = instance_type
            data_dict['docker_user'] = docker_user
            data_dict['ssh_user'] = ssh_user
            data_dict['tpu_create_script'] = tpu_create_script
            data_dict['tpu_delete_script'] = tpu_delete_script
            data_dict['last_use'] = last_use
            data_dict['status'] = status
            data_dict['autostop'] = autostop
            data_dict['metadata'] = metadata
            data_dict['to_down'] = to_down
            data_dict['owner'] = owner
            data_dict['cluster_hash'] = cluster_hash
            data_dict['storage_mounts_metadata'] = storage_mounts_metadata
            data_dict['cluster_ever_up'] = cluster_ever_up

            if accelerators is not None:
                data_dict['accelerator_name'] = list(accelerators.keys())[0]
                data_dict['accelerator_count'] = list(accelerators.values())[0]
            else:
                data_dict['accelerator_name'] = None
                data_dict['accelerator_count'] = None

            update_clusters_collection(user_id, data_dict)
    connection.close()


def update_clusters_status(user_id, new_status, cluster_name):
    # Query for documents with status 'UP' or 'INIT'
    query = db.collection(u'clusters').where(u'user_id', u'==', user_id).where(u'name', u'==', cluster_name)
    statuses = ['UP', 'INIT']
    existing_docs = [doc for status in statuses for doc in query.where(u'status', u'==', status).get()]

    if existing_docs:
        print("Document(s) found:", len(existing_docs))
        for doc in existing_docs:
            doc_id = doc.id
            doc_ref = db.collection(u'clusters').document(doc_id)
            # Update the document with the new status, if it is different
            if doc.to_dict().get('status') != new_status:
                doc_ref.update({'status': new_status})
                print(f"Updated document {doc_id} with new status: {new_status}")
            else:
                print(f"No update needed for document {doc_id}, already has status: {new_status}")
    else:
        print("No matching documents found.")


def update_job_details(jobs_data, user_id):
    for job in jobs_data:
        job_id = job['job_id']
        job_name = job['job_name']
        username = job['username']
        submitted_at = job['submitted_at']
        status = job['status']
        start_at = job['start_at']
        end_at = job['end_at']
        resources = job['resources']

        data = {
            "job_id": job_id,
            "job_name": job_name,
            "username": username,
            "submitted_at" : submitted_at,
            "start_at" : start_at,
            "end_at" : end_at,
            "resources" : resources,
            "status": status.name
        }
        existing_doc = db.collection(u'jobs').where(u'user_id', u'==', user_id).where(u'job_id', u'==', job_id).where(u'job_name', u'==', job_name).get()
        if len(existing_doc) == 0:
            # print("existing_doc not found")
            doc_ref = db.collection(u'jobs').document()
            data['user_id'] = user_id
            data['created_at'] = datetime.now(timezone.utc)
            doc_ref.set(data)
            return doc_ref.id
        else:
            # print("existing_doc found")
            # If existing document, update it with only non-None values
            doc_id = existing_doc[0].id
            doc_ref = db.collection(u'jobs').document(doc_id)
            doc_ref.set(data, merge=True)  # Merge the non-None values
            return doc_ref.id

# @title Task selector code

# SELECT TYPE OF TASK
import ipywidgets as widgets
from ipywidgets import VBox, Output, HTML
from IPython.display import display, clear_output

class TaskConfigUI:
    def __init__(self):
        # Initialize the task configuration as an empty dictionary
        self.task_config = {}

        # Create a heading to guide the user
        self.heading = HTML(
            "<h3>Please select the task you would like to perform in this notebook:</h3>"
        )

        # Create UI components
        self.task_dropdown = widgets.Dropdown(
            options=['Select', 'Ready-to-Use Models', 'Custom Training or Inference'],
            value='Select',
            description='Task Type:',
            style={'description_width': 'initial'}
        )

        self.ready_to_use_model_dropdown = widgets.Dropdown(
            options=['Select', 'LLaMA 3.1', 'SAM2'],
            value='Select',
            description='Ready-to-Use Models:',
            style={'description_width': 'initial'}
        )

        self.custom_task_dropdown = widgets.Dropdown(
            options=['Select', 'Training', 'Inference'],
            value='Select',
            description='Custom Task:',
            style={'description_width': 'initial'}
        )

        self.submit_button = widgets.Button(
            description="Submit",
            button_style='success'
        )

        # Output widget for synchronous display of results
        #self.output = Output()

        # Set up event handlers
        self.task_dropdown.observe(self.display_relevant_dropdown, names='value')
        self.submit_button.on_click(self.on_submit_clicked)

        # Display the initial UI
        display(VBox([self.heading, self.task_dropdown]), self.output)

    def display_relevant_dropdown(self, change):
        """Displays the relevant dropdown based on the selected task type."""
        with self.output:
            clear_output(wait=True)

            # Display appropriate dropdown based on the selected task type
            if change['new'] == 'Ready-to-Use Models':
                display(VBox([self.ready_to_use_model_dropdown, self.submit_button]))
            elif change['new'] == 'Custom Training or Inference':
                display(VBox([self.custom_task_dropdown, self.submit_button]))

    def on_submit_clicked(self, b):
        """Handles the submit button click and updates task configuration."""
        with self.output:
            clear_output(wait=True)
            # Populate task configuration based on user selections
            if self.task_dropdown.value == 'Ready-to-Use Models' and self.ready_to_use_model_dropdown.value != 'Select':
                self.task_config['task_type'] = "Ready-to-Use Models"
                self.task_config['model'] = self.ready_to_use_model_dropdown.value
                ##
                config = Config()  # Access the singleton instance
                config.set('task_type', "Ready-to-Use Models")
                config.set('model', self.ready_to_use_model_dropdown.value)
                ##
            elif self.task_dropdown.value == 'Custom Training or Inference' and self.custom_task_dropdown.value != 'Select':
                self.task_config['task_type'] = "Custom Training or Inference"
                self.task_config['mode'] = self.custom_task_dropdown.value
                ##
                config = Config()  # Access the singleton instance
                config.set('task_type', "Custom Training or Inference")
                config.set('mode', self.custom_task_dropdown.value)
                ##
            else:
                print("Please make a valid selection for all fields.")
                return

            # Display the current configuration
            print("Configuration:", self.task_config)

    def get_task_config(self):
        """Returns the current task configuration."""
        return self.task_config



# @title Run this cell select type of task you want to run
# Create an instance of the TaskConfigUI class
task_ui = TaskConfigUI()

import ipywidgets as widgets
from IPython.display import display

# Dropdown for the first question with a 'Select' option
dropdown_main = widgets.Dropdown(
    options=[
        'Select',  # Added 'Select' option
        'Are you looking for a GPU',
        'Do you have an LLM in mind'
    ],
    description='Options:',
    disabled=False,
)

# Placeholder for dynamic dropdowns
dropdown_gpu = None
dropdown_task = None

# Function to display the secondary dropdown based on GPU selection
def on_main_dropdown_change(change):
    global dropdown_gpu, dropdown_task
    # Clear any previous dynamic dropdowns
    if dropdown_gpu:
        dropdown_gpu.close()
    if dropdown_task:
        dropdown_task.close()

    if change['new'] == 'Are you looking for a GPU':
        # Display another dropdown if GPU is selected, with a 'Select' option
        dropdown_gpu = widgets.Dropdown(
            options=[
                'Select',  # Added 'Select' option
                'Do you want to view a GPU rate card across 20 GPU clouds',
                'Do you have an LLM task in mind'
            ],
            description='GPU Options:',
            disabled=False,
        )
        display(dropdown_gpu)

        # Add event handler for the secondary dropdown
        dropdown_gpu.observe(on_gpu_dropdown_change, names='value')

# Function to call when the GPU rate card option is selected
def on_gpu_dropdown_change(change):
    if change['new'] == 'Do you want to view a GPU rate card across 20 GPU clouds':
        # Call your functions here
        # Run the GPU rate sheet function
        gpu_rate_ui = GpuRateUI()  # Assuming this is your custom class
        gpu_rate_ui.run_gpu_rates_ui()

# Bind the first dropdown to trigger when a selection is made
dropdown_main.observe(on_main_dropdown_change, names='value')

# Display the first dropdown
display(dropdown_main)


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



AttributeError: 'TaskConfigUI' object has no attribute 'output'

During handling of the above exception, another exception occurred:

AttributeError: 'AttributeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError
AttributeError: 'TaskConfigUI' object has no attribute 'output'

During handling of the above exception, another exception occurred:

AttributeError: 'AttributeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

TypeError: object of type 'NoneType' has no len()

During handling of the above exception, another exception occurred:

AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError
AttributeError: 'TaskConfigUI' object has no attribute 'output'

During handling of the above exception, another exception occurred:

Attribute

In [None]:
# @title Run this cell to get GPU rate sheet and to select  Cloud and GPU accelerators
# Function to get the gpu ratesheet from the excel sheet provided by user

# Instantiate and run the UI
gpu_rate_ui = GpuRateUI()
gpu_rate_ui.run_gpu_rates_ui()
# selected_config = gpu_rate_ui.get_selected_config()

HBox(children=(VBox(children=(Output(), Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00…

In [None]:
# @title Python class to launch vm and run task, and generate unique user_id

import typing
from typing import (Any, Callable, Dict, Iterable, List, Optional, Set, Tuple,
                    Union)

import sky
from sky import backends, optimizer
from sky import clouds
from sky.utils import resources_utils
from sky import cli
import uuid

if typing.TYPE_CHECKING:
    from sky import resources as resources_lib


# function to generate unique uder_id
def generate_user_id():
    return str(uuid.uuid4())


class ClusterTaskRunner:
    """
    ClusterTaskRunner is responsible for setting up a cluster and running task/training on it.

    """
    def __init__(
        self,
        user_id: str,
        #launch params
        cluster_name: Optional[str] = None,
        retry_until_up: bool = False,
        idle_minutes_to_autostop: Optional[int] = None,
        dryrun: bool = False,
        down: bool = False,
        stream_logs: bool = True,
        backend: Optional[backends.Backend] = None,
        optimize_target: optimizer.OptimizeTarget = optimizer.OptimizeTarget.COST,
        detach_setup: bool = False,
        detach_run: bool = False,
        no_setup: bool = False,
        clone_disk_from: Optional[str] = None,
        # Task params:
        name: Optional[str] = None,
        setup: Optional[str] = None,
        run: Optional[str] = None,
        envs: Optional[Dict[str, str]] = None,
        workdir: Optional[str] = None,
        num_nodes: Optional[int] = None,
        # Advanced:
        docker_image: Optional[str] = None,
        event_callback: Optional[str] = None,
        blocked_resources: Optional[Iterable['resources_lib.Resources']] = None,
        # Resources params:
        cloud: Optional[clouds.Cloud] = None,
        instance_type: Optional[str] = None,
        cpus: Union[None, int, float, str] = None,
        memory: Union[None, int, float, str] = None,
        accelerators: Union[None, str, Dict[str, int]] = None,
        accelerator_args: Optional[Dict[str, str]] = None,
        use_spot: Optional[bool] = None,
        # spot_recovery: Optional[str] = None,
        region: Optional[str] = None,
        zone: Optional[str] = None,
        image_id: Union[Dict[str, str], str, None] = None,
        disk_size: Optional[int] = None,
        disk_tier: Optional[Union[str, resources_utils.DiskTier]] = None,
        ports: Optional[Union[int, str, List[str], Tuple[str]]] = None,

    ):

        """Initializes the ClusterTaskRunner with the provided parameters.

        Args:
        user_id : Unique created by user to identify the user
        #sky.launch params#
        cluster_name (str): This is an optional string that represents the name of the cluster to create/reuse. If not provided, a name will be auto-generated.
        retry_until_up (bool): This boolean value determines whether to retry launching the cluster until it is up.
        idle_minutes_to_autostop (int): This is an optional integer that represents the number of idle minutes after which the cluster should be automatically stopped.
                                            If not set, the cluster will not be autostopped.
        dryrun (bool): This boolean value determines whether to actually launch the cluster. If True, the cluster will not be launched.
        down (bool): This boolean value determines whether to tear down the cluster after all jobs finish (successfully or abnormally).
        stream_logs (bool): This boolean value determines whether to show the logs in the terminal.
        backend (backends.Backend): This is an optional backend to use. If not provided, the default backend (CloudVMRayBackend) will be used.
        optimize_target (optimizer.OptimizeTarget): This is the target to optimize for. Choices are OptimizeTarget.COST and OptimizeTarget.TIME.
        detach_setup (bool): This boolean value determines whether to run setup in non-interactive mode as part of the job itself.
        detach_run (bool): This boolean value determines whether to return from this function and not stream execution logs as soon as a job is submitted.
        no_setup (bool): This boolean value determines whether to re-run setup commands.
        clone_disk_from (str): This is an optional string that, if set, specifies the cluster from which to clone the disk.
                                    This is useful to migrate the cluster to a different availability zone or region.

        #sky.Task params#
        name (str): This is an optional param that provides a name for the Task for display purposes.
        setup (str): This is an optional param that represents a setup command, which will be run before executing the run commands.
        run (str): This is an optional string that represents the actual command for the task. It can be a shell command or a command generator (callable).
        envs (Dict[str, str]): This is an optional dictionary of strings that represents environment variables to set before running the setup and run commands.
        workdir (str): This is an optional string that represents the local working directory. This directory will be synced to a location on the remote VM(s).
        num_nodes (int): This is an optional integer that represents the number of nodes to provision for this Task. If not provided, it is treated as 1 node.
        docker_image (str): This is an optional string that represents the base docker image that this Task will be built on. This is only in effect when LocalDockerBackend is used.
        blocked_resources (Iterable['resources_lib.Resources']): This is an optional param that represents a set of resources that this task cannot run on.

        #sky.Resources params#
        cloud (clouds.Cloud): This is an optional parameter that specifies the cloud provider to use.
                                It should be an instance of a Cloud class.
                                eg: sky.RunPod(),sky.AWS(),sky.Azure()
        instance_type (str): This is an optional parameter that specifies the type of instance to use on the cloud provider.
        cpus (Union[None, int, float, str]): This optional parameter specifies the number of CPUs required for the task.
        memory (Union[None, int, float, str]): This optional parameter specifies the amount of memory (in GiB) required for the task.
        accelerators (Union[None, str, Dict[str, int]]): This optional parameter specifies the accelerators (like GPUs) required for the task.
        accelerator_args: This optional parameter is a dictionary that can contain accelerator-specific arguments.
        use_spot (bool): This optional boolean parameter specifies whether to use spot instances.
        spot_recovery (str): This optional parameter specifies the spot recovery strategy to use for the managed spot to recover the cluster from preemption.
        region (str): This optional parameter specifies the region to use on the cloud provider.
        zone (str): This optional parameter specifies the zone to use on the cloud provider.
        image_id (Union[Dict[str, str], str, None]): This optional parameter specifies the image ID to use.
        disk_size (int): This optional parameter specifies the size of the OS disk in GiB.
        disk_tier (Union[str, resources_utils.DiskTier]): This optional parameter specifies the disk performance tier to use.
        ports (Union[int, str, List[str], Tuple[str]]): This optional parameter specifies the ports to open on the instance.
        """
        self.user_id = user_id
        self.cluster_name = cluster_name
        self.retry_until_up = retry_until_up
        self.idle_minutes_to_autostop = idle_minutes_to_autostop
        self.dryrun = dryrun
        self.down = down
        self.stream_logs = stream_logs
        self.backend = backend
        self.optimize_target = optimize_target
        self.detach_setup = detach_setup
        self.detach_run = detach_run
        self.no_setup = no_setup
        self.clone_disk_from = clone_disk_from

        self.name = name
        self.setup = setup
        self.run = run
        self.envs = envs
        self.workdir = workdir
        self.num_nodes = num_nodes
        self.docker_image = docker_image
        self.event_callback = event_callback
        self.blocked_resources = blocked_resources

        self.cloud = cloud
        self.instance_type = instance_type
        self.cpus = cpus
        self.memory = memory
        self.accelerators = accelerators
        self.accelerator_args = accelerator_args
        self.use_spot = use_spot
        # self.spot_recovery = spot_recovery
        self.region = region
        self.zone = zone
        self.image_id = image_id
        self.disk_size = disk_size
        self.disk_tier = disk_tier
        self.ports = ports



    def create_task(self):
        """
        Create a task object with the specified parameters.

        Returns:
            sky.Task: The created task object.
        """
        task = sky.Task(
            name=self.name,
            run=self.run,
            setup=self.setup,
            envs=self.envs,
            workdir=self.workdir,
            num_nodes=self.num_nodes,
            docker_image=self.docker_image,
            event_callback=self.event_callback,
            blocked_resources=self.blocked_resources
        )

        print(f"Created task: {task}")
        return task

    def create_resources(self):
        """
        Creates and configures the resources needed for the task.

        Returns:
            sky.Resources: The created resources object.
        """
        resources = sky.Resources(
            cloud=self.cloud,
            instance_type=self.instance_type,
            cpus=self.cpus,
            memory=self.memory,
            accelerators=self.accelerators,
            accelerator_args=self.accelerator_args,
            use_spot=self.use_spot,
            # spot_recovery=self.spot_recovery,
            region=self.region,
            zone=self.zone,
            image_id=self.image_id,
            disk_size=self.disk_size,
            disk_tier=self.disk_tier,
            ports=self.ports
        )

        print("Created resources")
        return resources

    def launch_task(self):
        """
        Launches a task by creating the task, setting resources, and launching it on the cluster.

        Returns:
            job_id (str): The ID of the launched job.
            handle (str): The handle of the launched job.
        """
        task = self.create_task()
        resources = self.create_resources()
        task.set_resources(resources)
        print("Set resources for the task")

        job_id, handle = sky.launch(
            task=task,
            cluster_name=self.cluster_name,
            retry_until_up=self.retry_until_up,
            idle_minutes_to_autostop=self.idle_minutes_to_autostop,
            dryrun=self.dryrun,
            down=self.down,
            stream_logs=self.stream_logs,
            backend=self.backend,
            optimize_target=self.optimize_target,
            detach_setup=self.detach_setup,
            detach_run=self.detach_run,
            no_setup=self.no_setup,
            clone_disk_from=self.clone_disk_from
        )

        print(f"Job ID: {job_id}")
        print(f"Handle: {handle}")
        ## Copy datas saved in local mysql db to google firestore db ##
        print("Extracting data from local db and saving to firestore db")
        # Call function to update cluster-history collection
        extract_cluster_history_data(self.user_id)
        # Call function to update clusters collection
        extract_clusters_data(self.user_id)
        # fetch job details
        jobs = core.queue(self.cluster_name, skip_finished=False, all_users=False)
        # update job details to firestore collection
        update_job_details(jobs,self.user_id)
        return job_id, handle

# Create class to manage cluster including stopping/terminating the cluster
class ClusterManager:
    def __init__(self, user_id: str, clusters: List[str], all: Optional[bool] = None, yes: bool = False, purge: bool = False):
        self.clusters = clusters
        self.all = all
        self.yes = yes
        self.purge = purge
        self.user_id = user_id

    def down_clusters(self):
        for cluster in self.clusters:
            # fetch job details
            jobs = core.queue(cluster, skip_finished=False, all_users=False)
            # update job details to firestore collection
            update_job_details(jobs, self.user_id)

        cli._down_or_stop_clusters(self.clusters,
                               apply_to_all=self.all,
                               down=True,
                               no_confirm=self.yes,
                               purge=self.purge)

        print("Cluster terminated successfully")
        # call function to update "TERMINATED" status in firestore
        print("Updating terminated status in firestore db")
        for cluster in self.clusters:
            # update terminate cluster status
            update_clusters_status(self.user_id, "TERMINATED", cluster)

        print("copiying data from local db to firestore db..")
        # Call function to update cluster-history collection
        extract_cluster_history_data(self.user_id)


In [None]:
# @title Code to generate files required for COG

##################################################################################################################################################
            # COG TRAINING FILE GENERATOR
##################################################################################################################################################

import os
import shutil

class CogTraining_FileGenerator:
    def __init__(self):
        pass

    def is_valid_identifier(self, identifier):
        return identifier.isidentifier()

    def format_default(self, value, value_type):
        if value_type == "str":
            return f'"{value}"' if value is not None else None
        return value

    def get_user_input_for_arg(self):
        arg = {}
        arg["name"] = input("Enter argument name: ")
        if arg["name"].lower() == 'done':
            return None

        arg["type"] = input("Enter argument type (int, str, bool, etc.): ")

        if arg["type"] == "int":
            default = input("Enter default value (int): ")
            arg["default"] = int(default) if default else None
        elif arg["type"] == "bool":
            default = input("Enter default value (True/False): ")
            arg["default"] = True if default.lower() == "true" else False
        elif arg["type"] == "str":
            default = input("Enter default value (str): ")
            arg["default"] = default if default else None
        else:
            print("Type not supported, defaulting to None")
            arg["default"] = None

        return arg

    def generate_cog_python_script(self, work_dir):

        # module_name = input("Enter the module name: ")
        # function_name = input("Enter the function name: ")
        # print("Please input the parameters sequentially as they appear in your function's definition.")

        # function_args = []
        # # Get user input for function arguments
        # while True:
        #     print("\nAdding a new function argument. Enter 'done' as argument name to finish.")
        #     new_arg = self.get_user_input_for_arg()
        #     if new_arg is None:
        #         break
        #     function_args.append(new_arg)

        ###

        module_name = "train"
        function_name = "main"


        function_args = [{'name': 'epochs', 'type': 'int', 'default': 5}, {'name': 'batch_size', 'type': 'int', 'default': 10}]

        ###
        # Ask user if they want to configure dataset downloading
        include_dataset_download = input("\nDo you want to include dataset download configuration? (y/n): ").strip().lower()
        if include_dataset_download == "y":

            # Get user input for dataset downloading
            print("\n--- Dataset Download Configuration ---\n")
            aws_access_key_id = None
            aws_secret_access_key = None
            region_name = None
            service_account_path = None

            while True:
                storage_type = input("Enter storage type (aws/gcs): ").strip().lower()
                if storage_type in ["aws", "gcs"]:
                    break
                else:
                    print("Invalid storage type. Please enter 'aws' or 'gcs'.")

            if storage_type == "aws":
                print("\nAWS Configuration:")
                aws_access_key_id = input("Enter your AWS Access Key Id: ")
                aws_secret_access_key = input("Enter your AWS Secret Access Key: ")
                region_name = input("Enter AWS Region Name (e.g., us-east-1): ")

            elif storage_type == 'gcs':
                print("\nGCP Configuration:")
                service_account_path = input("Enter path to your GCP service account key file: ")
                # Copy the GCP service acoount file to the work directory
                new_service_account_path = os.path.join(work_dir, os.path.basename(service_account_path))
                if not os.path.exists(new_service_account_path):
                    shutil.copy(service_account_path, new_service_account_path)
                    print(f'Service account file copied to {new_service_account_path}')
                else:
                    print(f'Service account file already exists at {new_service_account_path}')

            # Common inputs for both AWS and GCS
            print("\nStorage Bucket Configuration:")
            bucket_name = input("Enter storage bucket name: ")
            object_path = input("Enter dataset path in storage bucket (e.g., path/to/dataset.zip): ")


            # Template for python script
            class_template = """
# Install necessary libraries
import subprocess
subprocess.run(['pip', 'install', 'boto3', 'google-cloud-storage'], check=True)

import os
from typing import Any
import cog
import boto3
import zipfile
import shutil
from google.cloud import storage

from {module_name} import {function_name}

def unzip_file(zip_path, extract_to):
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)

        # Check if there is a single directory inside the extract_to directory
        extracted_contents = os.listdir(extract_to)
        if len(extracted_contents) == 1:
            single_dir_path = os.path.join(extract_to, extracted_contents[0])
            if os.path.isdir(single_dir_path):
                # Move all contents of the single directory up one level
                for item in os.listdir(single_dir_path):
                    shutil.move(os.path.join(single_dir_path, item), extract_to)
                # Remove the now empty single directory
                os.rmdir(single_dir_path)

        print(f'Unzipped file {{zip_path}} successfully to {{extract_to}}')
    except Exception as e:
        print(f'Error unzipping file: {{e}}')


def download_dataset(storage_type, bucket_name, object_path, download_path, aws_access_key_id=None, aws_secret_access_key=None, region_name=None, service_account_path=None):
    if storage_type == 'aws':
        session = boto3.Session(
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            region_name=region_name
        )
        s3 = session.client('s3')
        try:
            s3.download_file(bucket_name, object_path, download_path)
            print(f'File {{object_path}} downloaded successfully to {{download_path}}')
        except Exception as e:
            print(f'Error downloading file from AWS S3: {{e}}')

    elif storage_type == 'gcs':
        if service_account_path:
            os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = service_account_path
        client = storage.Client()
        try:
            bucket = client.bucket(bucket_name)
            blob = bucket.blob(object_path)
            blob.download_to_filename(download_path)
            print(f"Downloaded storage object {{object_path}} from bucket {{bucket_name}} to local file {{download_path}}.")
        except Exception as e:
            print(f'Error downloading file from GCS: {{e}}')

    unzip_file(download_path, "./data")


class Predictor:

    def predict(self,
{args}
    ) -> Any:

        download_dataset(
            storage_type="{storage_type}",
            bucket_name="{bucket_name}",
            object_path="{object_path}",
            download_path="{download_path}",
            aws_access_key_id="{aws_access_key_id}",
            aws_secret_access_key="{aws_secret_access_key}",
            region_name="{region_name}",
            service_account_path="{service_account_path}"
        )
        output = {function_name}({params})
        print(f'\\nResult ==== {{output}}')

        """

            args_list = []
            params_list = []
            # Add each argument to the args and params list
            for arg in function_args:
                if not self.is_valid_identifier(arg['name']):
                    raise ValueError(f"Invalid argument name: {arg['name']}")

                default_value = self.format_default(arg['default'], arg['type'])
                arg_line = f"        {arg['name']}: {arg['type']} = cog.Input(default={default_value})"

                args_list.append(arg_line)
                params_list.append(arg['name'])

            args = ",\n".join(args_list)
            params = ", ".join(params_list)
            local_service_account_path  = os.path.basename(service_account_path) if service_account_path else None

            # Formatting the class template
            formatted_class = class_template.format(
                args=args,
                params=params,
                module_name=module_name,
                function_name=function_name,
                storage_type=storage_type,
                bucket_name=bucket_name,
                object_path=object_path,
                download_path="./dataset.zip",
                aws_access_key_id=aws_access_key_id,
                aws_secret_access_key=aws_secret_access_key,
                region_name=region_name,
                service_account_path=local_service_account_path,

            )

        else:
            # Template for python script
            class_template = """
import os
from typing import Any
import cog

from {module_name} import {function_name}


class Predictor:

    def predict(self,
{args}
    ) -> Any:

        output = {function_name}({params})
        print(f'\\nResult ==== {{output}}')
    """

            args_list = []
            params_list = []
            # Add each argument to the args and params list
            for arg in function_args:
                if not self.is_valid_identifier(arg['name']):
                    raise ValueError(f"Invalid argument name: {arg['name']}")

                default_value = self.format_default(arg['default'], arg['type'])
                arg_line = f"        {arg['name']}: {arg['type']} = cog.Input(default={default_value})"

                args_list.append(arg_line)
                params_list.append(arg['name'])

            args = ",\n".join(args_list)
            params = ", ".join(params_list)
            # Fill in the template with the user input
            formatted_class = class_template.format(module_name=module_name, function_name=function_name, args=args, params=params)


        # Write the content to a python file
        file_name = f"{work_dir}/cog_train.py"
        try:
            with open(file_name, "w") as file:
                file.write(formatted_class)
            print(f"Script {file_name} created successfully.")
        except OSError as e:
            print(f"Failed to write to {file_name}: {e}")

    def create_cog_yaml(self, work_dir):
        # YAML content template for cog.yaml
        yaml_content = """
build:
    python_version: "3.8"
    gpu: true
    system_packages:
        - "ffmpeg"
        - "libgl1-mesa-glx"
        - "libglib2.0-0"
    python_requirements: "requirements.txt"

predict: "cog_train.py:Predictor"
    """

        file_path = f"{work_dir}/cog.yaml"
        # Writing the YAML content to the file
        with open(file_path, 'w') as file:
            file.write(yaml_content)

        print(f"'{file_path}' has been created with the specified content.")

    def generate_cog_files(self):
        print("GENERATE COG TRAINING FILES ... ")
        work_dir = input("Enter the work directory path: ")
        # Function to generate cog python script
        self.generate_cog_python_script(work_dir)
        # Function to create cog.yaml file
        self.create_cog_yaml(work_dir)





##################################################################################################################################################
            # COG INFERENCE FILE GENERATOR
##################################################################################################################################################

# @title Code to generate files required for COG
import os
import shutil


class CogInference_FileGenerator:
    def __init__(self):
        self.work_dir = None

    def get_user_input(self, prompt, default=None):
        user_input = input(prompt).strip()
        return user_input if user_input else default

    def is_valid_identifier(self, identifier):
        return identifier.isidentifier()

    def format_default(self, value, value_type):
        if value_type == "str":
            return f'"{value}"' if value is not None else None
        return value

    def get_user_input_for_arg(self):
        arg = {}
        arg["name"] = self.get_user_input("Enter argument name: ")
        if arg["name"].lower() == 'done':
            return None

        arg["type"] = self.get_user_input("Enter argument type (int, str, bool, etc.): ")

        if arg["type"] == "int":
            default = self.get_user_input("Enter default value (int): ")
            arg["default"] = int(default) if default else None
        elif arg["type"] == "bool":
            default = self.get_user_input("Enter default value (True/False): ")
            arg["default"] = default.lower() == "true" if default else None
        elif arg["type"] == "str":
            default = self.get_user_input("Enter default value (str): ")
            arg["default"] = default if default else None
        else:
            print("Type not supported, defaulting to None")
            arg["default"] = None

        return arg

    def configure_storage(self):
        """This function gets the storage type and configuration details from the user.
        """
        while True:
            storage_type = self.get_user_input("Enter storage type (aws/gcs): ").strip().lower()
            if storage_type in ["aws", "gcs"]:
                break
            else:
                print("Invalid storage type. Please enter 'aws' or 'gcs'.")

        config = {
            "storage_type": storage_type,
            "aws_access_key_id": None,
            "aws_secret_access_key": None,
            "region_name": None,
            "service_account_path": None
        }

        if storage_type == "aws":
            print("\nAWS Configuration:")
            config["aws_access_key_id"] = self.get_user_input("Enter your AWS Access Key Id: ")
            config["aws_secret_access_key"] = self.get_user_input("Enter your AWS Secret Access Key: ")
            config["region_name"] = self.get_user_input("Enter AWS Region Name (e.g., us-east-1): ")
        elif storage_type == 'gcs':
            print("\nGCP Configuration:")
            config["service_account_path"] = self.get_user_input("Enter path to your GCP service account key file: ")

        return config

    def configure_bucket(self, config):
        """This function gets the bucket name and model & dataset path from the user.
        """
        model_path = None
        dataset_path = None
        print("\nStorage Bucket Configuration:")
        bucket_name = self.get_user_input("Enter storage bucket name: ")
        if config["download_dataset"]:
            dataset_path = self.get_user_input("Enter dataset path in storage bucket (e.g., path/to/dataset.zip): ")
        if config["download_model"]:
            model_path = self.get_user_input("Enter model path in storage bucket (e.g., path/to/model.h5): ")

        return {
            "bucket_name": bucket_name,
            "dataset_path": dataset_path,
            "model_path": model_path,
        }

    def configure_download(self):
        """This function configures download options for dataset and model separately."""
        config = {
            "download_dataset": False,
            "download_model": False
        }

        download_dataset = self.get_user_input("\nDo you want to include download configuration for the dataset? (y/n): ").strip().lower()
        if download_dataset == "y":
            config["download_dataset"] = True

        download_model = self.get_user_input("\nDo you want to include download configuration for the model? (y/n): ").strip().lower()
        if download_model == "y":
            config["download_model"] = True

        if config["download_dataset"] or config["download_model"]:
            storage_config = self.configure_storage()
            bucket_config = self.configure_bucket(config)
            config.update({**storage_config, **bucket_config})

            if storage_config["storage_type"] == 'gcs' and storage_config.get("service_account_path"):
                service_account_path = storage_config["service_account_path"]
                new_service_account_path = os.path.join(self.work_dir, os.path.basename(service_account_path))
                if not os.path.exists(new_service_account_path):
                    shutil.copy(service_account_path, new_service_account_path)
                    print(f'Service account file copied to {new_service_account_path}')
                else:
                    print(f'Service account file already exists at {new_service_account_path}')

            return config
        else:
            return config

    def generate_function_args(self, function_args):
        args_list = []
        for arg in function_args:
            if not self.is_valid_identifier(arg['name']):
                raise ValueError(f"Invalid argument name: {arg['name']}")

            default_value = self.format_default(arg['default'], arg['type'])
            arg_line = f"        {arg['name']}: {arg['type']} = cog.Input(default={default_value})"
            args_list.append(arg_line)

        return ",\n".join(args_list)

    def generate_function_params(self, function_args):
        return ", ".join(arg['name'] for arg in function_args)

    def generate_class_template(self, function_args, module_name, function_name, download_config=None):
        args = self.generate_function_args(function_args)
        params = self.generate_function_params(function_args)

        if download_config["download_dataset"] or download_config["download_model"]:
            storage_type = download_config["storage_type"]
            bucket_name = download_config["bucket_name"]
            dataset_path = download_config.get("dataset_path", None)
            model_path = download_config.get("model_path", None)
            aws_access_key_id = download_config.get("aws_access_key_id")
            aws_secret_access_key = download_config.get("aws_secret_access_key")
            region_name = download_config.get("region_name")
            service_account_path = download_config.get("service_account_path")
            local_service_account_path = os.path.basename(service_account_path) if service_account_path else None

            dataset_download_code = ""
            if download_config.get('download_dataset'):
                dataset_download_code = f"""
        # Download dataset
        fetch_resource(
            storage_type="{storage_type}",
            bucket_name="{bucket_name}",
            object_path="{dataset_path}",
            download_path="./dataset.zip",
            aws_access_key_id="{aws_access_key_id}",
            aws_secret_access_key="{aws_secret_access_key}",
            region_name="{region_name}",
            service_account_path="{local_service_account_path}"
        )
        # Extract dataset
        unzip_file("./dataset.zip", "./data")"""

            model_download_code = ""
            if download_config.get('download_model'):
                model_filename = os.path.basename(model_path)
                model_download_code = f"""
        # Download model
        model_download_path = f"./{model_filename}"
        fetch_resource(
            storage_type="{storage_type}",
            bucket_name="{bucket_name}",
            object_path="{model_path}",
            download_path=model_download_path,
            aws_access_key_id="{aws_access_key_id}",
            aws_secret_access_key="{aws_secret_access_key}",
            region_name="{region_name}",
            service_account_path="{local_service_account_path}"
        )"""

            return f"""
# Install necessary libraries
import subprocess
subprocess.run(['pip', 'install', 'boto3', 'google-cloud-storage'], check=True)

import os
import json
from typing import Any
import cog
import boto3
import zipfile
import shutil
from google.cloud import storage

from {module_name} import {function_name}

def unzip_file(zip_path, extract_to):
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)

        extracted_contents = os.listdir(extract_to)
        if len(extracted_contents) == 1:
            single_dir_path = os.path.join(extract_to, extracted_contents[0])
            if os.path.isdir(single_dir_path):
                for item in os.listdir(single_dir_path):
                    shutil.move(os.path.join(single_dir_path, item), extract_to)
                os.rmdir(single_dir_path)

        print(f'Unzipped file {{zip_path}} successfully to {{extract_to}}')
    except Exception as e:
        print(f'Error unzipping file: {{e}}')

def fetch_resource(storage_type, bucket_name, object_path, download_path, aws_access_key_id=None, aws_secret_access_key=None, region_name=None, service_account_path=None):
    if storage_type == 'aws':
        session = boto3.Session(
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            region_name=region_name
        )
        s3 = session.client('s3')
        try:
            s3.download_file(bucket_name, object_path, download_path)
            print(f'File {{object_path}} downloaded successfully to {{download_path}}')
        except Exception as e:
            print(f'Error downloading file from AWS S3: {{e}}')

    elif storage_type == 'gcs':
        if service_account_path:
            os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = service_account_path
        client = storage.Client()
        try:
            bucket = client.bucket(bucket_name)
            blob = bucket.blob(object_path)
            blob.download_to_filename(download_path)
            print(f"Downloaded storage object {{object_path}} from bucket {{bucket_name}} to local file {{download_path}}.")
        except Exception as e:
            print(f'Error downloading file from GCS: {{e}}')

class Predictor:

    def predict(self,
{args}
    ) -> Any:
        {dataset_download_code}
        {model_download_code}

        output = {function_name}({params})
        print(f'\\nResult ==== {{output}}')
        result = {{"Prediction": output}}
        return result
"""


        else:
            return f"""
import os
import json
from typing import Any
import cog

from {module_name} import {function_name}

class Predictor:

    def predict(self,
{args}
    ) -> Any:

        output = {function_name}({params})
        print(f'\\nResult ==== {{output}}')
        result = {{"Prediction": output}}
        return result
"""

    def generate_cog_python_script(self):
        # module_name = self.get_user_input("Enter the inference module name: ")
        # function_name = self.get_user_input("Enter the inference function name: ")
        # print("Please input the parameters sequentially as they appear in your function's definition.")

        # function_args = []
        # while True:
        #     print("\nAdding a new function argument. Enter 'done' as argument name to finish.")
        #     new_arg = self.get_user_input_for_arg()
        #     if new_arg is None:
        #         break
        #     function_args.append(new_arg)

        ##
        # Hardcoded values :
        module_name = "inference"
        function_name = "run_inference"
        function_args = [{'name': 'image_url', 'type': 'str', 'default': ""}]

        ##
        download_config = self.configure_download()
        class_template = self.generate_class_template(function_args, module_name, function_name, download_config)

        file_name = f"{self.work_dir}/cog_inference.py"
        try:
            with open(file_name, "w") as file:
                file.write(class_template)
            print(f"Script {file_name} created successfully.")
        except OSError as e:
            print(f"Failed to write to {file_name}: {e}")

    def create_cog_yaml(self):
        yaml_content = """
build:
    python_version: "3.8"
    gpu: true
    system_packages:
        - "ffmpeg"
        - "libgl1-mesa-glx"
        - "libglib2.0-0"
    python_requirements: "requirements.txt"

predict: "cog_inference.py:Predictor"
        """
        file_path = f"{self.work_dir}/cog.yaml"
        with open(file_path, 'w') as file:
            file.write(yaml_content)
        print(f"'{file_path}' has been created with the specified content.")

    def generate_cog_files(self):
        print("GENERATE COG INFERENCE FILES ..")
        self.work_dir = self.get_user_input("Enter the work directory path: ")
        self.generate_cog_python_script()
        self.create_cog_yaml()






In [None]:
# @title Class to get cost report for all VM's and Class to terminate VM's that exceeds cost limit
import firebase_admin
from firebase_admin import credentials, firestore
import time

import sky
from sky.resources import Resources
from sky import status_lib
from sky import core
from sky.utils.cli_utils import status_utils
from sky.utils import controller_utils



class ClusterCostReport:
    def __init__(self, user_id):
        self.user_id = user_id
        # cred = credentials.Certificate(credential_path)
        # firebase_admin.initialize_app(cred)
        # self.db = firestore.client()

        # Check if the default app has already been initialized
        firebase_credentials  = {
            "type": "service_account",
            "project_id": "demokratik-ai",
            "private_key_id": "b4fba9a840665ca0de436aa1343e338da548cddd",
            "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/DDnRp7BCci8z\nCQa3KND4C0cSIyj/KKSk7TVdl3iUXDmkCNonmpcEwhz5ueRIdFM0Qp6LEoFpKYj2\nUxxx3InIFyYzME+A7dsfSXlt0S3LLiMBtix1e7JPqF8VIxuj/gU7Y7LFX/2w3KlF\nScnTh9p0aL04ig+CfFb4eu1THoaoJCHF9QPXFVjovVxa3IR0ALwyZ7kNUCAQ1R9K\nSdooCX3BBh4CAudkiSVmBvtAbCl9N3hXyJI2FUGshyHnOIPjxdEgGKdJEQzLqLEc\nYeo3r2N/22bpkOtcdu/qXOvYyFHfid57clADHjUVAIeW+H4M/M1FJyBykJR+bpf0\n6WqI6jZZAgMBAAECggEAND3inasOLs1MjNceB/3YTsu+rn4A2u1BQHxBCIvABEnH\n1Uw6Y3QJcdqm0776Zrmweuzwr7TqwFgERN9rOtOEtbtmzb6bLvVb4w1TcC0wKYw/\nNFudgq7FjrTK81u7VjfDNH9JxmE+XRuqCfFoiDhxCz4M6CYQoazgl2f996m0RsvJ\nMSioMB45TDpaQ13etRmdvyAupJAT3q4uz9hhbeFg3sK0FQ0p14NoPNHTHZc+ao0Y\ntFpycXlpN60lYxYlwR5bVie8gYwMpOSJlx5tmoacd9wkfMWeWZ3WxItuKLlCzOvR\nV8hrHM/ouruW+IR8iA6tgkwjQRm9ntRulQ4uZVl9NQKBgQDwzoqzNif4+9rHcf/Z\nSPEErwnJfeklaPlUkwmZR6WIGP2nJaJ+VCZHqMH/uBvUXyLtIw3O4WZoZknmrlV6\nrMsH3+wpkVI1avSwYwFLXe5uSnpSUWtgjuUQwk+ZfRX3sieMd/cvEB2sBexbraq0\ndLvA3T06mvrBh+G3w+IMMUuczwKBgQDLGfyosscK1oFv7hL26+rdLABbzPXALTxv\nw2tC6mDdA040nGQVSZ9PKIi1UPC5M1fZxD95CTtuVhNyccMC8JxMl9qDyOhzV19b\nmGv+yexIuccDPUdd/OSeZb7qgTBH1TVu8j4Tk0OnCKVhxzpRAguSab9HLRhHJgEI\n6S3U0/xUVwKBgAqn2RNdDh/CZUf2D3A5/hBK/o/f9nUlVAMeDoqt1PFUycSfsbUN\nXDRHj24VY4XeP9lmd7Hz32g3PoXqqPWot8M9cWHVgGHEvdcgSmpfWbGfshH6cFnx\neMlpD8Sm+FXhmEZq+JEokWePS4ozAR7DG5vJLnBtczMlJCPRkyse802ZAoGBALFh\nJ9XaFZY0k2VDItL+Zef2JvYGAldt6oudu478jc5YtbOjJpGk8/q5ZvQcFbFZ0w8A\no4K8+me5y1A3ZN6rb4ZBVXmlPMaHd2PSsfcbOT6m6dW9bhf651wjdmLuo1EJN5qG\n/IoyInK+TVohGqBnTd34dzLdgoeY/Gxg4kMvRQAPAoGAfu1SzNEhtihJiJXOEymn\nlnpVDxWNluTUue3xqKcPbKAvzeFG8ikDi9oYHSC+jV6dGmMlrkzAXY+VLp0XuwHp\nH+WqWqoRQDqUtQpiQW85q3tt2IeTwMivJ+dfgpoT+hVLHNQu4XNkNkgnfd/hVubF\nF7KRerPOyIG4mvQUUNGPJhI=\n-----END PRIVATE KEY-----\n",
            "client_email": "firebase-adminsdk-8pl66@demokratik-ai.iam.gserviceaccount.com",
            "client_id": "112837816332224596495",
            "auth_uri": "https://accounts.google.com/o/oauth2/auth",
            "token_uri": "https://oauth2.googleapis.com/token",
            "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
            "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-8pl66%40demokratik-ai.iam.gserviceaccount.com",
            "universe_domain": "googleapis.com"
          }

        try:
            firebase_admin.get_app()
            print(f"already initialised")
        except ValueError:
            print("initialise firestore")
            print(f"credential_path : {firebase_credentials}")
            # If not initialized, initialize with the provided credentials
            cred = credentials.Certificate(self.credential_path)
            firebase_admin.initialize_app(cred)

        self.db = firestore.client()


    def get_cluster_history_datas(self):
        # using usre_id fetch data from firestore cluster-history and cluster db
        query = self.db.collection('cluster-history').where('user_id', '==', self.user_id)
        cluster_history_docs = query.stream()

        cluster_history_data = []
        clouds_data = {
            "Azure": sky.Azure(), "AWS": sky.AWS(), "GCP": sky.GCP(), "IBM": sky.IBM(),
            "Cudo": sky.Cudo(), "Lambda": sky.Lambda(), "SCP": sky.SCP(),
            "Kubernetes": sky.Kubernetes(), "OCI": sky.OCI(),
            "RunPod": sky.RunPod(), "Vsphere": sky.Vsphere(), "Fluidstack": sky.Fluidstack()
        }
        for doc in cluster_history_docs:
            doc_data = doc.to_dict()
            total_duration = 0
            usage_intervals = doc_data.get('usage_intervals', [(doc_data.get('usage_start'), doc_data.get('usage_end'))])

            for i, (start_time, end_time) in enumerate(usage_intervals):
                # duration from latest start time to time of query
                if start_time is None:
                    continue
                if end_time is None:
                    assert i == len(usage_intervals) - 1, i
                    end_time = int(time.time())
                start_time, end_time = int(start_time), int(end_time)
                total_duration += end_time - start_time

            cloud_provider = clouds_data.get(doc_data.get('cloud'))

            resource = Resources(
                cloud=cloud_provider,
                instance_type=doc_data.get('instance_type'),
                cpus=doc_data.get('cpus'),
                memory=doc_data.get('memory'),
                accelerators={doc_data.get('accelerator_name'): doc_data.get('accelerator_count')} if doc_data.get('accelerator_name') and doc_data.get('accelerator_count') else None,
                accelerator_args=doc_data.get('accelerator_args'),
                use_spot=doc_data.get('use_spot'),
                # spot_recovery=doc_data.get('spot_recovery'),
                region=doc_data.get('region'),
                zone=doc_data.get('zone'),
                image_id=doc_data.get('image_id'),
                disk_size=doc_data.get('disk_size'),
                disk_tier=doc_data.get('disk_tier'),
                ports=doc_data.get('ports')
            )

            cluster_history_data.append({
                "name": doc_data.get('name'),
                "num_nodes": doc_data.get('num_nodes'),
                "cluster_hash": doc_data.get('cluster_hash'),
                "resources": resource,
                "usage_intervals": usage_intervals,
                "duration": total_duration,
                "launched_at": doc_data.get('usage_start'),
                "status": None
            })


        # Fetch cluster statuses
        clusters_query = self.db.collection('clusters').where('user_id', '==', self.user_id)
        clusters_docs = clusters_query.stream()
        clusters_data = []
        for doc in clusters_docs:
            # print(f"cluster doc_id : {doc.id}")
            doc_data = self.db.collection(u'clusters').document(doc.id).get().to_dict()

            status = None

            if doc_data['status'] is not None and doc_data['status'] != 'TERMINATED':
                status = status_lib.ClusterStatus[doc_data['status']]

            clusters_data.append({
                "name": doc_data['name'],
                "cluster_hash": doc_data['cluster_hash'],
                "status": status
            })

        # merge two dicts
        # Create a dictionary from clusters for quick lookup
        cluster_dict = {cluster['cluster_hash']: cluster for cluster in clusters_data}
        # Loop through each item in cluster_history and update it if matching cluster_hash is found in cluster_dict
        for history in cluster_history_data:
            if history['cluster_hash'] in cluster_dict:
                matching_cluster = cluster_dict[history['cluster_hash']]
                history.update({
                    'status': matching_cluster['status']
                })

        return cluster_history_data


    def cluster_cost_report(self):
        cluster_history_data = self.get_cluster_history_datas()
        cluster_records = sorted(cluster_history_data, key=lambda record: -record.get('launched_at', 0))

        for cluster_record in cluster_records:
            duration = cluster_record['duration']
            launched_nodes = cluster_record['num_nodes']
            launched_resources = cluster_record['resources']

            # Assuming there is a method get_cost(duration) in Resources class
            cost = (launched_resources.get_cost(duration) * launched_nodes)
            cluster_record['total_cost'] = cost

        return cluster_records

    def show_cost_report(self, cluster_records):
        normal_cluster_records = []
        controllers = dict()
        for cluster_record in cluster_records:
            cluster_name = cluster_record['name']
            controller = controller_utils.Controllers.from_name(cluster_name)
            if controller:
                controller_name = controller.value.name
                if controller_name not in controllers:
                    controllers[controller_name] = cluster_record
            else:
                normal_cluster_records.append(cluster_record)

        total_cost = status_utils.get_total_cost_of_displayed_records(normal_cluster_records, all)
        status_utils.show_cost_report_table(normal_cluster_records, all)
        for controller_name, cluster_record in controllers.items():
            status_utils.show_cost_report_table([cluster_record], all, controller_name=controller_name.capitalize())
            total_cost += cluster_record['total_cost']

    def run(self):
        cluster_records = self.cluster_cost_report()
        self.show_cost_report(cluster_records)

class ClusterCostManager:
    """
    Class to manage cluster cost and perform actions based on cost limits.

    Args:
        cost_limit (float): The cost limit in USD.
        email_id (str): The email ID to send notifications.

    """

    def __init__(self, cost_limit, email_id, user_id):
        self.cost_limit = cost_limit # Cost limit in USD
        self.email_id = email_id    # Email ID to send notifications
        self.user_id = user_id

    def run(self):
        """
        Run the function to get estimated cost of clusters.
        """
        cluster_history = self.get_cluster_history()
        self.process_clusters(cluster_history)

    def get_cluster_history(self):
        """
        Get the cluster history.

        Returns:
            list: A list of cluster data.
        """
        manager = ClusterCostReport(self.user_id)
        cluster_records = manager.cluster_cost_report()
        return cluster_records

    def process_clusters(self, cluster_history):
        """
        Process the clusters based on their cost.

        Args:
            cluster_history (list): A list of cluster data.
        """
        for cluster_data in cluster_history:
            if cluster_data["status"] == status_lib.ClusterStatus["UP"] or cluster_data["status"] == status_lib.ClusterStatus["STOPPED"]:
                estimated_cost = cluster_data["total_cost"]
                eighty_percent_limit = 0.8 * self.cost_limit

                if eighty_percent_limit < estimated_cost < self.cost_limit:
                    print("Estimated cost has reached more than 80% of the cost limit but has not exceeded the cost limit.")
                    self.send_notification()

                if estimated_cost >= self.cost_limit:
                    print("Estimated cost has exceeded the cost limit.")
                    self.shutdown_cluster([cluster_data['name']])

    def send_notification(self):
        """
        Send a notification about the cost threshold.
        """
        # Placeholder for notification logic
        print(f"Notification sent to {self.email_id} about cost threshold.")

    def shutdown_cluster(self, cluster_name):
        """
        Shutdown a cluster to prevent further costs.

        Args:
            cluster_name (str): The name of the cluster to shutdown.
        """
        for cluster in cluster_name:
            # fetch job details
            jobs = core.queue(cluster, skip_finished=False, all_users=False)
            # update job details to firestore collection
            update_job_details(jobs, self.user_id)

        # core.down(cluster_name=cluster_name, purge=False)
        cli._down_or_stop_clusters(cluster_name,
                               apply_to_all=None,
                               down=True,
                               no_confirm=True,
                               purge=False)

        print(f"Cluster {cluster_name} has been shut down to prevent further costs.")
        # call function to update "TERMINATED" status in firestore
        for cluster in cluster_name:
            # update terminate cluster status
            update_clusters_status(self.user_id, "TERMINATED", cluster)

        print("copiying data from local db to firestore db..")
        # Call function to update cluster-history collection
        extract_cluster_history_data(self.user_id)




In [None]:
# @title Python Class to get the status of jobs

from sky import core
from sky.skylet import job_lib


class SkyJobManager:
    """
    A class to manage and fetch SkyPilot job information.
    """

    def __init__(self, cluster_name: str, skip_finished: bool = False, all_users: bool = False):
        """
        Initialize the SkyJobManager with the specified parameters.

        Args:
            cluster_name (str): The name of the cluster to get jobs from.
            skip_finished (bool): Optional parameter to skip finished jobs. Defaults to False.
            all_users (bool): Optional parameter to get jobs from all users. Defaults to False.
        """
        self.cluster_name = cluster_name
        self.skip_finished = skip_finished
        self.all_users = all_users

    def fetch_jobs(self):
        """
        Fetches job data from the specified cluster and formats the job data into a readable table.

        Returns:
            List[Dict]: A list of jobs retrieved from the cluster.
        """
        jobs = core.queue(self.cluster_name, skip_finished=self.skip_finished, all_users=self.all_users)
        job_table = job_lib.format_job_queue(jobs)
        return job_table


In [None]:
#@title Code for calling COG HTTP API in the VM instance

import requests
import json
from sky import core
from sky.skylet import job_lib
from sky.skylet.job_lib import JobStatus


def trigger_cog_inference_api(instance_ip):
    try:

        # Cog prediction api url
        url = f'http://{instance_ip}:5000/predictions'

        # Define the headers and data
        headers = {
            "Content-Type": "application/json",
        }

        #TODO: we can storge the input param we passed during cog file creation somewhere , and use those to define the payloads for  inference api
        image_url = input("Enter the image url for prediction:")
        # Data currently hardcoded for testing
        data = {
            "input": {
                "image_url": image_url
                }
              }

        # Step 4: Make the POST request
        response = requests.post(url, headers=headers, data=json.dumps(data),timeout=3600)

        print(response.status_code)
        response_data = response.json()
        # Filter out the desired keys
        filtered_data = {key: response_data[key] for key in ["output", "started_at", "completed_at","error","status","metrics"] if key in response_data}
        # Print the filtered values
        print(json.dumps(filtered_data, indent=4))
    except Exception as err:
        print(F"Error : {err}")


def run_inference_api(instance_ip,job_id,cluster_name):
    # Get job details
    jobs = core.queue(cluster_name, skip_finished=False, all_users=False)

    # Find the specific job by ID
    job = next((job for job in jobs if job['job_id'] == job_id), None)
    # check job is SUCCEEDED, if yes then only trigger cog http api
    if job:
        if job['status'] == JobStatus.SUCCEEDED:
            print(f"Job {job['job_id']} succeeded. Triggering Inference API.")
            trigger_cog_inference_api(instance_ip)
        else:
            print(f"Job {job['job_id']} is not completed. Status: {job['status']}. Cannot trigger Inference API.")
    else:
        print(f"Job with ID {job_id} not found.")





In [None]:
# @title Run selected task

def launch_any_task(run_command, setup_cmd, workdir, cluster_name, sky_cloud, accelerator=None, quantity=None, instance_type=None):
    user_id = generate_user_id()
    print(f"User ID: {user_id}")

    config = Config()  # Access the singleton instance
    config.set('user_id',user_id)

    import sky
    from sky import backends, optimizer

    trainer = ClusterTaskRunner(
        run=run_command,
        setup=setup_cmd,
        workdir=workdir,
        cloud=sky_cloud(),  # Update this as needed
        accelerators={accelerator:quantity},
        retry_until_up=True,
        instance_type=instance_type,
        cluster_name=cluster_name,
        optimize_target=optimizer.OptimizeTarget.COST,
        detach_setup=True,
        detach_run=True,
        ports=5000,
        user_id=user_id
    )
    job_id, handle = trainer.launch_task()
    print(f"Job ID: {job_id}")
    config.set('job_id',job_id)
    config.set('instance_ip',handle.head_ip)
    return user_id

def run_task(task_run_config):

    # # Call method to get machine config
    # selected_config = gpu_rate_ui.get_selected_config()
    # print(f"selected_config : {selected_config}")
    # selected_accelerator_name = selected_config['gpu']
    # quantity = int(selected_config['quantity'])
    # selected_instance_type = selected_config['instance_type']
    # sky_cloud = selected_config['cloud_class']
    ##
    config = Config()  # Access the singleton instance
    # Retrieve all necessary configuration values
    task_type = config.get('task_type')
    selected_accelerator_name = config.get('gpu')
    quantity = int(config.get('quantity'))
    selected_instance_type = config.get('instance_type')
    sky_cloud = config.get('cloud_class')
    ##

    if task_type == "Ready-to-Use Models":
        model = config.get('model')
        print(f"Run Ready-to-Use Models ... ")
    elif task_type == "Custom Training or Inference":
        mode = config.get('mode')
        if mode == 'Training':
            print("TRAINING ....")

            setup_cmd = 'set -e && sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" && sudo chmod +x /usr/local/bin/cog && sudo cog build -t my-train'

            # Initial run command
            run_command = 'sudo cog predict my-train'
            # Extract parameters from the config
            params = task_run_config["train_params"]
            # Create a string of key-value pairs in the format `-i key=value`
            param_string = ' '.join([f'-i {key}={value}' for key, value in params.items()])
            # Append the parameters to the run command
            run_command_with_params = f'{run_command} {param_string}'

            launch_any_task(run_command_with_params, setup_cmd, task_run_config['workdir'], task_run_config['vm_name'], sky_cloud, selected_accelerator_name, quantity, selected_instance_type)
            print("RUN TRAINING TASK IN VM")
        elif mode == 'Inference':
            print("INFERENCE ....")

            setup_cmd = 'set -e && sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" && sudo chmod +x /usr/local/bin/cog && sudo cog build -t pred'
            run_command = 'sudo docker run -d -p 5000:5000 pred'
            launch_any_task(run_command, setup_cmd, task_run_config['workdir'], task_run_config['vm_name'], sky_cloud, selected_accelerator_name, quantity, selected_instance_type)
            print("RUN INFERENCE TASK IN VM")




In [None]:
################################################################ FUNCTION CALLS #######################################################################


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# @title Run this cell to generate necessary files for running COG
# Run to create cog related file.
# User will need to input the module name and funcion name to be imported from your training dir
"""
Example input params: add this these values according to user input

INFERENCE :
Enter the work directory path:    /content/drive/MyDrive/Colab Notebooks/Traffic_Inference
Enter storage bucket name:   demai
Enter dataset path in storage bucket (e.g., path/to/dataset.zip):   traffic_dataset.zip
Enter model path in storage bucket (e.g., path/to/model.h5):       traffic_model/classifier_model.h5

Enter "done" to complete the process
"""
config = Config()  # Access the singleton instance
task_type = config.get('task_type')
if task_type == "Custom Training or Inference":
    mode = config.get('mode')
    if mode == 'Training':
        instance = CogTraining_FileGenerator()
        instance.generate_cog_files()
    elif mode == 'Inference':
        instance = CogInference_FileGenerator()
        instance.generate_cog_files()

In [None]:
# @title Run task in vm

task_run_config = {
    "train_params" : {"epochs": 5,
                      "batch_size": 10
                      },
    "workdir": "/content/drive/MyDrive/Colab Notebooks/Traffic_Inference",
    "vm_name": "gcp-instance-1"
}

run_task(task_run_config)

I 09-10 17:28:56 cloud_vm_ray_backend.py:3319] Job submitted with Job ID: [1m1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3354] [36mJob ID: [1m1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3354] To cancel the job:	[1msky cancel gcp-instance-1 1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3354] To stream job logs:	[1msky logs gcp-instance-1 1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3354] To view the job queue:	[1msky queue gcp-instance-1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] 
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] [36mCluster name: [1mgcp-instance-1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] To log into the head VM:	[1mssh gcp-instance-1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] To submit a job:		[1msky exec gcp-instance-1 yaml_file[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] To stop the cluster:	[1msky stop gcp-instance-1[0m
I 09-10 17:28:56 cloud_vm_ray_backend.py:3450] To teardown the cluster:	[1msky down gcp-instance

  return query.where(field_path, op_string, value)
  existing_doc = db.collection(u'cluster-history').where(u'user_id', u'==', user_id).where(u'cluster_hash', u'==', data['cluster_hash']).get()
  existing_doc = db.collection(u'clusters').where(u'user_id', u'==', user_id).where(u'cluster_hash', u'==', data['cluster_hash']).get()


Job ID: 1
RUN INFERENCE TASK IN VM


  existing_doc = db.collection(u'jobs').where(u'user_id', u'==', user_id).where(u'job_id', u'==', job_id).where(u'job_name', u'==', job_name).get()


In [None]:
config = Config()

# Setting a top-level key
user_id = config.get('user_id')
cluster_name = task_run_config['vm_name']


In [None]:
# @title Run only if you are doing inference
## NOTE: the training arguments like epoch, etc are currectly hardcoded for testing purpose
config = Config()
job_id = config.get('job_id')
instance_ip = config.get('instance_ip')

run_inference_api(instance_ip,job_id,cluster_name=cluster_name)


Job 1 succeeded. Triggering Inference API.
Enter the image url for prediction:https://storage.cloud.google.com/demai/traffic_inference_data/00032.png?authuser=2
200
{
    "output": {
        "Prediction": "Stop"
    },
    "started_at": "2024-09-10T17:35:06.693783+00:00",
    "completed_at": "2024-09-10T17:35:18.652973+00:00",
    "error": null,
    "status": "succeeded",
    "metrics": {
        "predict_time": 11.95919
    }
}


In [None]:
# Function to get the estimated cost report of clusters launced by user

user_id = user_id
manager = ClusterCostReport(user_id)
manager.run()

already initialised
[36m[1mClusters[0m
NAME            LAUNCHED     DURATION  RESOURCES                                   STATUS      COST/hr  COST (est.)  
gcp-instance-1  10 mins ago  10m 20s   1x GCP(g2-standard-4, cpus=4.0, {'L4': 1})  [32mUP[0m          $ 0.70   $ 0.12       
gcp-instance-1  42 mins ago  14m 14s   1x GCP(g2-standard-4, cpus=4.0, {'L4': 1})  [2mTERMINATED[0m  $ 0.70   $ 0.17       


In [None]:
# Function to terminate VM if the estimated cost exceeds the cost_limit given by user
cost_manager = ClusterCostManager(
                            cost_limit=0.14,
                            email_id='user@example.com',
                            user_id=user_id,
                            )

cost_manager.run()


In [None]:
# class to fetch job status
manager = SkyJobManager(cluster_name=cluster_name, all_users=False)
job_table = manager.fetch_jobs()
print(job_table)

ID  NAME  SUBMITTED   STARTED     DURATION  RESOURCES  STATUS     LOG                                        
1   -     6 mins ago  3 mins ago  1s        1x[L4:1]   [32mSUCCEEDED[0m  ~/sky_logs/sky-2024-09-10-17-25-11-843496  


In [None]:
## Run ClusterManager to terminate the cluster
user_id = user_id   # user_id that was creted by user initially
clusters = [cluster_name]  # cluster name to stop
manager = ClusterManager(user_id, clusters)
manager.down_clusters()

Cluster terminated successfully
Updating terminated status in firestore db
Document(s) found: 1


  query = db.collection(u'clusters').where(u'user_id', u'==', user_id).where(u'name', u'==', cluster_name)
  existing_docs = [doc for status in statuses for doc in query.where(u'status', u'==', status).get()]


Updated document Egjg2YJ1GvTJkTsOXuQH with new status: TERMINATED
copiying data from local db to firestore db..


  existing_doc = db.collection(u'cluster-history').where(u'user_id', u'==', user_id).where(u'cluster_hash', u'==', data['cluster_hash']).get()
