# Import required Libraries

In [2]:
import pandas as pd
import re, os
import nltk
from nltk.corpus import stopwords

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, accuracy_score
import pandas as pd

# Path Variables

In [3]:
ground_truth_csv = "../../data/Qualitative/NER.csv"

# Ground Truth Preparation

In [4]:
def preprocess_tweet(text: str) -> str:
    text = re.sub(r'https?://\S+', '', text)
    text = re.sub(r'^RT\s*:\s*', '', text)
    text = re.sub(r'&\w+;', ' ', text)
    text = re.sub(r'&#\d+;', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\u0600-\u06FF]', '', text)
    text = text.strip()
    return text            

In [5]:
if os.path.exists(ground_truth_csv):
    df = pd.read_csv(ground_truth_csv)

In [6]:
df

Unnamed: 0,Text,NER,LID
0,shirt,PRODUCT,English
1,wesi,O,Hindi
2,hii,O,Hindi
3,thi,O,Hindi
4,jese,O,Hindi
...,...,...,...
982,full,O,English
983,bakwas,O,Telugu
984,time,O,English
985,waste,O,English


In [7]:
list(set(df["LID"]))

['Telugu', 'Malayalam', 'Hindi', 'Kannada', 'English', 'Tamil']

# Language Identification

In [8]:
available_models = [
    ("xlmr", "xlm-roberta-base"),
    ("mdeberta", "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"),
    ("labse", "setu4993/LaBSE"),
    ("muril", "google/muril-base-cased")
]

hf_token = "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"

In [9]:
# Encode sentences
def encode_sentences(tokenizer, model, sentences, device):
    encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
    # print(encoded_input)
    encoded_input = encoded_input.to(device)
    with torch.no_grad():
        model_output = model(**encoded_input)
    return model_output.last_hidden_state[:, 0, :]

# Zero-shot prediction
def zero_shot_predict_single(text, tokenizer, model, label_embeddings, labels, device):
    text_embedding = encode_sentences(tokenizer, model, [text], device)
    cosine_similarities = F.cosine_similarity(text_embedding.unsqueeze(1), label_embeddings.unsqueeze(0), dim=2)
    predicted_index = torch.argmax(cosine_similarities, dim=1).item()
    return labels[predicted_index]

In [11]:
#  Choose the model here (1-based index): 1 = XLM-R, 2 = mDeBERTa, 3 = LaBSE, 4 = MuRIL
for index in range(1,5):
    df = pd.read_csv(ground_truth_csv)
    choose_model = index
    key, model_name = available_models[choose_model - 1]
    labels_list = ['Kannada', 'Malayalam', 'Hindi', 'English', 'Tamil', 'Telugu']
    descriptions = [
        "The text is in Kannada.",
        "The text is in Malayalam.",
        "The text is in Hindi.",
        "The text is in English.",
        "The text is in Tamil.",
        "The text is in Telugu."
    ]
    if key not in df.columns:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
        model = AutoModel.from_pretrained(model_name, use_auth_token=hf_token).to(device)
        model.eval()

        label_embeddings = encode_sentences(tokenizer, model, descriptions, device)

        predictions = []
        for idx, item in enumerate(df["Text"].tolist()):
            try:
                pred = zero_shot_predict_single(item, tokenizer, model, label_embeddings, labels_list, device)
                predictions.append(pred)
            except Exception as e:
                print(f"Exception for model {model_name} at idx {idx}: {e}")
                predictions.append("O")

        df[key] = predictions
        df.to_csv(ground_truth_csv, index=False)
        print(key)



xlmr




mdeberta




labse




muril


In [12]:
df

Unnamed: 0,Text,NER,LID,xlmr,mdeberta,labse,muril
0,shirt,PRODUCT,English,Malayalam,Telugu,Malayalam,English
1,wesi,O,Hindi,Malayalam,English,Malayalam,English
2,hii,O,Hindi,Malayalam,English,Hindi,Hindi
3,thi,O,Hindi,Malayalam,English,Tamil,Telugu
4,jese,O,Hindi,Malayalam,English,Malayalam,English
...,...,...,...,...,...,...,...
982,full,O,English,Malayalam,English,Kannada,English
983,bakwas,O,Telugu,Malayalam,English,Kannada,Hindi
984,time,O,English,Malayalam,English,Tamil,English
985,waste,O,English,Malayalam,English,Hindi,Hindi


# Evaluation

In [13]:
df = pd.read_csv(ground_truth_csv)

In [14]:
for index in range(1,5):
    choose_model = index
    key, model_name = available_models[choose_model - 1]
    all_true = df["LID"].tolist()
    all_pred = df[key].tolist()
    report = classification_report(all_true, all_pred, digits=4)
    print(f"\nCLASSIFICATION REPORT: {key}")
    print(report)


CLASSIFICATION REPORT: xlmr
              precision    recall  f1-score   support

     English     0.1875    0.0146    0.0271       410
       Hindi     0.2500    0.0079    0.0153       127
     Kannada     0.0000    0.0000    0.0000        97
   Malayalam     0.1276    0.9528    0.2251       127
       Tamil     0.0000    0.0000    0.0000       114
      Telugu     0.0000    0.0000    0.0000       112

    accuracy                         0.1297       987
   macro avg     0.0942    0.1625    0.0446       987
weighted avg     0.1265    0.1297    0.0422       987


CLASSIFICATION REPORT: mdeberta
              precision    recall  f1-score   support

     English     0.3822    0.8585    0.5289       410
       Hindi     0.0000    0.0000    0.0000       127
     Kannada     0.0000    0.0000    0.0000        97
   Malayalam     0.0000    0.0000    0.0000       127
       Tamil     0.0000    0.0000    0.0000       114
      Telugu     0.0833    0.0357    0.0500       112

    accuracy   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
