## Setup

### Add API keys

In [None]:
# ! pip install trulens_eval==0.18.3 openai==1.3.7 torch transformers peft==0.6.2 gdown

In [None]:
import os
os.environ["OPENAI_API_KEY"] = ""

In [None]:
from trulens_eval import Tru, TruLlama

tru = Tru()

In [None]:
from llama_index import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader

documents = SimpleWebPageReader(
    html_to_text=True
).load_data(["http://paulgraham.com/worked.html"])
index = VectorStoreIndex.from_documents(documents)

query_engine = index.as_query_engine()

In [None]:
response = query_engine.query("What did the author do growing up?")
print(response)

## Initialize Feedback Function(s)

In [None]:
import gdown
import torch
from torch import nn
from transformers import AutoModel
from peft.tuners.lora.config import LoraConfig
from peft.mapping import get_peft_model
from peft.utils.peft_types import TaskType

### Load model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
class ClassifierHead(nn.Module):
    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        p_dropout: float,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=p_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        activation = self.dropout(x)
        activation = self.dense(activation)
        activation = torch.tanh(activation)
        activation = self.dropout(activation)
        return self.out_proj(activation)
    

class FeedbackClassifier(nn.Module):
    def __init__(self, backbone_model: nn.Module, n_classes: int = 1):
        super().__init__()
        self.n_classes = n_classes
        self._backbone = backbone_model
        
        self.eos_token_id = self._backbone.config.eos_token_id

        self._classifier = ClassifierHead(
            input_dim=self._backbone.config.hidden_size,
            inner_dim=self._backbone.config.hidden_size,
            num_classes=n_classes,
            p_dropout=self._backbone.config.hidden_dropout_prob,
        )
        if n_classes == 1:
            self.probits_proj = nn.Sigmoid()
        else:
            self.probits_proj = nn.Softmax(dim=-1)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self._backbone(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        cls_output = outputs.pooler_output
        return self._classifier(cls_output)

    def compute_probits(self, logits: torch.Tensor):
        probits: torch.Tensor = self.probits_proj(logits)
        if len(probits.shape) > 1 and self.n_classes == 1:
            probits = probits.squeeze(1)
        return probits

In [None]:
def create_model(model_name: str, n_classes: int = 1, device: str = None):
    backbone = AutoModel.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16
    )
    peft_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION, 
        inference_mode=False, 
        r=8, 
        lora_alpha=32, 
        lora_dropout=0.1,
        # target_modules=["query", "value"]
    )
    backbone = get_peft_model(backbone, peft_config)

    model = FeedbackClassifier(backbone, n_classes=n_classes)
    if device:
        model = model.to(device)
    model = model.to(torch.bfloat16)
    return model

def download_checkpoint():
    checkpoint_path = './checkpoint.ckpt'
    if not os.path.exists(checkpoint_path):
        raise ValueError("""1) Download the model checkpoint and 2) move it to ./checkpoint.ckpt
                         https://drive.google.com/file/d/1dmQqr7K3OL8TVNYPj8yF9OnhNtzroOO0/view?usp=drive_link""")
    return checkpoint_path

def load_model_from_checkpoint(base_model_name: str):
    model = create_model(base_model_name, device=device)
    checkpoint_path = download_checkpoint()
    state_dict = torch.load(checkpoint_path, map_location=torch.device(device))
    state_dict = {k[6:]: v for k, v in state_dict['state_dict'].items()}
    model.load_state_dict(state_dict)
    return model

In [None]:
base_model_name: str = "roberta-base"
model = load_model_from_checkpoint(base_model_name)

### Inference helpers

In [None]:
import torch
from typing import Mapping, Optional
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

def tokenize_batch(
    batch, 
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 
    max_length: Optional[int] = None,
    pad_seq: bool = True,
    truncate_seq: bool = True,
    return_type: str = "pt"
) -> Mapping[str, torch.Tensor]:
    tokenizer_kwargs = {
        "return_tensors": return_type, 
        "padding": "max_length" if pad_seq else "do_not_pad", 
        "truncation": truncate_seq, 
        "max_length": max_length
    }
    texts = batch['text']
    return tokenizer(
        texts, 
        **tokenizer_kwargs
    )

def combine_premise_hypothesis(batch, sep_token: str):
    premises = batch['premise']
    hypotheses = batch['hypothesis']    
    return {"text": [f"{premise}{sep_token}{hypothesis}" for premise, hypothesis in zip(premises, hypotheses)]}

def collate_batch(record_batch):
    assert isinstance(record_batch, dict)
    for k, v in record_batch.items():
        if not isinstance(v, list):
            assert isinstance(v, str)
            record_batch[k] = [v]
    return record_batch

### Create Provider

In [None]:
from torch import nn
import numpy as np
import torch.nn.functional as F
from trulens_eval import Provider, Feedback

class TruEraDistill(Provider):
    models = {}
    base_models = {}

    def __init__(self, context_relevance_model: nn.Module, context_base_model: str):
        super().__init__("TruEraDistill")
        self.models = {
            "context_relevance": context_relevance_model
        }
        self.base_models = {
            "context_relevance": context_base_model
        }
        
    def _prepare_inputs(self, premise: str, hypothesis: str, objective: str) -> dict:
        model_name = self.base_models[objective]
        
        record_batch = collate_batch({"premise": premise, "hypothesis": hypothesis})

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        # Combine premise and hypothesis into a single text field
        record_batch = combine_premise_hypothesis(record_batch, sep_token=tokenizer.sep_token)

        return tokenize_batch(
            record_batch, 
            tokenizer=tokenizer,
            max_length=tokenizer.model_max_length,
            pad_seq=False,
            truncate_seq=True,
        )

    def _call_model(self, inputs, objective: str):
        model = self.models[objective]
        logits = model(inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device))
        probits = F.sigmoid(logits)
        return {
            "probits": probits.detach().cpu().float().numpy(),
            "logits": logits.detach().cpu().float().numpy()
        }

    def context_relevance(self, instruction: str, context: str) -> float:
        objective = "context_relevance"
        inputs = self._prepare_inputs(instruction, context, objective=objective)
        return float(self._call_model(inputs, objective)['probits'].squeeze())


In [None]:
truera_distill = TruEraDistill(model, base_model_name)

f_ctx_relevance = Feedback(truera_distill.context_relevance).on_input().on(
    TruLlama.select_source_nodes().node.text
    ).aggregate(np.mean)

## Instrument chain for logging with TruLens

In [None]:
test_evalautions = [
    "What did the author do growing up?",
    "Where did the author work?",
    "What notable achievements did the author have?",
    "What did the author do after graduating?",
    # irrelevant questions
    "Where can I order a Big Mac?",
    "Who was the first President of PepsiCo?",
    "How many children did Napoleon have?",
    "What is the capital of California?",
]

In [None]:
tru_query_engine_recorder = TruLlama(query_engine,
    app_id='LlamaIndex_App1',
    feedbacks=[f_ctx_relevance])

In [None]:
with tru_query_engine_recorder as recording:
    for query in test_evalautions:
        query_engine.query(query)

## Retrieve records and feedback

In [None]:
# The record of the ap invocation can be retrieved from the `recording`:

rec = recording.records # use .get if only one record
# recs = recording.records # use .records if multiple

display(rec)

In [None]:
tru.run_dashboard() # open a local streamlit app to explore

# tru.stop_dashboard() # stop if needed

## Or view results directly in your notebook

In [None]:
tru.get_records_and_feedback(app_ids=[])[0] # pass an empty list of app_ids to get all