## Setup

In [None]:

!pip install -q peft sacrebleu
!git clone https://github.com/AI4Bharat/IndicTrans2
%cd /content/IndicTrans2/huggingface_interface
!bash install.sh


## Unzip Dataset

In [None]:
!unzip /content/en-indic-exp.zip -d /content/

## Finetune

In [None]:
%cd /content/IndicTrans2/huggingface_interface

fine_tuning_args = ' '.join([
    '--model_name "ai4bharat/indictrans2-en-indic-dist-200M"',
    '--direction "en-indic"',
    '--src_lang_list "eng_Latn"',
    '--tgt_lang_list "hin_Deva"',
    '--data_dir "/content/en-indic-exp"', 
    '--output_dir "output"',
    '--batch_size 4',
    '--max_steps 2000',
    '--num_workers 1',
    '--lora_r 32',
    '--lora_alpha 64',
])

!python train_lora.py {fine_tuning_args}

## LoRA Upload to 🤗HUB

Login to Huggingface Hub

In [None]:
from huggingface_hub import notebook_login
notebook_login()

Saved checkpoint will be in `IndicTrans2/huggingface_interface/output/checkpoint-{steps}`

In [None]:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel


base_ckpt_dir = "ai4bharat/indictrans2-en-indic-dist-200M"
lora_identifier = 'indictrans2-conv'
lora_ckpt_dir = ''

base_model = AutoModelForSeq2SeqLM.from_pretrained(base_ckpt_dir, trust_remote_code=True)
lora_model = PeftModel.from_pretrained(base_model, lora_ckpt_dir)
lora_model.push_to_hub(repo_id=lora_identifier)



## Inference with LoRA

In [None]:

!pip install peft
!git clone https://github.com/VarunGumma/IndicTransTokenizer.git



⚠️ *Now **Restart** the session*

In [None]:
import requests

url = "https://demo-api.models.ai4bharat.org/inference/translation/v2"
payload = {
    "controlConfig": {"dataTracking": True},
    "input": [],
    "config": {
        "serviceId": "",
        "language": {"sourceLanguage": "en", "targetLanguage": "hi"},
    },
}


def indictrans2_api(inputs):
    inputs = [{"source": i} for i in inputs]
    payload["input"] = inputs

    while True:
        response = requests.post(url, json=payload)
        resp_data = response.json()
        if "output" in resp_data:
            break

    results = []
    for output in resp_data["output"]:
        results.append(output["target"])

    return results


In [None]:
import sys
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
from peft import PeftModel


en_indic_ckpt_dir =  "ai4bharat/indictrans2-en-indic-dist-200M"
lora_ckpt_dir = 'sam749/indictrans2-conv'
BATCH_SIZE = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HALF = True if torch.cuda.is_available() else False

quantization = None


def initialize_model_and_tokenizer(ckpt_dir, direction, quantization):
    if quantization == "4-bit":
        qconfig = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "8-bit":
        qconfig = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_use_double_quant=True,
            bnb_8bit_compute_dtype=torch.bfloat16,
        )
    else:
        qconfig = None

    tokenizer = IndicTransTokenizer(direction=direction)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        ckpt_dir,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        quantization_config=qconfig,
    )


    if qconfig == None:
        model = model.to(DEVICE)
    if HALF:
        model.half()

    model.eval()

    if lora_ckpt_dir:
      lora_model = PeftModel.from_pretrained(model, lora_ckpt_dir)
      return tokenizer, lora_model

    return tokenizer, model


def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
    translations = []
    for i in range(0, len(input_sentences), BATCH_SIZE):
        batch = input_sentences[i : i + BATCH_SIZE]

        # Preprocess the batch and extract entity mappings
        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)

        # Tokenize the batch and generate input encodings
        inputs = tokenizer(
            batch,
            src=True,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        # Generate translations using the model
        with torch.inference_mode():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        # Decode the generated tokens into text
        generated_tokens = tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), src=False)

        # Postprocess the translations, including entity replacement
        translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)

        del inputs
        torch.cuda.empty_cache()

    return translations


ip = IndicProcessor(inference=True)

en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization)





In [None]:

# ---------------------------------------------------------------------------
#                              English to Hindi
# ---------------------------------------------------------------------------

en_sents = [
    "Ajay to Kritika: Hello! How can I help you?",
    "Kritika to Ajay: Hello! How can I help you?",
    "Ajay to Kritika: Did you mean 'Do you like pizza?' I don't actually eat it. Can I get a photo of it instead?",
    "Kritika to Ajay: Did you mean 'Do you like pizza?'' I don't actually eat it. Can I get a photo of it instead?"
]

src_lang, tgt_lang = "eng_Latn", "hin_Deva"
hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)
indictrans2_MTs = indictrans2_api(en_sents)


for input_sentence,lora_mt,it2_mt in zip(en_sents, hi_translations,indictrans2_MTs):
    print(f"input: {input_sentence}")
    print(f"hi_IndicTrans2: {it2_mt}")
    print(f"hi_LoRA: {lora_mt}")
    print('-'*20)
