In [1]:
# !chmod 777 install_reqs.sh
# !./install_reqs.sh


In [2]:
# For ASR
from models.complaint_model import ComplaintModel
import torch
from transformers import AutoModelForCTC, AutoProcessor
import torchaudio.functional as F
import pyctcdecode
from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, pipeline
import soundfile
import torchaudio
import requests




In [3]:
print("starting backend service")

starting backend service


In [4]:
# ASR Service


DEVICE_ID = "cuda" if torch.cuda.is_available() else "cpu"
asr_langs={'hindi':'hi'}
asr_model_ids={'hindi':"ai4bharat/indicwav2vec-hindi"}
asr_models={}
asr_processors = {}

def load_asr_models():
    for lang in asr_langs:
        asr_models[lang]=AutoModelForCTC.from_pretrained(asr_model_ids[lang]).to(DEVICE_ID)
        asr_processors[lang]=Wav2Vec2Processor.from_pretrained(asr_model_ids[lang])


    return


def load_audio_from_url(url):
    local_filename = url.split('/')[-1].split("%2F")[1].split("?")[0]
    # NOTE the stream=True parameter below
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    file_path=local_filename
    waveform, sample_rate = torchaudio.load(file_path)
    num_channels, _ = waveform.shape
    if num_channels == 1:
        return waveform[0], sample_rate
    else:
        raise ValueError("Waveform with more than 1 channels are not supported.")    

def convert_stt(complaint:ComplaintModel):
    if complaint.lang not in asr_models:
        print(f"---------loading asr model for {complaint.lang}----------")
        asr_models[complaint.lang]=AutoModelForCTC.from_pretrained(asr_model_ids[complaint.lang]).to(DEVICE_ID)
        asr_processors[complaint.lang]=Wav2Vec2Processor.from_pretrained(asr_model_ids[complaint.lang])
    #Load from url
    waveform, sample_rate = load_audio_from_url(complaint.audioURL)
    resampled_audio = torchaudio.functional.resample(waveform, sample_rate, 16000)
    input_values = asr_processors[complaint.lang](resampled_audio, return_tensors="pt").input_values

    with torch.no_grad():
        logits = asr_models[complaint.lang](input_values.to(DEVICE_ID)).logits.cpu()
    
    prediction_ids = torch.argmax(logits, dim=-1)
    output_str = asr_processors[complaint.lang].batch_decode(prediction_ids)[0]
    print(f"Greedy Decoding: {output_str}")
    return output_str



In [5]:
# Translation Service
# code for translation

import torch
from transformers import AutoModelForSeq2SeqLM
from IndicTransTokenizer import IndicProcessor, IndicTransTokenizer
import models.complaint_model

trans_langs={"hindi":'hin_Deva'}
trans_tokenizer = IndicTransTokenizer(direction="indic-en")
trans_ip = IndicProcessor(inference=True)
trans_model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-indic-en-dist-200M", trust_remote_code=True)

def translate(complaint):
    input=[complaint.complaintText]
        
    batch = trans_ip.preprocess_batch(input, src_lang=trans_langs[complaint.lang], tgt_lang="eng_Latn")
    batch = trans_tokenizer(batch, src=True, return_tensors="pt")

    with torch.inference_mode():
        outputs = trans_model.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256)

    outputs = trans_tokenizer.batch_decode(outputs, src=False)
    outputs = trans_ip.postprocess_batch(outputs, lang=trans_langs[complaint.lang])
    print(outputs)
    return outputs[0]



  return self.fget.__get__(instance, owner)()


In [7]:
# Finetuned ft5 response/listener function
from peft import AutoPeftModelForSeq2SeqLM,PeftModel
from transformers import AutoTokenizer,AutoConfig,T5ForConditionalGeneration,AutoModelForSeq2SeqLM
import torch
import numpy as np
import pandas as pd
import subprocess
import time
from finetune.instruction import instruction

# wd_path='E:\ETC\Progress\Projects\web_app\mh_bhasha3\cms_backend'
wd_path='~/mh_bhasha3/cms_backend'

use_ft_model=True
base_model_path="google/flan-t5-xl"

if use_ft_model:
    ft5_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    # from tokenizers import AddedToken
    # stokens_v3=["{","}","<","`","\\"]
    # stokens=stokens_v3
    # for st in stokens:
    #     ft5_tokenizer.add_tokens(AddedToken(st, normalized=False),special_tokens=False)
    
    lcp_path='./finetune/ft_models/flan-t5-xl-mt5-v1/checkpoint-22900'
    model_path=lcp_path
    # ft5_model = AutoPeftModelForSeq2SeqLM(model,)
    # Load base model
    base_model = AutoModelForSeq2SeqLM.from_pretrained( base_model_path,)
    base_model.resize_token_embeddings(len(ft5_tokenizer))
     # Load PEFT model
    ft5_model = PeftModel.from_pretrained(model=base_model, model_id =lcp_path,)
    
 

    
else:
    model_path=base_model_path
    ft5_model = AutoModelForSeq2SeqLM.from_pretrained(model_path, )
    ft5_tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    



