In [None]:
import torch

from bret.data_loaders import (
    GenericDataLoader,
    get_text_dataloader,
    get_training_dataloader,
)
from bret.models import model_factory

In [None]:
def preprocess_key(old_key):
    if "embeddings" in old_key:
        return old_key
    if "norm" in old_key.lower():
        return old_key
    if "pooler" in old_key:
        return old_key
    if old_key.endswith(".weight"):
        return old_key.replace(".weight", ".weight_mean")
    if old_key.endswith(".bias"):
        return old_key.replace(".bias", ".bias_mean")
    return old_key

In [None]:
encoder_ckpt = "../output/trained_encoders/bert-base-dpr.pt"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tokenizer, model = model_factory("bert-base", "bret", device)
sd = torch.load(encoder_ckpt, map_location=device)
sdnew = {}
for old_key, v in sd.items():
    k = preprocess_key(old_key)
    sdnew[k] = v
model.load_state_dict(sdnew, strict=False)

In [None]:
corpus_dl = get_text_dataloader("../data/msmarco-corpus.jsonl", batch_size=32)
for corpus_sample in corpus_dl:
    psg_id, psg = corpus_sample  # Get a single batch of passages from the corpus.
    break

In [None]:
with torch.no_grad():
    psg_enc = tokenizer(
        psg, padding="max_length", truncation=True, max_length=256, return_tensors="pt"
    ).to(device)
    Upsg = model.compute_uncertainty(psg_enc, num_samples=10)

In [None]:
Upsg