In [None]:
import gradio as gr
import transformers
import torch
import os

# Assuming DEVICE is already defined (e.g., 'cuda' or 'cpu')
DEVICE = 'cuda'  # or 'cpu' if you are not using CUDA

MODEL_DIRS = [
    './trained_model/epoch_0_batch_17619/',
    './trained_model/epoch_0_batch_52986/',
    './trained_model/epoch_0_batch_120933/',
]

# Load models and tokenizers from directories
models = {}
tokenizers = {}
for model_dir in MODEL_DIRS:
    model_name = os.path.basename(os.path.normpath(model_dir))
    models[model_name] = transformers.GPT2LMHeadModel.from_pretrained(model_dir).to(DEVICE)
    tokenizers[model_name] = transformers.AutoTokenizer.from_pretrained(model_dir)

def generate_text(prompt, max_length=20, selected_models=[]):
    """Generates text based on the input prompt for selected models."""
    outputs = {"original": prompt}
    for model_name in selected_models:
        model = models[model_name].eval()  # Set the model to evaluation mode
        tokenizer = tokenizers[model_name]
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)
        attention_mask = torch.ones(input_ids.shape, device=DEVICE)  # Create an attention mask for the inputs
        output_sequences = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length + len(input_ids[0]),
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            num_return_sequences=1,
        )
        
        generated_sequence = output_sequences[0].tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        text = text[: text.find(tokenizer.eos_token)] if tokenizer.eos_token else text  # Remove the end of sequence token

        outputs[model_name] = text  # Store the generated text for the model

    # Return the generated texts as a dictionary
    return outputs


checkbox_choices = [os.path.basename(os.path.normpath(model_dir)) for model_dir in MODEL_DIRS]
    
# Define Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(minimum=10, maximum=300, value=50),

        gr.CheckboxGroup(choices=checkbox_choices, label="Select Models", value=checkbox_choices),
    ],
    outputs=gr.JSON(label="Generated Texts"),
    title="GPT-2 Text Generation | Multi-Model Support",
    description="This interface generates text based on the input prompt using selected models. Each model is fine-tuned from GPT-2."
)

# Launch the interface
iface.launch(share=True)