In [3]:
import os
import sys

current_dir = os.getcwd()
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

import io
import time
import shutil
import json
import yaml
import tarfile
import requests
import pandas as pd
from pandas import DataFrame
from dotenv import load_dotenv
import logging
import time


In [4]:
load_dotenv(os.path.join(repo_dir,'.env'))
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [30]:
!snapi app list 

ASR With Diarization
Name                : ASR With Diarization
ID                  : b6aefdf7-02a4-4384-9c3c-8a81d735a54e
Playground          : False
Prediction Input    : text

ASR Without Diarization
Name                : ASR Without Diarization
ID                  : a36cc322-dd36-40e3-9641-d87ac48fe2c4
Playground          : False
Prediction Input    : file

CLIP
Name                : CLIP
ID                  : 6c14325a-1be7-4e48-b38f-19b33745fc3b
Playground          : False
Prediction Input    : text

Databox
Name                : Databox
ID                  : 199e9684-785c-4df0-8dc3-49e808d8eba5
Playground          : False
Prediction Input    : text

Deepseek 6.7B single socket
Name                : Deepseek 6.7B single socket
ID                  : 2eeb4b7f-bc56-48c4-8814-ef9d1e8806b8
Playground          : True
Prediction Input    : text

DePlot
Name                : DePlot
ID                  : 40f16b58-72a9-404f-a7c3-afc0d27a2343
Playground          : False
Prediction Input    :

In [31]:
# app id for ASR With Diarization

app_id = 'b6aefdf7-02a4-4384-9c3c-8a81d735a54e'

In [32]:
# we have to reference a json file that has the path of the data set that we'll upload

!snapi dataset add --help