def categorize(complaint):
    print("-------------generating response-----------------")
    context=complaint.complaintText
    # context=context.replace('\n',' \\n ')
    # context=context.replace('\t',' \\t ')

    input=f'complaint:{context}'
    # prompts=[[input,instruction]]
    prompts=[[f'Input:\n{input}\n\n',f'Instruction:\n{instruction}',]]
    
    res=[]
    input_ids = ft5_tokenizer(prompts, return_tensors="pt" ,padding=True,truncation=True, max_length=512).input_ids
    start_time = time.time()
    outputs = ft5_model.generate(input_ids=input_ids, do_sample=True, max_length=150)
    pt=time.time() - start_time
    print(f"--------------time taken by model for generating response={pt} seconds")

    res+=ft5_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return  { 'output': f'{res[0]}' }


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [None]:
# input="complaint:my train was late, my pnr is 2345"

# from finetune.instruction import instruction
# # categories=['Train','station','suggestion','enquiry','staff']
# # instruction=f"""Identify the following from the complaint:
# # 1)category of the complaint from the following categories:{categories.__str__()}.
# # 2)a priority for the complaint.
# # 3)extract attributes from the complaint.
# # 4)identify sentiment from the complaint.
# # return the output in format: tags:[],priority:[],attributes:[],sentiment:[]
# # """
# prompts=[[input,instruction]]
# prompts=[[f'Input:\n{input}\n\n',f'Instruction:\n{instruction}',]]

# res=[]
# input_ids = ft5_tokenizer(prompts, return_tensors="pt" ,padding=True,truncation=True, max_length=512).input_ids
# start_time = time.time()
# outputs = ft5_model.generate(input_ids=input_ids, do_sample=True, max_length=150)
# pt=time.time() - start_time
# print(f"--------------time taken by model for generating response={pt} seconds")

# res+=ft5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
# print(res)

In [8]:
# listener function
from models.complaint_model import ComplaintModel
# from cms_backend.utils.ft5 import categorize
# from cms_backend.utils.indictrans2 import translate
# from cms_backend.utils.indicwav2vec import convert_stt
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import datetime
import json
import threading
import time

from pandas import Timestamp



def listen_msgs():
    coll_name='mh_bhasha3'
    user_uid='mh_bhasha3_user'
        
    # Use a service account.
    cred = credentials.Certificate(f'../keys/sa.json')
    if not firebase_admin._apps:
        app = firebase_admin.initialize_app(cred)

    db = firestore.client()
    
    # Create an Event for notifying main thread.
    callback_done = threading.Event()

    # Create a callback on_snapshot function to capture changes
    def on_snapshot(doc_snapshot, changes, read_time):
        for doc in doc_snapshot:
            print(f"Received document snapshot: {doc.id}")
            print(doc.to_dict())
            complaint=ComplaintModel.from_dict(doc.to_dict())
            complaint.lang=complaint.lang.lower()
            # complaint.complaintText=complaint.complaintText.toString()
            # data=doc.to_dict()
            if complaint.audioURL!='':
                complaint.complaintText=convert_stt(complaint=complaint)
            if complaint.lang!='english':
                complaint.complaintText=translate(complaint=complaint)

            pred_res=categorize(complaint)
            # pred_res={'output':"model response"}

            complaint.senderId='backend@red'
            complaint.output=pred_res['output']
            # data['priority']=pred_res['output'].split(':')[-1]
            complaint.ots=datetime.datetime.utcnow()
            print(f'sending response: '+complaint.output)
            # doc_id=str(round(time.time() * 1000))
            doc_id=complaint.id
            db.collection(coll_name).document(user_uid).collection('pComplaints').document(doc_id).set(complaint.__dict__)
            print(f"----------sent----------------with id: {doc_id} ")

        callback_done.set()

    doc_ref = db.collection(coll_name).document(user_uid).collection('rComplaints').document("complaint")

    # Watch the document
    doc_watch = doc_ref.on_snapshot(on_snapshot)
    print("listening for messages...",)
    
    

    while True:
        print('', end='', flush=True)
        time.sleep(1)



In [None]:
# calling the listener function
try:
    listen_msgs()
except Exception as e:
    print(f'error occured:\n\n\n {e}')  


listening for messages...
Received document snapshot: complaint
{'audioURL': '', 'ots': None, 'id': '1712418185632', 'complaintText': 'my train with pnr 12345 got late', 'senderId': 'mhs@red', 'lang': 'Hindi', 'output': '', 'timestamp': DatetimeWithNanoseconds(2024, 4, 6, 15, 43, 5, 632000, tzinfo=datetime.timezone.utc)}
['my train with pnr 12345 got late']
-------------generating response-----------------
--------------time taken by model for generating response=38.16107392311096 seconds
sending response: tags:['train'],priority:['2'],attributes:['pnr=12345'],sentiment:['negative']
----------sent----------------with id: 1712418185632 
Received document snapshot: complaint
{'audioURL': '', 'ots': None, 'id': '1712418488261', 'complaintText': 'my train with pnr 568 was late', 'senderId': 'mhs@red', 'lang': 'Hindi', 'timestamp': DatetimeWithNanoseconds(2024, 4, 6, 15, 48, 8, 261000, tzinfo=datetime.timezone.utc), 'output': ''}
['my train with pnr 568 was late']
-------------generating re

In [None]:
# !pip install soundfile 
# !pip install cffi -t /home/u5f8a87982dcb10536c03fdd4afd4637/.local/lib/python3.9/site-packages
# !pip3 uninstall -y cffi
# import soundfile

# !pip show sox
# !pip install pysoundfile
# import soundfile
# import cffi