In [4]:
pip install "transformers==4.40.2"


Collecting transformers==4.40.2
  Downloading transformers-4.40.2-py3-none-any.whl.metadata (137 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.40.2)
  Downloading tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.7 kB)
Downloading transformers-4.40.2-py3-none-any.whl (9.0 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m962.3 kB/s[0m  [33m0:00:09[0m[0m eta [36m0:00:01[0m0:01[0m:02[0m
[?25hDownloading tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl (2.4 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m1.0 MB/s[0m  [33m0:00:02[0m1.0 MB/s[0m eta [36m0:00:01[0m:01[0m
[?25hInstalling collected packages: tokenizers, transformers
[2K  Attempting uninstall: tokenizers
[2K    Found existing installation: tokenizers 0.22.0
[2K    Uninstalling tokenizers-0.22.0:
[2K      Successfully uninstalled tokenizers-0.22.0
[2K  Attempting uninstall: trans

In [2]:
pip install epitran


Collecting epitran
  Downloading epitran-1.34.0-py3-none-any.whl.metadata (36 kB)
Collecting panphon>=0.20 (from epitran)
  Downloading panphon-0.22.2-py2.py3-none-any.whl.metadata (15 kB)
Collecting marisa-trie (from epitran)
  Downloading marisa_trie-1.3.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting jamo (from epitran)
  Downloading jamo-0.4.1-py3-none-any.whl.metadata (2.3 kB)
Collecting unicodecsv (from panphon>=0.20->epitran)
  Downloading unicodecsv-0.14.1.tar.gz (10 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting editdistance (from panphon>=0.20->epitran)
  Downloading editdistance-0.8.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (3.9 kB)
Collecting munkres (from panphon>=0.20->epitran)
  Downloading munkres-1.1.4-py2.py3-none-any.whl.metadata (980 bytes)
Downloading epitran-1.34.0-py3-none-any.whl (222 kB)
Downloading panphon-0.22.2-py2.py3-none-any.whl (78 kB)
Downloading editdistance-0.8.1-cp311-cp311-macosx_11_0_arm64.whl (79 kB)
Downloading j

In [16]:
# Install required packages if not already installed
!pip install indic-transliteration googletrans==4.0.0-rc1 g2p_en



In [5]:
# %%
# Basic imports
import os
import random
import time
from pathlib import Path
import math
from collections import Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoConfig, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
from tqdm.auto import tqdm

# Set seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

# Paths
TRAIN_PATH = 'training.xlsx'
TEST_PATH = 'testing.xlsx'


Using device: cpu


In [6]:
# %%
def load_excel(path):
    df = pd.read_excel(path)
    return df

train_df = load_excel(TRAIN_PATH)
test_df = load_excel(TEST_PATH)

print('Train shape:', train_df.shape)
print('Test shape:', test_df.shape)

# Check required columns
for df in [train_df, test_df]:
    for col in ['text', 'label', 'lang']:
        if col not in df.columns:
            raise ValueError(f'Column `{col}` missing in dataset')
    df.dropna(subset=['text'], inplace=True)
    df['lang'] = df['lang'].str.lower().str.strip()

# Class balance
print(train_df['label'].value_counts(normalize=True))


Train shape: (2400, 3)
Test shape: (600, 3)
label
0    0.5
1    0.5
Name: proportion, dtype: float64


In [7]:
# %%
# Try to import helpful libraries
try:
    from indic_transliteration.sanscript import transliterate as it_transliterate, SCHEMES
    HAS_INDIC = True
except:
    HAS_INDIC = False

try:
    import epitran
    HAS_EPITRAN = True
except:
    HAS_EPITRAN = False

# Transliteration
def transliterate_text(lang, text):
    if lang == 'en': return text
    if HAS_INDIC:
        scheme_map = {'te':'telugu', 'ml':'malayalam'}
        try:
            return it_transliterate(text, scheme_map.get(lang,'iast'), 'iast')
        except:
            return text
    import unicodedata
    return ''.join(c for c in unicodedata.normalize('NFD', text) if unicodedata.category(c) != 'Mn')

# G2P mapping
def g2p_map(lang, text):
    if lang == 'en':
        return ' '.join(list(text.replace(' ', '_')))
    if HAS_EPITRAN:
        try:
            epi = epitran.Epitran({'te':'tel-Deva','ml':'mal-Mlym'}.get(lang,'eng-Latn'))
            phons = epi.transliterate(text)
            return ' '.join(list(phons))
        except:
            return ' '.join(list(text.replace(' ','_')))
    else:
        return ' '.join(list(text.replace(' ','_')))

# English gloss fallback (original text if translation unavailable)
def english_gloss(lang, text):
    return text  # For now, placeholder (can integrate Helsinki-NLP translation if desired)

# Precompute columns
for df in [train_df, test_df]:
    df['translit'] = df.apply(lambda r: transliterate_text(r['lang'], str(r['text'])), axis=1)
    df['phonemes'] = df.apply(lambda r: g2p_map(r['lang'], r['translit']), axis=1)
    df['gloss'] = df['text']  # simple fallback


In [10]:
# Save preprocessed data
train_df.to_pickle('train_preprocessed.pkl')
test_df.to_pickle('test_preprocessed.pkl')
print("✅ Preprocessed data saved.")


✅ Preprocessed data saved.


In [None]:
# Load preprocessed data
#train_df = pd.read_pickle('train_preprocessed.pkl')
#test_df = pd.read_pickle('test_preprocessed.pkl')
#print("✅ Preprocessed data loaded.")


In [8]:
# %%
PHONEME_VOCAB_SIZE = 5000

def build_phoneme_vocab(series_phonemes, max_vocab=PHONEME_VOCAB_SIZE-2):
    ctr = Counter()
    for s in series_phonemes:
        ctr.update(s.strip().split())
    most = ctr.most_common(max_vocab)
    idx2tok = ['<pad>','<unk>'] + [w for w,_ in most]
    tok2idx = {t:i for i,t in enumerate(idx2tok)}
    return tok2idx, idx2tok

phon_tok2idx, phon_idx2tok = build_phoneme_vocab(train_df['phonemes'].tolist())
PHON_VOCAB_SIZE = len(phon_tok2idx)
print('Phoneme vocab size:', PHON_VOCAB_SIZE)


Phoneme vocab size: 161


In [11]:
# %%
TEXT_MODEL_NAME = 'xlm-roberta-base'
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_config = AutoConfig.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(DEVICE)

# Adapter module
class Adapter(nn.Module):
    def __init__(self, hidden_size, bottleneck=256):
        super().__init__()
        self.down = nn.Linear(hidden_size, bottleneck)
        self.act = nn.ReLU()
        self.up = nn.Linear(bottleneck, hidden_size)
    def forward(self, x):
        return self.up(self.act(self.down(x)))

text_adapter = Adapter(text_config.hidden_size, bottleneck=128).to(DEVICE)

# Phoneme encoder
PHONEME_EMB_DIM = 256
PHONEME_NHEAD = 8
PHONEME_NLAYERS = 3

class PhonemeEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=PHONEME_EMB_DIM, nhead=PHONEME_NHEAD, nlayers=PHONEME_NLAYERS):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
    def forward(self, input_ids):
        emb = self.embedding(input_ids).transpose(0,1)
        out = self.transformer(emb).transpose(0,1)
        mask = (input_ids!=0).unsqueeze(-1).float()
        summed = (out*mask).sum(1)
        lengths = mask.sum(1).clamp(min=1.0)
        return summed/lengths

phoneme_encoder = PhonemeEncoder(vocab_size=PHON_VOCAB_SIZE).to(DEVICE)


model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]



In [44]:
# %%
MAX_TEXT_LEN = 128
MAX_PH_LEN = 128
lang2idx = {'en':0,'te':1,'ml':2}

def phoneme_tokenize(ph_str, tok2idx, max_len=MAX_PH_LEN):
    toks = ph_str.strip().split()
    ids = [tok2idx.get(t, tok2idx['<unk>']) for t in toks][:max_len]
    if len(ids) < max_len:
        ids += [tok2idx['<pad>']]*(max_len-len(ids))
    return ids

from torch.utils.data import Dataset
import torch

class AbuseDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]

        # use 'text' instead of 'translit'
        text = str(r['text'])

        # if phonemes column exists, use it; otherwise generate or fill zeros
        if 'phonemes' in self.df.columns:
            phon_seq = phoneme_tokenize(r['phonemes'], phon_tok2idx)
        else:
            # fallback: create dummy phoneme ids of max length
            phon_seq = [0] * MAX_TEXT_LEN  

        # text encoding (same as before)
        enc = text_tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=MAX_TEXT_LEN,
            return_tensors='pt'
        )

        # label (convert to float for BCE, or long for CE — adjust as needed)
        label = torch.tensor(int(r['label']), dtype=torch.float)

        # language ID mapping
        lid = torch.tensor(lang2idx.get(r['lang'], 0), dtype=torch.long)

        # convert phoneme list to tensor
        ph_ids = torch.tensor(phon_seq, dtype=torch.long)

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'phon_ids': ph_ids,
            'label': label,
            'lang_id': lid
        }


train_dataset = AbuseDataset(train_df)
valid_dataset = AbuseDataset(test_df)

BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)


In [45]:
# %%
class CrossAttentionFusion(nn.Module):
    def __init__(self, text_dim, phon_dim, hidden_dim=512, nhead=8):
        super().__init__()
        self.proj_text = nn.Linear(text_dim, hidden_dim)
        self.proj_phon = nn.Linear(phon_dim, hidden_dim)
        self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads=nhead, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
    def forward(self, text_feats, phon_feats):
        q = self.proj_text(text_feats)
        kv = self.proj_phon(phon_feats).unsqueeze(1)
        attn_out, _ = self.cross_attn(q, kv, kv)
        out = self.norm(attn_out + q)
        return out.squeeze(1)

class AbuseModel(nn.Module):
    def __init__(self, text_model, text_adapter, phoneme_encoder, n_langs=3, hidden_dim=512):
        super().__init__()
        self.text_model = text_model
        self.text_adapter = text_adapter
        self.phoneme_encoder = phoneme_encoder
        self.fusion = CrossAttentionFusion(text_dim=text_config.hidden_size, phon_dim=PHONEME_EMB_DIM, hidden_dim=hidden_dim)
        self.abuse_head = nn.Linear(hidden_dim, 1)
        self.lid_head = nn.Linear(hidden_dim, n_langs)
        self.reconstructor = nn.Sequential(nn.Linear(hidden_dim, text_config.hidden_size), nn.ReLU(), nn.Linear(text_config.hidden_size, text_config.hidden_size))

    def forward(self, input_ids, attention_mask, phon_ids):
        out = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last = out.last_hidden_state
        mask = attention_mask.unsqueeze(-1)
        pooled = (last*mask).sum(1)/mask.sum(1).clamp(min=1.0)
        adapted = pooled + self.text_adapter(pooled)
        phon_pooled = self.phoneme_encoder(phon_ids.to(pooled.device))
        fused = self.fusion(adapted.unsqueeze(1), phon_pooled)
        abuse_logits = self.abuse_head(fused).squeeze(-1)
        lid_logits = self.lid_head(fused)
        recon_text_embed = self.reconstructor(fused)
        return {'abuse_logits': abuse_logits,'lid_logits':lid_logits,'text_embed':pooled,'fused':fused,'recon_text_embed':recon_text_embed,'phon_embed':phon_pooled}

model = AbuseModel(text_model, text_adapter, phoneme_encoder).to(DEVICE)


In [46]:
# %%
# Focal BCE
class FocalBCELoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, logits, targets):
        prob = torch.sigmoid(logits)
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        p_t = prob*targets + (1-prob)*(1-targets)
        alpha_factor = self.alpha*targets + (1-self.alpha)*(1-targets)
        mod_factor = (1.0 - p_t)**self.gamma
        return (alpha_factor*mod_factor*bce).mean()

focal_loss_fn = FocalBCELoss()
ce_loss_fn = nn.CrossEntropyLoss()
mse_loss_fn = nn.MSELoss()

def compute_metrics(y_true, y_pred_probs, threshold=0.5):
    y_pred = (np.array(y_pred_probs) >= threshold).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    acc = accuracy_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp/(tp+fn) if tp+fn>0 else 0
    specificity = tn/(tn+fp) if tn+fp>0 else 0
    return {'macro_precision':precision,'macro_recall':recall,'macro_f1':f1,'accuracy':acc,'sensitivity':sensitivity,'specificity':specificity,'error_rate':1-acc}


In [47]:
# %%
EPOCHS = 6
lr = 2e-5
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)
total_steps = len(train_loader)*EPOCHS
warmup_steps = int(0.06*total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)


In [48]:
# %%
def train_one_epoch(model, loader, optimizer, scheduler):
    model.train()
    losses = []
    for batch in tqdm(loader, desc='train'):
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        phon_ids = batch['phon_ids'].to(DEVICE)
        labels = batch['label'].to(DEVICE)
        lids = batch['lang_id'].to(DEVICE)

        out = model(input_ids, attention_mask, phon_ids)
        loss_abuse = focal_loss_fn(out['abuse_logits'], labels)
        loss_lid = ce_loss_fn(out['lid_logits'], lids)
        loss_rt = mse_loss_fn(out['recon_text_embed'], out['text_embed'])
        loss = loss_abuse + 0.5*loss_lid + 0.5*loss_rt

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),5.0)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
    return np.mean(losses)

def evaluate(model, loader):
    model.eval()
    all_probs, all_labels = [],[]
    with torch.no_grad():
        for batch in tqdm(loader, desc='eval'):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            phon_ids = batch['phon_ids'].to(DEVICE)
            labels = batch['label'].numpy()
            out = model(input_ids, attention_mask, phon_ids)
            probs = torch.sigmoid(out['abuse_logits']).cpu().numpy()
            all_probs.extend(probs.tolist())
            all_labels.extend(labels.tolist())
    return compute_metrics(all_labels, all_probs)


In [49]:
import pandas as pd
df = pd.read_excel('training.xlsx')
print(df.columns)



Index(['text', 'label', 'lang'], dtype='object')


In [50]:
# -------------------------
# Utilities, checkpointing, logging, inference, adversarial routines
# -------------------------
import json
from sklearn.metrics import classification_report
import os
from copy import deepcopy

CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(state, name='latest.pth'):
    path = os.path.join(CHECKPOINT_DIR, name)
    torch.save(state, path)
    print(f"Saved checkpoint: {path}")

def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location=DEVICE):
    checkpoint = torch.load(path, map_location=map_location)
    model.load_state_dict(checkpoint['model_state'])
    if optimizer and 'optim_state' in checkpoint:
        optimizer.load_state_dict(checkpoint['optim_state'])
    if scheduler and 'sched_state' in checkpoint:
        scheduler.load_state_dict(checkpoint['sched_state'])
    print(f"Loaded checkpoint from {path}")
    return checkpoint

# Simple logger
def save_metrics(metrics, fname="metrics.json"):
    with open(os.path.join(CHECKPOINT_DIR, fname), "w") as f:
        json.dump(metrics, f, indent=2)

# -------------------------
# Adversarial: FGM-style for embeddings
# -------------------------
class FGM:
    """
    Fast Gradient Method (single-step) applied to embedding parameters.
    We will copy and add small perturbation to embedding weights for a forward pass.
    This supports both text model embeddings and phoneme embeddings.
    """
    def __init__(self, model, epsilon=1e-3):
        self.model = model
        self.epsilon = epsilon
        self.backup = {}

    def attack_embedding(self, emb_param_name_substr='embeddings'):
        """
        Adds perturbation to any parameter whose name contains emb_param_name_substr.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_param_name_substr in name:
                if param.grad is None:
                    continue
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0:
                    r_at = self.epsilon * param.grad / (norm + 1e-8)
                    param.data.add_(r_at)

    def restore(self):
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}

# -------------------------
# Training + evaluation loops (with optional adversarial step)
# -------------------------
import datetime
def train_loop(model, train_loader, valid_loader, optimizer, scheduler,
               epochs=EPOCHS, early_stop_patience=3, adv_train=False,
               adv_epsilon=1e-3, grad_clip=5.0, save_every=1):
    best_val_f1 = -1.0
    best_epoch = -1
    no_improve = 0
    metrics_history = {'train_loss':[], 'val':[]}
    fgm = FGM(model, epsilon=adv_epsilon)

    for epoch in range(1, epochs+1):
        t0 = time.time()
        model.train()
        train_losses = []
        for batch in tqdm(train_loader, desc=f"Epoch {epoch} train"):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            phon_ids = batch['phon_ids'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            lids = batch['lang_id'].to(DEVICE)

            out = model(input_ids, attention_mask, phon_ids)
            loss_abuse = focal_loss_fn(out['abuse_logits'], labels)
            loss_lid = ce_loss_fn(out['lid_logits'], lids)
            loss_rt = mse_loss_fn(out['recon_text_embed'], out['text_embed'])
            loss = loss_abuse + 0.5*loss_lid + 0.5*loss_rt

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            train_losses.append(loss.item())

            # Optional adversarial training: single-step FGM on embedding grads
            if adv_train:
                # ensure grads exist
                # attack text model embeddings and phoneme embeddings
                fgm.attack_embedding('embeddings')   # will match huggingface embedding param names
                # forward with perturbed embeddings
                out_adv = model(input_ids, attention_mask, phon_ids)
                loss_abuse_adv = focal_loss_fn(out_adv['abuse_logits'], labels)
                loss_lid_adv = ce_loss_fn(out_adv['lid_logits'], lids)
                loss_rt_adv = mse_loss_fn(out_adv['recon_text_embed'], out_adv['text_embed'])
                loss_adv = loss_abuse_adv + 0.5*loss_lid_adv + 0.5*loss_rt_adv

                optimizer.zero_grad()
                loss_adv.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                fgm.restore()

        avg_train_loss = np.mean(train_losses)
        metrics_history['train_loss'].append(avg_train_loss)

        # validation
        val_metrics = evaluate(model, valid_loader)
        metrics_history['val'].append(val_metrics)
        print(f"[{datetime.datetime.now().isoformat()}] Epoch {epoch} finished in {time.time()-t0:.1f}s; train_loss={avg_train_loss:.4f} val_f1={val_metrics['macro_f1']:.4f} acc={val_metrics['accuracy']:.4f}")

        # save checkpoint
        if epoch % save_every == 0:
            ck_name = f"epoch_{epoch}.pth"
            save_checkpoint({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optim_state': optimizer.state_dict(),
                'sched_state': scheduler.state_dict() if scheduler else None,
                'val_metrics': val_metrics
            }, name=ck_name)

        # early stopping on macro_f1
        if val_metrics['macro_f1'] > best_val_f1 + 1e-5:
            best_val_f1 = val_metrics['macro_f1']
            best_epoch = epoch
            no_improve = 0
            # save best
            save_checkpoint({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optim_state': optimizer.state_dict(),
                'sched_state': scheduler.state_dict() if scheduler else None,
                'val_metrics': val_metrics
            }, name='best.pth')
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print(f"Early stopping at epoch {epoch} (best epoch {best_epoch} f1={best_val_f1:.4f})")
                break

    save_metrics(metrics_history, fname="metrics.json")
    return metrics_history

# -------------------------
# Inference helpers
# -------------------------
def preprocess_single(text, lang):
    """
    Take raw text + language tag -> produce tokenized inputs similar to dataset
    """
    translit = transliterate_text(lang, text)
    phon = g2p_map(lang, translit)
    enc = text_tokenizer(translit, truncation=True, padding='max_length', max_length=MAX_TEXT_LEN, return_tensors='pt')
    ph_ids = torch.tensor(phoneme_tokenize(phon, phon_tok2idx), dtype=torch.long).unsqueeze(0)
    lid = lang2idx.get(lang, 0)
    return enc['input_ids'].squeeze(0), enc['attention_mask'].squeeze(0), ph_ids.squeeze(0), lid

def predict_batch(model, texts, langs):
    """
    texts: list of strings
    langs: list of lang codes corresponding to texts
    returns: list of dicts: {'text', 'lang', 'prob', 'label', 'lid_pred', 'lid_confidence'}
    """
    model.eval()
    results = []
    enc_batch = text_tokenizer([transliterate_text(l, t) for t, l in zip(texts, langs)],
                               truncation=True, padding='longest', max_length=MAX_TEXT_LEN, return_tensors='pt')
    ph_batch = []
    for t, l in zip(texts, langs):
        ph = g2p_map(l, transliterate_text(l, t))
        ph_ids = phoneme_tokenize(ph, phon_tok2idx)
        ph_batch.append(ph_ids)
    ph_batch = torch.tensor(ph_batch, dtype=torch.long)
    with torch.no_grad():
        out = model(enc_batch['input_ids'].to(DEVICE), enc_batch['attention_mask'].to(DEVICE), ph_batch.to(DEVICE))
        probs = torch.sigmoid(out['abuse_logits']).detach().cpu().numpy()
        lids_logits = out['lid_logits'].detach().cpu().numpy()
        lids_pred = lids_logits.argmax(axis=1)
        lids_conf = (torch.softmax(torch.tensor(lids_logits), dim=1).max(dim=1).values).numpy()
    for i, txt in enumerate(texts):
        results.append({
            'text': txt,
            'lang': langs[i],
            'prob': float(probs[i]),
            'label': int(probs[i] >= 0.5),
            'lid_pred': int(lids_pred[i]),
            'lid_confidence': float(lids_conf[i])
        })
    return results

def predict_single(model, text, lang):
    input_ids, attention_mask, ph_ids, lid = preprocess_single(text, lang)
    model.eval()
    with torch.no_grad():
        out = model(input_ids.unsqueeze(0).to(DEVICE),
                    attention_mask.unsqueeze(0).to(DEVICE),
                    ph_ids.unsqueeze(0).to(DEVICE))
        prob = float(torch.sigmoid(out['abuse_logits']).item())
        lid_logits = out['lid_logits'].squeeze(0).cpu().numpy()
        lid_pred = int(lid_logits.argmax())
        lid_conf = float(torch.softmax(torch.tensor(lid_logits), dim=0).max().item())
    return {'text':text, 'lang':lang, 'prob':prob, 'label': int(prob>=0.5), 'lid_pred':lid_pred, 'lid_confidence':lid_conf}

# -------------------------
# Robustness test utilities (noise on phonemes / transliteration / char swaps)
# -------------------------
import random, re
def phoneme_dropout(ph_str, drop_prob=0.1):
    toks = ph_str.strip().split()
    out = [t for t in toks if random.random() > drop_prob]
    if len(out)==0:
        out = toks[:1]
    return ' '.join(out)

def char_swap_noise(text, swap_prob=0.05):
    chars = list(text)
    for i in range(len(chars)-1):
        if random.random() < swap_prob:
            chars[i], chars[i+1] = chars[i+1], chars[i]
    return ''.join(chars)

def translit_noise(lang, text, noise_level=0.05):
    # apply char swaps to transliterated text
    t = transliterate_text(lang, text)
    return char_swap_noise(t, swap_prob=noise_level)

def robustness_evaluation(model, df, perturbation='phon_drop', n_samples=100):
    """
    Run evaluation on a subset with a particular perturbation applied and return metrics.
    perturbation: 'phon_drop', 'char_swap', 'translit_noise'
    """
    # Pick n_samples random rows
    df_sample = df.sample(min(n_samples, len(df)), random_state=SEED).reset_index(drop=True)
    texts = []
    langs = []
    labels = []
    for _, r in df_sample.iterrows():
        if perturbation == 'phon_drop':
            ph = phoneme_dropout(r['phonemes'], drop_prob=0.15)
            translit = r['translit']  # keep same transliteration
            # For prediction we'll feed translit into tokenizer but phonemes as modified
            texts.append(translit)
            langs.append(r['lang'])
            labels.append(int(r['label']))
        elif perturbation == 'char_swap':
            t_noisy = char_swap_noise(r['text'], swap_prob=0.08)
            texts.append(t_noisy)
            langs.append(r['lang'])
            labels.append(int(r['label']))
        elif perturbation == 'translit_noise':
            t_noisy = translit_noise(r['lang'], r['text'], noise_level=0.08)
            texts.append(t_noisy)
            langs.append(r['lang'])
            labels.append(int(r['label']))
        else:
            texts.append(r['text'])
            langs.append(r['lang'])
            labels.append(int(r['label']))
    preds = predict_batch(model, texts, langs)
    probs = [p['prob'] for p in preds]
    metrics = compute_metrics(labels, probs, threshold=0.5)
    return metrics

# -------------------------
# Ablation runner
# -------------------------
def run_ablation(ablation_name, disable_adapter=False, disable_phoneme=False, disable_reconstructor=False):
    """
    ablation_name: string
    toggles: disable parts to evaluate their effect
    This creates a shallow copy model and modifies components to disable them.
    """
    model_copy = deepcopy(model)
    model_copy.to(DEVICE)
    if disable_adapter:
        # replace adapter with identity
        model_copy.text_adapter = nn.Identity()
        print("Adapter disabled for ablation")
    if disable_phoneme:
        # replace phoneme encoder with zero-output module
        class DummyPhoneme(nn.Module):
            def __init__(self, out_dim=PHONEME_EMB_DIM):
                super().__init__()
                self.out_dim = out_dim
            def forward(self, x):
                b = x.size(0)
                return torch.zeros((b, self.out_dim), device=x.device)
        model_copy.phoneme_encoder = DummyPhoneme()
        print("Phoneme encoder disabled for ablation")
    if disable_reconstructor:
        model_copy.reconstructor = nn.Identity()
        print("Reconstructor disabled for ablation")

    # quick eval
    metrics = evaluate(model_copy, valid_loader)
    print(f"Ablation {ablation_name} results: {metrics}")
    return metrics

# -------------------------
# Example: run full training
# -------------------------
if __name__ == "__main__":
    # Re-create optimizer/scheduler if needed
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    total_steps = len(train_loader) * EPOCHS
    warmup_steps = int(0.06 * total_steps)
    from transformers import get_cosine_schedule_with_warmup
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    metrics_history = train_loop(model, train_loader, valid_loader, optimizer, scheduler,
                                 epochs=EPOCHS, early_stop_patience=3,
                                 adv_train=True, adv_epsilon=1e-3, save_every=1)

    # final evaluation on validation/test
    val_metrics = evaluate(model, valid_loader)
    print("Final val metrics:", val_metrics)

    # robustness tests
    for pert in ['phon_drop', 'char_swap', 'translit_noise']:
        rm = robustness_evaluation(model, test_df, perturbation=pert, n_samples=200)
        print(f"Robustness ({pert}):", rm)

    # run ablations
    ablations = {
        'no_adapter': {'disable_adapter':True},
        'no_phoneme': {'disable_phoneme':True},
        'no_recon': {'disable_reconstructor':True}
    }
    ablation_results = {}
    for name, opts in ablations.items():
        res = run_ablation(name, **opts)
        ablation_results[name] = res
    save_metrics({'final_val': val_metrics, 'robustness': rm, 'ablations': ablation_results}, fname="final_results.json")

    # save tokenizer & phoneme vocab
    text_tokenizer.save_pretrained(CHECKPOINT_DIR)
    with open(os.path.join(CHECKPOINT_DIR, "phon_tok2idx.json"), "w") as f:
        json.dump(phon_tok2idx, f)

    print("Training + evaluation complete. Best model saved as checkpoints/best.pth")

# -------------------------
# Notes & quick tips
# -------------------------
# - If GPU memory is tight: lower MAX_TEXT_LEN or reduce BATCH_SIZE.
# - For stronger adversarial training, replace FGM with multi-step PGD (increase compute).
# - For view-agreement losses: you can add contrastive loss between text_embed and phon_embed
#   e.g., NT-Xent on pooled representations; add weights to main loss.
# - For data augmentation: consider back-translation (Helsinki models), code-mixed insertion,
#   or phoneme-level substitution to improve generalization.
# - To run evaluation on a specific checkpoint:
#     ck = torch.load('checkpoints/best.pth', map_location=DEVICE)
#     model.load_state_dict(ck['model_state'])
#     print(evaluate(model, valid_loader))
#
# - To produce per-class reports:
#     y_true, y_probs = ... collect from evaluate loop and call classification_report or compute_metrics with thresholds
#
# Hyperparameter hints:
# - Epsilon for FGM: 1e-3 or 1e-2 (too large -> unstable)
# - FocalLoss alpha/gamma tuneable per class imbalance (alpha ~ 0.25, gamma 2.0 is a good start)
# - Use gradient accumulation if batch size too small.


Epoch 1 train: 100%|██████████████████████████| 150/150 [18:44<00:00,  7.50s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:18<00:00,  2.07it/s]


[2025-10-24T02:22:02.697839] Epoch 1 finished in 1143.2s; train_loss=0.2770 val_f1=0.4940 acc=0.5783
Saved checkpoint: checkpoints/epoch_1.pth
Saved checkpoint: checkpoints/best.pth


Epoch 2 train: 100%|██████████████████████████| 150/150 [18:41<00:00,  7.48s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:18<00:00,  2.10it/s]


[2025-10-24T02:41:08.500502] Epoch 2 finished in 1139.9s; train_loss=0.0946 val_f1=0.6077 acc=0.6483
Saved checkpoint: checkpoints/epoch_2.pth
Saved checkpoint: checkpoints/best.pth


Epoch 3 train: 100%|██████████████████████████| 150/150 [18:22<00:00,  7.35s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:17<00:00,  2.14it/s]


[2025-10-24T02:59:54.835799] Epoch 3 finished in 1120.7s; train_loss=0.0595 val_f1=0.6897 acc=0.7083
Saved checkpoint: checkpoints/epoch_3.pth
Saved checkpoint: checkpoints/best.pth


Epoch 4 train: 100%|██████████████████████████| 150/150 [18:24<00:00,  7.36s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:18<00:00,  2.11it/s]


[2025-10-24T03:18:42.562466] Epoch 4 finished in 1122.2s; train_loss=0.0555 val_f1=0.6628 acc=0.6883
Saved checkpoint: checkpoints/epoch_4.pth


Epoch 5 train: 100%|██████████████████████████| 150/150 [18:23<00:00,  7.36s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:17<00:00,  2.14it/s]


[2025-10-24T03:37:26.593273] Epoch 5 finished in 1121.1s; train_loss=0.0639 val_f1=0.7303 acc=0.7417
Saved checkpoint: checkpoints/epoch_5.pth
Saved checkpoint: checkpoints/best.pth


Epoch 6 train: 100%|██████████████████████████| 150/150 [18:57<00:00,  7.58s/it]
eval: 100%|█████████████████████████████████████| 38/38 [00:18<00:00,  2.09it/s]


[2025-10-24T03:56:47.903644] Epoch 6 finished in 1155.9s; train_loss=0.0649 val_f1=0.6679 acc=0.6933
Saved checkpoint: checkpoints/epoch_6.pth


eval: 100%|█████████████████████████████████████| 38/38 [00:18<00:00,  2.10it/s]


Final val metrics: {'macro_precision': 0.7786496701044135, 'macro_recall': 0.6933333333333334, 'macro_f1': 0.6679139945374258, 'accuracy': 0.6933333333333334, 'sensitivity': 0.4166666666666667, 'specificity': 0.97, 'error_rate': 0.30666666666666664}


KeyError: 'phonemes'

In [51]:
# Save model weights
torch.save(model.state_dict(), "trident_max.pt")
print("Model weights saved as trident_max.pt")


Model weights saved as trident_max.pt


In [56]:
import pandas as pd
import torch
from tqdm.auto import tqdm
import random
import numpy as np

# -----------------------------
# Load and fix test dataset
# -----------------------------
TEST_PATH = 'testing.xlsx'
test_df = pd.read_excel(TEST_PATH)

# Standardize columns
test_df.dropna(subset=['text'], inplace=True)
test_df['lang'] = test_df['lang'].str.lower().str.strip()
test_df.rename(columns={'phoneme':'phonemes'}, inplace=True)
test_df['lang_id'] = test_df['lang'].map(lang2idx)

# Precompute translit, phonemes, gloss if not already
for df in [test_df]:
    df['translit'] = df.apply(lambda r: transliterate_text(r['lang'], str(r['text'])), axis=1)
    df['phonemes'] = df.apply(lambda r: g2p_map(r['lang'], r['translit']), axis=1)
    df['gloss'] = df['text']

# -----------------------------
# Load saved model
# -----------------------------
model = AbuseModel(text_model, text_adapter, phoneme_encoder).to(DEVICE)
model.load_state_dict(torch.load("trident_max.pt", map_location=DEVICE))
model.eval()

# -----------------------------
# Perturbation functions
# -----------------------------
def phoneme_dropout(ph_str, drop_prob=0.15):
    toks = ph_str.strip().split()
    new_toks = [t for t in toks if random.random() > drop_prob]
    return ' '.join(new_toks) if new_toks else ' '.join(toks)

def char_swap(text, swap_prob=0.1):
    text = list(text)
    for i in range(len(text)-1):
        if random.random() < swap_prob:
            text[i], text[i+1] = text[i+1], text[i]
    return ''.join(text)

def translit_noise(text, noise_prob=0.1):
    text = list(text)
    for i in range(len(text)):
        if random.random() < noise_prob:
            text[i] = random.choice('abcdefghijklmnopqrstuvwxyz')
    return ''.join(text)

# -----------------------------
# Robustness evaluation
# -----------------------------
def robustness_evaluation(model, df, perturbation='phon_drop', n_samples=None):
    model.eval()
    if n_samples:
        df_sample = df.sample(n_samples, random_state=42)
    else:
        df_sample = df

    all_labels, all_preds = [], []

    for _, r in tqdm(df_sample.iterrows(), total=len(df_sample), desc=f'Perturbation: {perturbation}'):
        text = r['text']
        translit = r['translit']
        phon = r['phonemes']

        if perturbation == 'phon_drop':
            phon_mod = phoneme_dropout(phon)
            translit_mod = translit
        elif perturbation == 'char_swap':
            translit_mod = char_swap(translit)
            phon_mod = phon
        elif perturbation == 'translit_noise':
            translit_mod = translit_noise(translit)
            phon_mod = phon
        else:
            translit_mod = translit
            phon_mod = phon

        enc = text_tokenizer(translit_mod, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
        input_ids = enc['input_ids'].to(DEVICE)
        attention_mask = enc['attention_mask'].to(DEVICE)
        phon_ids = torch.tensor([phoneme_tokenize(phon_mod, phon_tok2idx)], dtype=torch.long).to(DEVICE)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, phon_ids=phon_ids)
            abuse_prob = torch.sigmoid(out['abuse_logits']).item()
            pred_label = int(abuse_prob >= 0.5)

        all_labels.append(int(r['label']))
        all_preds.append(pred_label)

    from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro', zero_division=0)
    return {'accuracy': acc, 'macro_precision': precision, 'macro_recall': recall, 'macro_f1': f1}

# Run robustness evaluation
perturbations = ['phon_drop', 'char_swap', 'translit_noise']
for pert in perturbations:
    metrics = robustness_evaluation(model, test_df, perturbation=pert, n_samples=200)
    print(f"Robustness ({pert}):", metrics)

# -----------------------------
# Single-text inference
# -----------------------------
def predict_text(text, lang='en'):
    translit = transliterate_text(lang, text)
    phon = g2p_map(lang, translit)

    enc = text_tokenizer(translit, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
    input_ids = enc['input_ids'].to(DEVICE)
    attention_mask = enc['attention_mask'].to(DEVICE)
    phon_ids = torch.tensor([phoneme_tokenize(phon, phon_tok2idx)], dtype=torch.long).to(DEVICE)

    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask, phon_ids=phon_ids)
        abuse_prob = torch.sigmoid(out['abuse_logits']).item()
        pred_label = 'Abusive' if abuse_prob >= 0.5 else 'Not Abusive'
        lid_idx = torch.argmax(out['lid_logits'], dim=1).item()
        lang_pred = list(lang2idx.keys())[lid_idx]

    return {'text': text, 'predicted_language': lang_pred, 'abuse_prediction': pred_label, 'abuse_prob': abuse_prob}

# Example usage
example = "നീ very smart!"
result = predict_text(example, lang='ml')
print(result)


Perturbation: phon_drop:   0%|          | 0/200 [00:00<?, ?it/s]

Robustness (phon_drop): {'accuracy': 0.67, 'macro_precision': 0.7662079257171281, 'macro_recall': 0.6606946251626464, 'macro_f1': 0.6296711929076422}


Perturbation: char_swap:   0%|          | 0/200 [00:00<?, ?it/s]

Robustness (char_swap): {'accuracy': 0.65, 'macro_precision': 0.7165209387942936, 'macro_recall': 0.6412771494344911, 'macro_f1': 0.612789025334661}


Perturbation: translit_noise:   0%|          | 0/200 [00:00<?, ?it/s]

Robustness (translit_noise): {'accuracy': 0.64, 'macro_precision': 0.7255411255411255, 'macro_recall': 0.6303673305975378, 'macro_f1': 0.592944369063772}
{'text': 'നീ very smart!', 'predicted_language': 'te', 'abuse_prediction': 'Not Abusive', 'abuse_prob': 0.27650994062423706}


In [57]:
print(test_df.columns)
# Index(['text', 'label', 'lang', 'translit', 'phonemes', 'gloss'], dtype='object')


Index(['text', 'label', 'lang', 'lang_id', 'translit', 'phonemes', 'gloss'], dtype='object')


In [58]:
# Initialize the same model architecture
loaded_model = AbuseModel(text_model, text_adapter, phoneme_encoder).to(DEVICE)

# Load weights
loaded_model.load_state_dict(torch.load("trident_max.pt", map_location=DEVICE))
loaded_model.eval()  # set to evaluation mode
print("trident_max model loaded successfully")


trident_max model loaded successfully


In [63]:
with open('phon_tok2idx.pkl', 'wb') as f:
    pickle.dump(phon_tok2idx, f)


In [66]:
# %%
import torch
import torch.nn as nn
import pickle
from transformers import AutoTokenizer, AutoModel

# -------------------------------
# 1️⃣ Device, tokenizer, phoneme vocab
# -------------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

TEXT_MODEL_NAME = 'bert-base-multilingual-cased'
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(DEVICE)

with open('phon_tok2idx.pkl','rb') as f:
    phon_tok2idx = pickle.load(f)
PHON_VOCAB_SIZE = len(phon_tok2idx)
PHON_EMB_DIM = 256
MAX_TEXT_LEN = 128
MAX_PH_LEN = 128

lang2idx = {'en':0,'te':1,'ml':2}
idx2lang = {v:k for k,v in lang2idx.items()}

# -------------------------------
# 2️⃣ Define classes exactly as during training
# -------------------------------
class Adapter(nn.Module):
    def __init__(self, hidden_size=768, bottleneck=128):
        super().__init__()
        self.down = nn.Linear(hidden_size,bottleneck)
        self.act = nn.ReLU()
        self.up = nn.Linear(bottleneck,hidden_size)
    def forward(self,x):
        return self.up(self.act(self.down(x)))

class PhonemeEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=PHON_EMB_DIM, nhead=8, nlayers=3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
    def forward(self, input_ids):
        emb = self.embedding(input_ids).transpose(0,1)
        out = self.transformer(emb).transpose(0,1)
        mask = (input_ids!=0).unsqueeze(-1).float()
        summed = (out*mask).sum(1)
        lengths = mask.sum(1).clamp(min=1.0)
        return summed/lengths

class CrossAttentionFusion(nn.Module):
    def __init__(self, text_dim=768, phon_dim=PHON_EMB_DIM, hidden_dim=512, nhead=8):
        super().__init__()
        self.proj_text = nn.Linear(text_dim, hidden_dim)
        self.proj_phon = nn.Linear(phon_dim, hidden_dim)
        self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads=nhead, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
    def forward(self, text_feats, phon_feats):
        q = self.proj_text(text_feats)
        kv = self.proj_phon(phon_feats).unsqueeze(1)
        attn_out,_ = self.cross_attn(q,kv,kv)
        out = self.norm(attn_out + q)
        return out.squeeze(1)

class AbuseModel(nn.Module):
    def __init__(self, text_model, text_adapter, phoneme_encoder, n_langs=3, hidden_dim=512):
        super().__init__()
        self.text_model = text_model
        self.text_adapter = text_adapter
        self.phoneme_encoder = phoneme_encoder
        self.fusion = CrossAttentionFusion(hidden_dim=hidden_dim)
        self.abuse_head = nn.Linear(hidden_dim,1)
        self.lid_head = nn.Linear(hidden_dim,n_langs)
    def forward(self, input_ids, attention_mask, phon_ids):
        out = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last = out.last_hidden_state
        mask = attention_mask.unsqueeze(-1)
        pooled = (last*mask).sum(1)/mask.sum(1).clamp(min=1.0)
        adapted = pooled + self.text_adapter(pooled)
        phon_pooled = self.phoneme_encoder(phon_ids.to(pooled.device))
        fused = self.fusion(adapted.unsqueeze(1), phon_pooled)
        abuse_logits = self.abuse_head(fused).squeeze(-1)
        lid_logits = self.lid_head(fused)
        return {'abuse_logits': abuse_logits, 'lid_logits': lid_logits}

# -------------------------------
# 3️⃣ Load model
# -------------------------------
text_adapter = Adapter(hidden_size=text_model.config.hidden_size)
phoneme_encoder = PhonemeEncoder(PHON_VOCAB_SIZE)
model = AbuseModel(text_model, text_adapter, phoneme_encoder).to(DEVICE)

checkpoint_path = 'trident_max.pt'
model.load_state_dict(torch.load(checkpoint_path,map_location=DEVICE))
model.eval()

# -------------------------------
# 4️⃣ Prediction helper
# -------------------------------
def phoneme_tokenize(ph_str):
    toks = ph_str.strip().split()
    ids = [phon_tok2idx.get(t,phon_tok2idx.get('<unk>',0)) for t in toks][:MAX_PH_LEN]
    if len(ids)<MAX_PH_LEN:
        ids += [phon_tok2idx.get('<pad>',0)]*(MAX_PH_LEN-len(ids))
    return ids

def predict_abuse_language(text):
    enc = text_tokenizer(text, truncation=True, padding='max_length', max_length=MAX_TEXT_LEN, return_tensors='pt')
    
    # dummy phonemes: split text into chars or words; replace with actual phoneme extraction if available
    phon_ids = torch.tensor([phoneme_tokenize(' '.join(list(text)))], dtype=torch.long)
    
    with torch.no_grad():
        out = model(enc['input_ids'].to(DEVICE),
                    enc['attention_mask'].to(DEVICE),
                    phon_ids.to(DEVICE))
        abuse_prob = torch.sigmoid(out['abuse_logits']).item()
        abuse_label = "Abusive" if abuse_prob>=0.5 else "Not Abusive"
        lid_pred = torch.argmax(out['lid_logits'], dim=-1).item()
        lang_pred = idx2lang.get(lid_pred,'en')
    return {'text': text,'predicted_language': lang_pred,'abuse_prediction': abuse_label,'abuse_prob': abuse_prob}

# -------------------------------
# 5️⃣ Test examples
# -------------------------------
texts = ["നീ very smart!","You are చాలా bad!","This is fine"]
for t in texts:
    print(predict_abuse_language(t))


RuntimeError: Error(s) in loading state_dict for AbuseModel:
	Unexpected key(s) in state_dict: "reconstructor.0.weight", "reconstructor.0.bias", "reconstructor.2.weight", "reconstructor.2.bias". 
	size mismatch for text_model.embeddings.word_embeddings.weight: copying a param with shape torch.Size([250002, 768]) from checkpoint, the shape in current model is torch.Size([119547, 768]).
	size mismatch for text_model.embeddings.position_embeddings.weight: copying a param with shape torch.Size([514, 768]) from checkpoint, the shape in current model is torch.Size([512, 768]).
	size mismatch for text_model.embeddings.token_type_embeddings.weight: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]).

In [68]:
# %%
import torch
import pickle
from transformers import AutoTokenizer, AutoModel

# -------------------------------
# 0️⃣ Device
# -------------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -------------------------------
# 1️⃣ Load tokenizer and phoneme vocab
# -------------------------------
TEXT_MODEL_NAME = 'xlm-roberta-base'
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(DEVICE)

with open('phon_tok2idx.pkl', 'rb') as f:
    phon_tok2idx = pickle.load(f)

PHON_VOCAB_SIZE = len(phon_tok2idx)
PHON_EMB_DIM = 256  # same as training

lang2idx = {'en':0, 'te':1, 'ml':2}  # update if you used different

# -------------------------------
# 2️⃣ Define model components
# -------------------------------
class Adapter(torch.nn.Module):
    def __init__(self, hidden_size, bottleneck=128):
        super().__init__()
        self.down = torch.nn.Linear(hidden_size, bottleneck)
        self.act = torch.nn.ReLU()
        self.up = torch.nn.Linear(bottleneck, hidden_size)
    def forward(self, x):
        return self.up(self.act(self.down(x)))

class PhonemeEncoder(torch.nn.Module):
    def __init__(self, vocab_size, emb_dim=PHON_EMB_DIM, nhead=8, nlayers=3):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead)
        self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
    def forward(self, phon_ids):
        emb = self.embedding(phon_ids).transpose(0,1)
        out = self.transformer(emb).transpose(0,1)
        mask = (phon_ids!=0).unsqueeze(-1).float()
        pooled = (out*mask).sum(1)/mask.sum(1).clamp(min=1.0)
        return pooled

class CrossAttentionFusion(torch.nn.Module):
    def __init__(self, text_dim=768, phon_dim=PHON_EMB_DIM, hidden_dim=512, nhead=8):
        super().__init__()
        self.proj_text = torch.nn.Linear(text_dim, hidden_dim)
        self.proj_phon = torch.nn.Linear(phon_dim, hidden_dim)
        self.cross_attn = torch.nn.MultiheadAttention(hidden_dim, num_heads=nhead, batch_first=True)
        self.norm = torch.nn.LayerNorm(hidden_dim)
        self.ff = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim),
                                      torch.nn.ReLU(),
                                      torch.nn.Linear(hidden_dim, hidden_dim))
    def forward(self, text_feats, phon_feats):
        q = self.proj_text(text_feats)
        kv = self.proj_phon(phon_feats).unsqueeze(1)
        attn_out, _ = self.cross_attn(q, kv, kv)
        out = self.norm(attn_out + q)
        return out.squeeze(1)

class AbuseModel(torch.nn.Module):
    def __init__(self, text_model, text_adapter, phoneme_encoder, n_langs=3, hidden_dim=512):
        super().__init__()
        self.text_model = text_model
        self.text_adapter = text_adapter
        self.phoneme_encoder = phoneme_encoder
        self.fusion = CrossAttentionFusion(hidden_dim=hidden_dim)
        self.abuse_head = torch.nn.Linear(hidden_dim, 1)
        self.lid_head = torch.nn.Linear(hidden_dim, n_langs)
        # reconstructor exists in checkpoint but optional here
        self.reconstructor = torch.nn.Sequential(torch.nn.Linear(hidden_dim, text_model.config.hidden_size),
                                                 torch.nn.ReLU(),
                                                 torch.nn.Linear(text_model.config.hidden_size, text_model.config.hidden_size))

    def forward(self, input_ids, attention_mask, phon_ids):
        out = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = out.last_hidden_state[:,0,:] + self.text_adapter(out.last_hidden_state[:,0,:])
        phon_pooled = self.phoneme_encoder(phon_ids.to(pooled.device))
        fused = self.fusion(pooled.unsqueeze(1), phon_pooled)
        abuse_logits = self.abuse_head(fused).squeeze(-1)
        lid_logits = self.lid_head(fused)
        return {'abuse_logits': abuse_logits, 'lid_logits': lid_logits}

# -------------------------------
# 3️⃣ Instantiate model
# -------------------------------
text_adapter = Adapter(hidden_size=text_model.config.hidden_size)
phoneme_encoder = PhonemeEncoder(PHON_VOCAB_SIZE)
model = AbuseModel(text_model, text_adapter, phoneme_encoder).to(DEVICE)

# Load checkpoint with strict=False to ignore extra keys like reconstructor
checkpoint = torch.load('trident_max.pt', map_location=DEVICE)
model.load_state_dict(checkpoint, strict=False)
model.eval()

# -------------------------------
# 4️⃣ Helper functions
# -------------------------------
MAX_PH_LEN = 128

def phoneme_tokenize(ph_str):
    toks = ph_str.strip().split()
    ids = [phon_tok2idx.get(t, phon_tok2idx.get('<unk>',0)) for t in toks][:MAX_PH_LEN]
    if len(ids) < MAX_PH_LEN:
        ids += [phon_tok2idx.get('<pad>',0)]*(MAX_PH_LEN-len(ids))
    return ids

def predict_abuse_language(text):
    # --- Example phoneme extraction ---
    phonemes = text.split()  # Replace with your phoneme extraction logic
    enc = text_tokenizer(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
    ph_ids = torch.tensor([phoneme_tokenize(" ".join(phonemes))], dtype=torch.long)

    with torch.no_grad():
        out = model(enc['input_ids'].to(DEVICE),
                    enc['attention_mask'].to(DEVICE),
                    ph_ids.to(DEVICE))
        abuse_prob = torch.sigmoid(out['abuse_logits']).item()
        abuse_pred = "Abusive" if abuse_prob>0.5 else "Not Abusive"
        lid_pred = torch.argmax(out['lid_logits'], dim=-1).item()
        lang_pred = [k for k,v in lang2idx.items() if v==lid_pred][0]

    return {
        'text': text,
        'predicted_language': lang_pred,
        'abuse_prediction': abuse_pred,
        'abuse_prob': abuse_prob
    }

# -------------------------------
# 5️⃣ Test
# -------------------------------
texts = [
    "നീ very smart!", 
    "You are చాలా bad!", 
    "This is fine"
]
for t in texts:
    print(predict_abuse_language(t))


{'text': 'നീ very smart!', 'predicted_language': 'en', 'abuse_prediction': 'Abusive', 'abuse_prob': 0.5016085505485535}
{'text': 'You are చాలా bad!', 'predicted_language': 'te', 'abuse_prediction': 'Abusive', 'abuse_prob': 0.5065277218818665}
{'text': 'This is fine', 'predicted_language': 'en', 'abuse_prediction': 'Not Abusive', 'abuse_prob': 0.4728836417198181}


In [71]:
# -------------------------------
# 1️⃣ Imports
# -------------------------------
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -------------------------------
# 2️⃣ Prediction loop on test set
# -------------------------------
true_langs = []
true_abuse = []
pred_langs = []
pred_abuse = []

for idx, row in test_df.iterrows():
    text = row['text']
    out = predict_abuse_language(text)  # your model inference function

    # True labels
    true_langs.append(row['lang'])
    true_abuse.append(int(row['label']))

    # Predicted labels
    pred_langs.append(out['predicted_language'])
    pred_abuse.append(1 if out['abuse_prediction'] == "Abusive" else 0)

# -------------------------------
# 3️⃣ Language detection metrics
# -------------------------------
print("===== Language Detection Metrics =====")
print("Accuracy:", accuracy_score(true_langs, pred_langs))
print("Macro Precision:", precision_score(true_langs, pred_langs, average='macro'))
print("Macro Recall:", recall_score(true_langs, pred_langs, average='macro'))
print("Macro F1:", f1_score(true_langs, pred_langs, average='macro'))

# -------------------------------
# 4️⃣ Abusive detection metrics
# -------------------------------
print("\n===== Abusive Detection Metrics =====")
print("Accuracy:", accuracy_score(true_abuse, pred_abuse))
print("Precision:", precision_score(true_abuse, pred_abuse))
print("Recall:", recall_score(true_abuse, pred_abuse))
print("F1-score:", f1_score(true_abuse, pred_abuse))


===== Language Detection Metrics =====
Accuracy: 0.9616666666666667
Macro Precision: 0.9634105178790984
Macro Recall: 0.9616666666666666
Macro F1: 0.9613662406752049

===== Abusive Detection Metrics =====
Accuracy: 0.73
Precision: 0.7331081081081081
Recall: 0.7233333333333334
F1-score: 0.7281879194630873
