In [1]:
from transformers import T5ForConditionalGeneration, AutoTokenizer

# Load the CORRECT model
model_name = "chronbmm/sanskrit5-multitask"
print(f"Loading model: {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
print("Model loaded successfully!\n")

print(f"Vocab size: {len(tokenizer)}")
print(f"Tokenizer class: {tokenizer.__class__.__name__}\n")

# Test 1: Word Segmentation (S)
print("=" * 70)
print("TEST 1: Word Segmentation (S)")
print("=" * 70)
input_text = "S yajñopavītaprācīnāvītayor adhvaryum anuvidadhīta"
print(f"Input: {input_text}\n")

inputs = tokenizer(input_text, return_tensors="pt")
# Generate output
outputs = model.generate(**inputs, max_length=512)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {result}\n")

# Test 2: Lemmatization (L)
print("=" * 70)
print("TEST 2: Lemmatization (L)")
print("=" * 70)
input_text = "L agnaye vaiśvānarāya dvādaśakapālaḥ"
print(f"Input: {input_text}\n")

inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512, num_beams=5)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {result}\n")

# Test 3: Lemmatization + Morphosyntax (LM)
print("=" * 70)
print("TEST 3: Lemmatization + Morphosyntax (LM)")
print("=" * 70)
input_text = "LM somam indrābhaspati pibataṃ dāśuṣo gṛhe"
print(f"Input: {input_text}\n")

inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512, num_beams=5)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {result}\n")

# Test 4: Try Devanagari
print("=" * 70)
print("TEST 4: Devanagari Input (S)")
print("=" * 70)
input_text = "S सयज्ञोपवीतप्राचीनावीतयोरध्वर्युमनुविदधीत"
print(f"Input: {input_text}\n")

inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512, num_beams=5)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {result}\n")

# Test 5: Additional examples
print("=" * 70)
print("TEST 5: More Examples")
print("=" * 70)

test_cases = [
    ("S", "agnau devatāḥ pratyakṣaṃ śrāvayati"),
    ("L", "somam indrābhaspati"),
    ("LM", "agnau devatāḥ"),
]

for prefix, text in test_cases:
    input_text = f"{prefix} {text}"
    print(f"\nInput: {input_text}")
    
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=512, num_beams=4)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Output: {result}")

print("\n" + "=" * 70)
print("Testing complete!")
print("=" * 70)
print("\nNote: This model should properly handle:")
print("  S  = Word Segmentation")
print("  L  = Lemmatization")
print("  LM = Lemmatization + Morphosyntax Tagging")

Loading model: chronbmm/sanskrit5-multitask...
Model loaded successfully!

Vocab size: 384
Tokenizer class: ByT5Tokenizer

TEST 1: Word Segmentation (S)
Input: S yajñopavītaprācīnāvītayor adhvaryum anuvidadhīta

Output: yajñopavīta_prācīnāvītayoḥ_adhvaryum_anuvidadhīta_

TEST 2: Lemmatization (L)
Input: L agnaye vaiśvānarāya dvādaśakapālaḥ

Output: agni_vaiśvānara_dvādaśan_kapāla_

TEST 3: Lemmatization + Morphosyntax (LM)
Input: LM somam indrābhaspati pibataṃ dāśuṣo gṛhe

Output: soma_SAM indrābhaspati_SVM pā_DuPr2Im dāś_SGPaPsM gṛha_SLNe

TEST 4: Devanagari Input (S)
Input: S सयज्ञोपवीतप्राचीनावीतयोरध्वर्युमनुविदधीत

Output: _____________

TEST 5: More Examples

Input: S agnau devatāḥ pratyakṣaṃ śrāvayati
Output: agnau_devatāḥ_pratyakṣam_śrāvayati_

Input: L somam indrābhaspati
Output: soma_indrābhaspati_

Input: LM agnau devatāḥ
Output: agni_SLM devatā_PNF

Testing complete!

Note: This model should properly handle:
  S  = Word Segmentation
  L  = Lemmatization
  LM = Lemmatization 