[1m                                                                                [0m
[1m [0m[1;33mUsage: [0m[1msnapi dataset add [OPTIONS][0m[1m                                            [0m[1m [0m
[1m                                                                                [0m
 Add a new dataset                                                              
                                                                                
[2m╭─[0m[2m Options [0m[2m───────────────────────────────────────────────────────────────────[0m[2m─╮[0m
[2m│[0m    [1;36m-[0m[1;36m-file[0m                          [1;33mTEXT[0m                                      [2m│[0m
[2m│[0m [31m*[0m  [1;36m-[0m[1;36m-dataset[0m[1;36m-name[0m       [1;32m-n[0m         [1;33mTEXT[0m  Dataset name [2m[default: None][0m        [2m│[0m
[2m│[0m                                          [2;31m[required]  [0m                        [2m│[0m
[2m│[0m [31m*

# ASR Pipeline

In [31]:
def load_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(os.path.join(kit_dir,'config.yaml'))

PENDING_RDU_JOB_STATUS = 'PENDING_RDU'
SUCCESS_JOB_STATUS = 'EXIT_WITH_0'

In [32]:
class BatchASRProcessor():
    
    def __init__(self, config) -> None:
        self.headers = {
            'content-type': 'application/json',
            'key': os.getenv('SAMBASTUDIO_KEY'),
        }
        self.datasets_path = f".{config['asr']['datasets']['datasets_path']}"
        self.dataset_id = None
        self.dataset_name = config['asr']['datasets']['dataset_name']
        self.dataset_description = config['asr']['datasets']['dataset_description']
        self.dataset_source_type = config['asr']['datasets']['dataset_source_type']
        self.dataset_source_file = f".{config['asr']['datasets']['dataset_source_file']}"
        self.dataset_language = config['asr']['datasets']['dataset_language']
        
        self.asr_with_diarization_app_id = config['asr']['apps']['asr_with_diarization_app_id']
        self.application_field = config['asr']['apps']['application_field']
        
        self.base_url = config['asr']['urls']['base_url']
        self.datasets_url = config['asr']['urls']['datasets_url'] 
        self.projects_url = config['asr']['urls']['projects_url'] 
        self.jobs_url = config['asr']['urls']['jobs_url'] 
        self.download_results_url = config['asr']['urls']['download_results_url'] 
    
        self.project_name = config['asr']['projects']['project_name']
        self.project_description = config['asr']['projects']['project_description']
        self.project_id=None
        
        self.job_name = config['asr']['jobs']['job_name']
        self.job_task = config['asr']['jobs']['job_task']
        self.job_type = config['asr']['jobs']['job_type']
        self.job_description = config['asr']['jobs']['job_description']
        self.model_checkpoint = config['asr']['jobs']['model_checkpoint']
        
        self.output_path = config['asr']['output']['output_path']
        
        
    def _get_call(self, url, params = None, success_message = None):
        response = requests.get(url, params=params, headers=self.headers)

        if response.status_code == 200:
            logging.info('GET request successful!')
            logging.info(success_message)
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'GET request failed with status code: {response.status_code}')
            logging.error(f'Error message: {response.text}')
        return response

    def _post_call(self, url, params, success_message = None):
        response = requests.post(url, json=params, headers=self.headers)

        if response.status_code == 200:
            logging.info('POST request successful!')
            logging.info(success_message)
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'POST request failed with status code: {response.status_code}')
            raise Exception(f'Error message: {response.text}')
        return response
    
    def _delete_call(self, url):
        response = requests.delete(url, headers=self.headers)    
        if response.status_code == 200:
            logging.info(f'Dataset {self.dataset_name} deleted successfully.')
            logging.debug(f'Response: {response.text}')
        else:
            logging.error(f'Failed to delete the resource. Status code: {response.status_code}')
            raise Exception(f'Error message: {response.text}')    
        return response

    def _time_to_seconds(self, time_str):
        minutes, seconds = map(int, time_str.split(':'))
        return  minutes * 60 + seconds

    def _get_df_output(self, response_content: str) -> DataFrame:
        compressed_bytes = io.BytesIO(response_content)
        
        with tarfile.open(fileobj=compressed_bytes, mode="r:gz") as tar:
            output_tar_member = tar.getmember(self.output_path)
            output_file = tar.extractfile(output_tar_member)
            output_df = pd.read_csv(io.BytesIO(output_file.read()), names = ['audio_path', 'results_path', 'speaker', 'start_time', 'sample_duration', 'unformatted_transcript', 'formatted_transcript'])
            output_df['start_time'] = output_df.apply(lambda x: self._time_to_seconds(x['start_time']), axis = 1)
            output_df['end_time'] = output_df.apply(lambda x: x['start_time'] + int(x['sample_duration'])/16000, axis = 1)
            output_df = output_df[['start_time', 'end_time', 'speaker', 'formatted_transcript']].rename(columns={'formatted_transcript': 'text'})
        
        return output_df

    def search_dataset(self, dataset_name):
        url = self.base_url + self.datasets_url + '/search'
        params = {
            'dataset_name': dataset_name
        }
        response = self._get_call(url, params, f'Dataset {dataset_name} found in SambaStudio')
        parsed_reponse = json.loads(response.text)
        return parsed_reponse['data']['dataset_id']

    def delete_dataset(self, dataset_name):
        dataset_id = self.search_dataset(dataset_name)
        url = self.base_url + self.datasets_url + '/' + dataset_id
        response = self._delete_call(url)
        logging.info(response.text)
        
        
    def create_dataset(self, path):
                
        dataset_name = f'{self.dataset_name}_{int(time.time())}'
        
        # create pca directory and source.json file
        pca_directory = self.datasets_path + '/' + dataset_name
        
        if not os.path.isdir(self.datasets_path):
            os.mkdir(self.datasets_path) 
            
        if not os.path.isdir(pca_directory):
            logging.info(f'Datasets path: {pca_directory} wan\'t found')
            
            source_file_data = {
                "source_path": pca_directory
            }
            with open(self.dataset_source_file, 'w') as json_file:
                json.dump(source_file_data, json_file)
            os.mkdir(pca_directory)
            
            logging.info(f'PCA Directory: {pca_directory} created')
    
        # validate audio file
        audio_format = path.split('.')[-1]
        
        if audio_format == 'mp3':
            shutil.copyfile(path, pca_directory + '/pca_file.mp3')
        elif audio_format == 'wav':
            shutil.copyfile(path, pca_directory + '/pca_file.wav')
        else:
            raise Exception('Only mp3 and wav audio files supported')
        
        # create dataset
        command = f'echo yes | snapi dataset add \
            --dataset-name {dataset_name} \
            --job_type {self.job_type} \
            --apps {self.asr_with_diarization_app_id} \
            --source_type {self.dataset_source_type} \
            --source_file {self.dataset_source_file} \
            --application_field {self.application_field} \
            --language {self.dataset_language} \
            --description "{self.dataset_description}"'
        
        os.system(command)
        logging.info(f'Creating dataset: {dataset_name}')
        
        return dataset_name
                
    def check_dataset_creation_progress(self, dataset_name):
        url = self.base_url + self.datasets_url + '/' + dataset_name
        response = self._get_call(url)
        if response.json()["data"]["status"]=="Available": 
            return True
        else:
            return False
            
    def create_load_project(self):

        url = self.base_url + self.projects_url + '/' + self.project_name

        response = self._get_call(url, success_message=f'Project {self.project_name} found in SambaStudio')
        not_found_error_message = f"{self.project_name} not found"

        if not_found_error_message in response.text:
            
            logging.info(f'Project {self.project_name} wasn\'t found in SambaStudio')
            
            url = self.base_url + self.projects_url

            params = {
                'project_name': self.project_name,
                'description': self.project_description
            }

            response = self._post_call(url, params, success_message=f'Project {self.project_name} created!')

        parsed_reponse = json.loads(response.text)
        self.project_id = parsed_reponse['data']['project_id']
        return self.project_id
    
    def run_job(self, dataset_name):
        
        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id)
        
        params = {
            'task': self.job_task,
            'job_type': self.job_type,
            'job_name': f'{self.job_name}_{int(time.time())}',
            'project': self.project_id,
            'model_checkpoint': self.model_checkpoint,
            'description': self.job_description,
            'dataset': dataset_name,
        }

        response = self._post_call(url, params, success_message='Job running')
        parsed_reponse = json.loads(response.text)
        job_id = parsed_reponse['data']['job_id']
        
        return job_id
    
    def check_job_progress(self, job_id):

        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id

        status = PENDING_RDU_JOB_STATUS
        while status != SUCCESS_JOB_STATUS:
            response = self._get_call(url, success_message='Still waiting for job to finish')
            parsed_reponse = json.loads(response.text)   
            status = parsed_reponse['data']['status']
            logging.info(f'Job status: {status}')
            if status == SUCCESS_JOB_STATUS:
                logging.info('Job finished!')
                break
            time.sleep(10)
        
        return True
    
    def delete_job(self, job_id):
        url = self.base_url +  self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id
        response = self._delete_call(url)
        logging.info(response.text)
        
    def retrieve_results(self, job_id):
        url = self.base_url + self.projects_url + self.jobs_url.format(project_id=self.project_id) + '/' + job_id + self.download_results_url
        response = self._get_call(url, success_message='Results downloaded!')
        df = self._get_df_output(response.content)
        return df

In [33]:
asr = BatchASRProcessor(config)

In [34]:
asr.create_load_project()

INFO:root:GET request successful!
INFO:root:Project PCA_Project found in SambaStudio


'2d49f0cd-d807-488b-a527-c2dd43377dd9'

In [35]:
dataset_name = asr.create_dataset(path=os.path.join(kit_dir,'data/conversations/audio/911_call.wav'))
while not asr.check_dataset_creation_progress(dataset_name):
    print("waiting for dataset available")
    time.sleep(1)

INFO:root:Datasets path: ../data/datasets/PCA_dataset_1709326005 wan't found
INFO:root:PCA Directory: ../data/datasets/PCA_dataset_1709326005 created



Folder Information:
  - Number of Files: 1
  - Total Size: 7.12 MB

Are you sure you want to proceed? ([33myes[0m/no)
: Uploading files


INFO:root:Creating dataset: PCA_dataset_1709326005


Dataset folder upload complete: ../data/datasets/PCA_dataset_1709326005
Dataset added successfully.
Time taken to upload the dataset: 3.8197240829467773 seconds


INFO:root:GET request successful!
INFO:root:None


In [36]:
asr.search_dataset(dataset_name)

INFO:root:GET request successful!
INFO:root:Dataset PCA_dataset_1709326005 found in SambaStudio


'b125ec06-4c05-4960-901d-a811a6e01954'

In [37]:
job_id = asr.run_job(dataset_name)

INFO:root:POST request successful!
INFO:root:Job running


In [38]:
result = asr.check_job_progress(job_id) 

INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PENDING_RDU
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET request successful!
INFO:root:Still waiting for job to finish
INFO:root:Job status: PREDICTING
INFO:root:GET re

In [39]:
df = asr.retrieve_results(job_id)
df

INFO:root:GET request successful!
INFO:root:Results downloaded!


Unnamed: 0,start_time,end_time,speaker,text
0,0,2.5,SPEAKER_01,Our Primeti 33. What is yet as emergency?
1,3,11.9,SPEAKER_00,"Yes, sir, I need to, uh, uh. I need an ambulan..."
2,11,12.3,SPEAKER_01,A car.
3,13,14.7,SPEAKER_00,Carol Wood Drive. Yes.
4,14,14.6,SPEAKER_01,
5,15,16.6,SPEAKER_00,"Yeah, a."
6,17,21.4,SPEAKER_01,"Okay, sir. What's the phone number you calling..."
7,22,31.6,SPEAKER_00,"Uh, sir. Oh, I have a we have a a gentleman he..."
8,32,37.2,SPEAKER_01,"Okay, how does. He's a 50 years old. Ser 50. O..."
9,38,39.5,SPEAKER_00,"Yes, he's not breathing, sir."


In [40]:
asr.delete_job(job_id)

INFO:root:Dataset PCA_dataset deleted successfully.
INFO:root:{}


In [41]:
asr.delete_dataset(dataset_name)

INFO:root:GET request successful!
INFO:root:Dataset PCA_dataset_1709326005 found in SambaStudio
INFO:root:Dataset PCA_dataset deleted successfully.
INFO:root:{"detail":"The Dataset: b125ec06-4c05-4960-901d-a811a6e01954 was successfully marked for deletion from the Dataset Hub."}
