# 🔄 Fine-Tune and Serve a BERT LLM with FFT, LORA, and QLoRa with Union.ai: A Hands-On Tutorial

<a target="_blank" href="https://colab.research.google.com/github/unionai-oss/bert-llm-classification-pipeline/blob/main/tutorial.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Welcome to this step-by-step tutorial on building a **Large Language Model (LLM) fine-tuning pipeline** using **Hugging Face Transformers, PEFT**  and **Union.ai’s AI workflow and inference platform**. In this tutorial, you’ll fine-tune a **BERT-based model for text classification**, serve it for inference, and track every step of your pipeline using **Union’s MLOps capabilities**.  



## 📝 What You'll Build  

By the end of this tutorial, you'll have a **fully functional AI pipeline** that:  

1. 📥 **Downloads and processes a dataset**   
2. 🏋️‍♂️ **Fine-tunes a BERT model for classification with FFT, LORA, and QLoRa**   
3. 💾 **Saves and versions the trained model** 
4. 📊 **Evaluates the model on a test set**   
4. 🚀 **Deploys the model for real-time inference**  
5. 📈 **Tracks all artifacts and experiments** using Union.ai



## 🧰 Setup 


To get started, sign up for a **Union Serverless** account at [Union.ai](https://union.ai) by clicking the **"Get Started"** button. No credit card is required, and you'll receive **$30 in free credits** to begin experimenting. The signup process takes just a few minutes.  

Alternatively, if you have access to a **[Union BYOC Enterprise](https://www.union.ai/pricing)** account, you can log into your account.  

### 📦 Install Python Packages & Clone Repo

Packages can be installed in your local environment using the following command using your preferred package manager from the [requirements.txt](requirements.txt) file. For example `pip install -r requirements.txt`. 

to clone the repo, run the following command in your environment: `git clone https://github.com/unionai-oss/bert-llm-classification-pipeline`

If you're running this notebook in a Google Colab environment, you can install the packages and clone the GitHub repo directly in the notebook by running the following cell:


In [None]:
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    !git clone https://github.com/unionai-oss/bert-llm-classification-pipeline
    %cd bert-llm-classification-pipeline
    !pip install -r requirements.txt

### 🔐 Authenticate
To use **Union.ai**, you'll need to authenticate your account. Follow the appropriate step based on your setup:  

##### 🔸 **Using Union BYOC Enterprise**  

If you're using a **[Union BYOC Enterprise](https://www.union.ai/pricing)** account, log in with the following command:  
```bash
union create login --host <union-host-url>
```

Replace <union-host-url> with your organization's Union instance URL.

##### 🔸 Using Union Serverless
If you're using [Union Serverless](https://www.union.ai/) , authenticate by running the command below:

Create an account for free at [Union.ai](https://union.ai) if you don't have one yet:
 

In [None]:
# 🌟 Authenticate to union serverless
!union create login --serverless --auth device-flow

## 🔀 BERT Fine-Tuning Pipeline  
We’ll create an **end-to-end machine learning pipeline** to train a **BERT model for text classification** using the **IMDB Review Dataset**.

- Run the command below to fine-tune the BERT model using the Union.ai CLI. This command will create a new pipeline and start the training process.

- The first time you this command it will take a while to download the model and set up the environment.

- The subsequent runs will be faster as the container, model, and data will be cached.

In [None]:
# 👇 Run this command to start the fine-tuning workflow using lora, qlora or full
!union run --remote workflows/train_pipeline.py train_pipeline --epochs 3 --tuning_method full 

### 🔎 Explore the Code  

- The command above is using files from the [`workflows/`](workflows/train_pipeline.py) and [`tasks`](tasks/) folders that got cloned on setup.

- The codeis added to this notebook for reference with the `%%writefile` magic command to overwrite the files if you want to make changes.

- You do not need to run the code cells with `%%writefile` unless you want to make changes to the pipeline or tasks.


In [None]:
%%writefile workflows/train_pipeline.py

"""
This file contains the train_pipeline workflow that orchestrates the
training pipeline for BERT classification models
"""

from union import workflow

from tasks.data import download_dataset, visualize_data
from tasks.inference import predict_batch_sentiment
from tasks.model import download_model, evaluate_model, train_model


# ---------------------------
# train pipeline
# ---------------------------
@workflow
def train_pipeline(
    tuning_method: str = "lora",  # options: "full", "lora", "qlora"
    model_name: str = "distilbert-base-uncased",
    epochs: int = 3,
    extra_test_text: list[str] = [
        "This is a great movie!",
        "This is a bad movie!",
    ],
) -> None:

    train_dataset, val_dataset, test_dataset = download_dataset()
    saved_model_dir = download_model(model_name=model_name)

    visualize_data(
        train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset
    )

    trained_model_dir = train_model(
        tuning_method=tuning_method,
        model_dir=saved_model_dir,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        epochs=epochs,
    )

    evaluate_model(trained_model_dir=trained_model_dir, test_dataset=test_dataset)

    # Perform batch inference
    predict_batch_sentiment(trained_model_dir=trained_model_dir, texts=extra_test_text)

# Run model training pipeline:
#!union run --remote workflows/train_pipeline.py train_pipeline

> **💡 Note:**  
> In more complex ML workflows, **data pipelines** are often separate from **model training pipelines**.  
> For simplicity, we'll combine them into a single workflow in this example.  


In [None]:
%%writefile containers.py
"""
This file contains the container image specification for the BERT classification pipeline
"""

from flytekit import ImageSpec, Resources

container_image = ImageSpec(
     name="fine-tune-qlora",
    requirements="requirements.txt",
    pip_extra_index_url=["https://download.pytorch.org/whl/cu118"],  #enables +cu118 builds
    builder="union",
    cuda="11.8",  # ensure GPU + CUDA layer is available
    apt_packages=["gcc", "g++"],  # optional, for packages like bitsandbytes
)

# we can also define a reusable stateful container environment
# See this in action near the end of this notebook for faster bactch inference!
actor = ActorEnvironment(
    name="my-actor",
    container_image=container_image,
    replica_count=1,
    ttl_seconds=360,
    requests=Resources(
        cpu="2",
        mem="5000Mi",
        gpu="1",
    ),
)



In [None]:
%%writefile requirements.txt

# This file contains the requirements for the BERT classification pipeline

torch==2.5.1+cu118
transformers==4.48.2
datasets>=2.14.0
peft==0.14.0
bitsandbytes==0.45.3 #change
accelerate==1.3.0
flytekit==1.15.0
union==0.1.151
python-dotenv==1.0.1
matplotlib==3.10.0
pandas==2.2.2
scikit-learn==1.6.1
seaborn==0.13.2


In [None]:
%%writefile tasks/data.py

"""
This module contains tasks for downloading the dataset and visualizing the data.
"""

from pathlib import Path

from datasets import Dataset
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from typing_extensions import Annotated
from union import Artifact, Deck, Resources, current_context, task

from containers import container_image

# Define Artifact Specifications
RawImdbDataset = Artifact(name="raw_imdb_dataset")
TrainImdbDataset = Artifact(name="train_imdb_dataset")
ValImdbDataset = Artifact(name="val_imdb_dataset")
TestImdbDataset = Artifact(name="test_imdb_dataset")


# ---------------------------
# download dataset
# ---------------------------
@task(
    container_image=container_image,
    cache=True,
    cache_version="1",
    requests=Resources(cpu="2", mem="2Gi"),
)
def download_dataset() -> tuple[
    Annotated[FlyteFile, TrainImdbDataset],
    Annotated[FlyteFile, ValImdbDataset],
    Annotated[FlyteFile, TestImdbDataset],
]:

    import pandas as pd
    from datasets import load_dataset
    from sklearn.model_selection import train_test_split

    # Load IMDB dataset
    dataset = load_dataset("imdb")
    train_df = dataset["train"].to_pandas()
    test_df = dataset["test"].to_pandas()

    # Split training set into train and validation sets
    train_df, val_df = train_test_split(
        train_df, test_size=0.2, stratify=train_df["label"], random_state=42
    )

    working_dir = Path(current_context().working_directory)
    data_dir = working_dir / "data"
    data_dir.mkdir(parents=True, exist_ok=True)

    # Save datasets as CSV files
    train_path = data_dir / "train.csv"
    val_path = data_dir / "val.csv"
    test_path = data_dir / "test.csv"

    train_df.to_csv(train_path, index=False)
    val_df.to_csv(val_path, index=False)
    test_df.to_csv(test_path, index=False)

    return (
        TrainImdbDataset.create_from(train_path),
        ValImdbDataset.create_from(val_path),
        TestImdbDataset.create_from(test_path),
    )


# ---------------------------
# visualize data
# ---------------------------
@task(
    container_image=container_image,
    enable_deck=True,
    requests=Resources(cpu="2", mem="2Gi"),
)
def visualize_data(
    train_dataset: FlyteFile, val_dataset: FlyteFile, test_dataset: FlyteFile
):
    import base64
    from textwrap import dedent

    import matplotlib.pyplot as plt
    import pandas as pd

    ctx = current_context()

    # Load datasets from CSV files
    train_df = pd.read_csv(train_dataset.download())
    val_df = pd.read_csv(val_dataset.download())
    test_df = pd.read_csv(test_dataset.download())

    # Create the deck for visualization
    deck = Deck("Dataset Analysis")

    # Sample reviews from the datasets
    train_positive_review = train_df[train_df["label"] == 1].iloc[0]["text"]
    train_negative_review = train_df[train_df["label"] == 0].iloc[0]["text"]
    val_positive_review = val_df[val_df["label"] == 1].iloc[0]["text"]
    val_negative_review = val_df[val_df["label"] == 0].iloc[0]["text"]
    test_positive_review = test_df[test_df["label"] == 1].iloc[0]["text"]
    test_negative_review = test_df[test_df["label"] == 0].iloc[0]["text"]

    # Visualization helper
    def plot_label_distribution(df, title, color, output_path):
        plt.figure(figsize=(10, 5))
        df["label"].value_counts().plot(kind="bar", color=color)
        plt.title(title)
        plt.xlabel("Label")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()

    # Plot label distributions
    plot_label_distribution(
        train_df,
        "Train Data Label Distribution",
        "skyblue",
        "/tmp/train_label_distribution.png",
    )
    plot_label_distribution(
        val_df,
        "Validation Data Label Distribution",
        "orange",
        "/tmp/val_label_distribution.png",
    )
    plot_label_distribution(
        test_df,
        "Test Data Label Distribution",
        "lightgreen",
        "/tmp/test_label_distribution.png",
    )

    # Convert images to base64 for embedding
    def image_to_base64(image_path):
        with open(image_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode("utf-8")

    train_image_base64 = image_to_base64("/tmp/train_label_distribution.png")
    val_image_base64 = image_to_base64("/tmp/val_label_distribution.png")
    test_image_base64 = image_to_base64("/tmp/test_label_distribution.png")

    # Create HTML report
    html_report = dedent(
        f"""
    <div style="font-family: Arial, sans-serif; line-height: 1.6;">
        <h2 style="color: #2C3E50;">Dataset Analysis</h2>

        <h3 style="color: #2980B9;">Training Data Summary</h3>
        <img src="data:image/png;base64,{train_image_base64}" alt="Train Data Label Distribution" width="600">
        Shape: {train_df.shape} <br>
        Label Distribution: {train_df['label'].value_counts()} <br>
        <p><strong>Positive Review:</strong> {train_positive_review}</p>
        <p><strong>Negative Review:</strong> {train_negative_review}</p>

        <h3 style="color: #2980B9;">Validation Data Summary</h3>
        <img src="data:image/png;base64,{val_image_base64}" alt="Validation Data Label Distribution" width="600">
        Shape: {val_df.shape} <br>
        Label Distribution: {val_df['label'].value_counts()} <br>
        <p><strong>Positive Review:</strong> {val_positive_review}</p>
        <p><strong>Negative Review:</strong> {val_negative_review}</p>

        <h3 style="color: #2980B9;">Test Data Summary</h3>
        <img src="data:image/png;base64,{test_image_base64}" alt="Test Data Label Distribution" width="600">
        Shape: {test_df.shape} <br>
        Label Distribution: {test_df['label'].value_counts()} <br>
        <p><strong>Positive Review:</strong> {test_positive_review}</p>
        <p><strong>Negative Review:</strong> {test_negative_review}</p>
    </div>
    """
    )

    # Append HTML content to the deck
    deck.append(html_report)

    # Insert the deck into the context
    ctx.decks.insert(0, deck)


In [None]:
%%writefile tasks/model.py
"""
This file contains the tasks that are used to download, train and evaluate the model.
"""

from pathlib import Path
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from typing_extensions import Annotated
from union import Artifact, Deck, Resources, current_context, task
from containers import container_image

# Define Artifact Specifications
FineTunedImdbModel = Artifact(name="fine_tuned_Imdb_model")

# ---------------------------
# download model
# ---------------------------
@task(
    container_image=container_image,
    cache=True,
    cache_version="1",
    requests=Resources(cpu="2", mem="2Gi"),
)
def download_model(model_name: str) -> FlyteDirectory:
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    working_dir = Path(current_context().working_directory)
    saved_model_dir = working_dir / "saved_model"
    saved_model_dir.mkdir(parents=True, exist_ok=True)


    # update AutoModelForSequenceClassification to "AutoModelForCausalLM" for causal models
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        device_map="cpu",
        torch_dtype="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model.save_pretrained(saved_model_dir)
    tokenizer.save_pretrained(saved_model_dir)

    return FlyteDirectory(saved_model_dir)


# ---------------------------
# full/lora/qlora fine-tune model
# ---------------------------
@task(
    container_image=container_image,
    requests=Resources(cpu="4", mem="12Gi", gpu="1"),
)
def train_model(
    model_dir: FlyteDirectory,
    train_dataset: FlyteFile,
    val_dataset: FlyteFile,
    epochs: int = 3,
    tuning_method: str = "full",  # options: "full", "lora", "qlora"
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.1,
) -> Annotated[FlyteDirectory, FineTunedImdbModel]:
    import pandas as pd
    import torch
    from datasets import Dataset
    from transformers import (
        AutoModelForSequenceClassification,
        AutoTokenizer,
        Trainer,
        TrainingArguments,
    )

    # Load datasets
    #------------------------------------
    local_model_dir = model_dir.download()
    train_df = pd.read_csv(train_dataset.download()).sample(n=500, random_state=42)
    val_df = pd.read_csv(val_dataset.download()).sample(n=100, random_state=42)

    train_dataset_hf = Dataset.from_pandas(train_df)
    val_dataset_hf = Dataset.from_pandas(val_df)

    tokenizer = AutoTokenizer.from_pretrained(local_model_dir)

    def tokenizer_function(example):
        return tokenizer(example["text"], padding="max_length", truncation=True)

    tokenized_train = train_dataset_hf.map(tokenizer_function)
    tokenized_val = val_dataset_hf.map(tokenizer_function)

    # Load & Setup Model
    #------------------------------------
    # qlora will load the model in 4-bit quantization
    if tuning_method == "qlora":
        from transformers import BitsAndBytesConfig
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            llm_int8_skip_modules=["classifier", "pre_classifier"],
        )
        model = AutoModelForSequenceClassification.from_pretrained(
            local_model_dir,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            # device_map="auto", # use this for models if implemented
        )

    else:
        # Load the model normally
        model = AutoModelForSequenceClassification.from_pretrained(local_model_dir)

    # if lora or qlora, set the LoRA config
    if tuning_method in {"lora", "qlora"}:
        from peft import get_peft_model, LoraConfig, TaskType
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=["q_lin", "k_lin", "v_lin"], # query, Key, Value linear layers in this model
        )
 
        model = get_peft_model(model, lora_config)

    
    # Model fine-tuning
    #------------------------------------
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
    )

    trainer.train()

    # Save the fine-tuned model
    #------------------------------------
    # Merge LoRA weights into base model (you could also just save adapter weights)
    if tuning_method in {"lora", "qlora"}:
        model = model.merge_and_unload()

    output_dir = Path(current_context().working_directory) / "trained_model"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    #TODO: Save traning type (lora, qlora, full) as artifacts
    return FineTunedImdbModel.create_from(output_dir)



# ---------------------------
# evaluate model
# ---------------------------
@task(
    container_image=container_image,
    enable_deck=True,
    requests=Resources(cpu="2", mem="12Gi", gpu="1"),
)
def evaluate_model(trained_model_dir: FlyteDirectory, test_dataset: FlyteFile) -> dict:
    import numpy as np
    import pandas as pd
    from datasets import Dataset
    from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
    from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
    from sklearn.metrics import confusion_matrix, roc_curve, auc
    import seaborn as sns
    import matplotlib.pyplot as plt
    import base64
    from union import current_context
    from union import Deck
    from textwrap import dedent

    # Download model locally
    local_model_dir = trained_model_dir.download()
    ctx = current_context()

    # Load model and tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(
        local_model_dir,
        torch_dtype="auto",
        load_in_4bit=False,  # Important: for evaluation, avoid loading in quantized 4-bit unless you really want to
    )
    tokenizer = AutoTokenizer.from_pretrained(local_model_dir)

    # Load and prepare the test dataset
    test_df = pd.read_csv(test_dataset.download()).sample(n=100, random_state=42)

    # Use a pipeline for evaluation (bypasses Trainer and works for quantized models)
    nlp_pipeline = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        # device=0 if torch.cuda.is_available() else -1,  # auto-select device
        truncation=True,
        padding=True,
    )

    # Perform batch inference
    predictions = nlp_pipeline(test_df["text"].tolist(), batch_size=8)

    # Extract predicted labels
    pred_labels = [int(p["label"].split("_")[-1]) if "label" in p else 0 for p in predictions]
    true_labels = test_df["label"].tolist()

    # Calculate metrics
    metrics = {
        "accuracy": accuracy_score(true_labels, pred_labels),
        "f1": f1_score(true_labels, pred_labels, average="weighted"),
        "precision": precision_score(true_labels, pred_labels, average="weighted"),
        "recall": recall_score(true_labels, pred_labels, average="weighted"),
        # "conf_matrix": confusion_matrix(true_labels, pred_labels)
    }

    # create visualization deck
    deck = Deck("Model Evaluation")

    # Generate Confusion Matrix
    cm = confusion_matrix(true_labels, pred_labels)
    cm_path = f"/tmp/confusion_matrix.png"
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=sorted(set(true_labels)), yticklabels=sorted(set(true_labels)))
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.savefig(cm_path)
    plt.close()
    
    # Convert images to base64 for embedding
    def image_to_base64(image_path):
        with open(image_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode("utf-8")
        
    cm_image_base64 = image_to_base64(cm_path)

    # Create HTML report
    html_report = dedent(
        f"""
    <div style="font-family: Arial, sans-serif; line-height: 1.6;">
        <h2 style="color: #2C3E50;">Model Evaluation</h2>

        <h3 style="color: #2980B9;">Confusion Matrix</h3>
        <img src="data:image/png;base64,{cm_image_base64}" alt="Confusion Matrix" width="600">
        <h3 style="color: #2980B9;">Model Metrics</h3>
        <pre>{metrics}</pre>
        
    </div>
        """)

     # Append HTML content to the deck
    deck.append(html_report)
    # Insert the deck into the context
    ctx.decks.insert(0, deck)

    return metrics


# 🚀 Serving the Fine-Tuned BERT model:

### Live App Serving (Beta)

Union.ai provides a **simple way to serve your models as a live app**, making it easy to interact with your trained model.  

In this example, we'll deploy the model using **Streamlit**, which provides a **simple web interface** for running predictions.  


📂 Check out the following files for the model-serving code:  
-[`app.py`](app.py) – Handles **loading the model** and serving it via Union.ai.  
- [`main.py`](main.py) – Defines the **Streamlit-based UI** for interacting with the model.  

Deploy the model by running the following command:

In [None]:
# 👇 Run this command to serve the model & streamlit application
!union deploy apps app.py bert-sentiment-analysis

Just like the training pipeline, the code is added to this notebook for reference with the `%%writefile` magic command to overwrite the files if you want to make changes directly in the notebook. But running the cells below are not required since the code is already in the `workflows/` and `tasks/` folders.

In [None]:
%%writefile app.py
"""A Union app that uses hugging face and Streamlit"""

import os

from union import Artifact, ImageSpec, Resources
from union.app import App, Input, ScalingMetric
from datetime import timedelta
from flytekit.extras.accelerators import L4, GPUAccelerator


# Define the artifact that holds the BERT model.
FineTunedImdbModel = Artifact(name="fine_tuned_Imdb_model")

# Define the container image including the required packages.
# ---------------------------------------
image_spec = ImageSpec(
    name="union-serve-bert-sentiment-analysis",
    packages=[
        "transformers==4.48.3",
        "union-runtime>=0.1.11",
        "accelerate==1.5.2",
        "streamlit==1.43.2",
        "bitsandbytes==0.45.3"
    ],
    builder="union",
    registry=os.getenv("REGISTRY"),
)

# Create the Union Serving App.
# ---------------------------------------
streamlit_app = App(
    name="bert-sentiment-analysis",
    inputs=[
        Input(
            name="bert_model",
            value=FineTunedImdbModel.query(),
            download=True,  # The model artifact is downloaded when the container starts.
        )
    ],
    container_image=image_spec,
    limits=Resources(cpu="2", mem="24Gi", gpu="1", ephemeral_storage="20Gi"),
    requests=Resources(cpu="2", mem="24Gi", gpu="1", ephemeral_storage="20Gi"),
    accelerator=L4,
    port=8082,
    include=["./main.py"],  # Include your Streamlit code.
    args=["streamlit", "run", "main.py", "--server.port", "8082"],
    min_replicas=0,
    max_replicas=1,
    scaledown_after=timedelta(minutes=5),
    scaling_metric=ScalingMetric.Concurrency(2),
    # requires_auth=False # Uncomment to make app public.
)

# union deploy apps app.py bert-sentiment-analysis


In [None]:
%%writefile main.py
"""
A simple Union app using Streamlit to serve a BERT model with Streamlit.
"""

import streamlit as st
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from union_runtime import get_input

# Load the model artifact downloaded by Union.
# ---------------------------------------
model_path = get_input("bert_model")
try:
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
    st.error(f"Error loading model: {e}")
    st.stop()


# Creat e the Streamlit app.
# ---------------------------------------
st.title("Sentiment Analyzer")
st.write("Enter text to predict the sentiment.")

# Input text for sentiment analysis
user_input = st.text_area("Enter your text:", height=400, key="text_input")

if st.button("Analyze"):
    try:
        # Tokenize and predict
        inputs = tokenizer(
            user_input, return_tensors="pt", truncation=True, padding=True
        )
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = F.softmax(logits, dim=-1)
        predictions = logits.argmax(dim=-1)
        labels = ["NEGATIVE", "POSITIVE"]  # Adjust according to your model's labels

        sentiment = labels[predictions.item()]
        score = probabilities[0][predictions.item()].item()

        if sentiment == "NEGATIVE":
            st.error(f"Predicted sentiment: {sentiment} (Confidence: {score:.2f})")
        else:
            st.success(f"Predicted sentiment: {sentiment} (Confidence: {score:.2f})")
    except Exception as e:
        st.error(f"Prediction error: {e}")

# union deploy apps app.py bert-sentiment-analysis


Check the Union platform `Apps` tab to see the status of all apps!

Once the app is live, experiment with different inputs and see how your fine-tuned BERT model performs! 🚀


## Optional: Use workflows to run batch inference

### Batch Serving

Union.ai also provides a way to serve your models in batch mode. This is useful when you have a large number of predictions to make and you want to do them all at once.

In [None]:
# 👇 Run this command to register tasks & workflows on Union.ai
!union register workflows/batch_inference.py

- Below we'll use union remote to run the batch inference pipeline directly in the Notebook. 
- This will create a new pipeline and start the batch inference process.

In [None]:
%%writefile containers.py

from flytekit import ImageSpec, Resources
from union.actor import ActorEnvironment

container_image = ImageSpec(
     name="fine-tune-qlora",
    requirements="requirements.txt",
    pip_extra_index_url=["https://download.pytorch.org/whl/cu118"],  #enables +cu118 builds
    builder="union",
    cuda="11.8",  # ensure GPU + CUDA layer is available
    apt_packages=["gcc", "g++"],  # optional, for packages like bitsandbytes
)

actor = ActorEnvironment(
    name="my-actor",
    container_image=container_image,
    replica_count=1,
    ttl_seconds=360,
    requests=Resources(
        cpu="2",
        mem="5000Mi",
        gpu="1",
    ),
)

In [None]:
from union.remote import UnionRemote
# Create a remote connection
remote = UnionRemote()

In [None]:
def predict_with_container(data):

    inputs = {"texts": data}

    workflow = remote.fetch_workflow(name="workflows.batch_inference.batch_inference_workflow")
    execution = remote.execute(workflow, inputs=inputs, wait=True) # wait=True will block until the execution is complete

    # print(execution.outputs)

    return execution.outputs["o0"]

In [None]:
print(predict_with_container(["I love this movie",
                               "I hate this movie"]
                               ))

### ⚡ Faster batch serving with Union Actors

Union [Actors](https://docs.union.ai/serverless/user-guide/core-concepts/actors/#actors) dramatically reduce the cost of cold starts by maintaining long-running stateful environments that stay ready for use until a defined time-to-live (TTL). This persistent setup eliminates redundant initialization and unlocks several key benefits. This can be especially useful for AI pipelines that benefit from long-running environments, such as large containers, serving models,

In [None]:
def predict_with_actor(data):

    inputs = {"texts": data}

    workflow = remote.fetch_workflow(name="workflows.batch_inference.actor_batch_inference_workflow")
    execution = remote.execute(workflow, inputs=inputs, wait=True) # wait=True will block until the execution is complete

    # print(execution.outputs)

    return execution.outputs['o0']

- Run the next three commands in quick succession to see the actor container in action. 
- After the first command, the actor will be created and will stay alive for 5 minutes after the last call. 
- The commands after the first will be run in the same actor container, so you should see a significant speedup.

In [None]:
print(predict_with_actor(["I love this movie",
                               "I hate this movie"]
                               ))

In [None]:
print(predict_with_actor(["I love this movie",
                               "I hate this movie"]
                               ))

In [None]:
print(predict_with_actor(["I love this movie",
                               "I hate this movie"]
                               ))

## Resources to learn more
- Union.ai & Flyte Documentation: https://www.union.ai/docs/byoc/user-guide/
- Building AI Together Slack: https://slack.flyte.org/
- Flyte Github: https://github.com/flyteorg/flyte
- Union.ai OSS Github: https://github.com/unionai-oss