In [None]:
import torch
from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration

In [None]:
class TextGenerate:
    def __init__(self, model):
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = T5ForConditionalGeneration.from_pretrained(model)

    def encode_text(self, text_input, text_ouput):
        inputs = self.tokenizer(
            text_input,
            max_length=16,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        labels = self.tokenizer(
            text_ouput,
            max_length=16,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        ).input_ids[0]
        input_ids = inputs["input_ids"][0]
        attention_mask = inputs["attention_mask"][0]
        labels = torch.tensor([label if label != 0 else -100 for label in labels])
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

    def generate_prediction(
        self, ctext, summ_len=16, beam_search=10, repetition_penalty=2.5
    ):
        input_ids = self.tokenizer(ctext, return_tensors="pt").input_ids

        generated_ids = self.model.generate(
            input_ids,
            do_sample=True,
            max_length=summ_len,
            top_k=beam_search,
            temperature=0.7,
        )

        summary = self.tokenizer.decode(
            generated_ids.squeeze(),
            skip_special_tokens=True,
            repetition_penalty=repetition_penalty,
        )
        return summary

In [None]:
custom_model = TextGenerate('shaankhosla/digit_conversion')
output = custom_model.generate_prediction("5")
print(output)
output = custom_model.generate_prediction("-8")
print(output)
output = custom_model.generate_prediction("11")
print(output)
output = custom_model.generate_prediction("24")
print(output)
output = custom_model.generate_prediction("-112")
print(output)
output = custom_model.generate_prediction("-236")
print(output)
output = custom_model.generate_prediction("-7965")
print(output)
output = custom_model.generate_prediction("32043")
print(output)
output = custom_model.generate_prediction("34986")
print(output)
output = custom_model.generate_prediction("430895")
print(output)
output = custom_model.generate_prediction("435641")
print(output)
output = custom_model.generate_prediction("-43968")
print(output)
output = custom_model.generate_prediction("-328493")
print(output)
output = custom_model.generate_prediction("-352501")
print(output)