In [1]:
import os
import shutil

# 1) Define repo URL and local directory name
REPO_URL = "https://github.com/semilleroCV/BreastCATT.git"
REPO_DIR = "BreastCATT"

# 2) Delete the directory if it exists
if os.path.isdir(REPO_DIR):
    print(f"Directory '{REPO_DIR}' exists. Deleting it...")
    shutil.rmtree(REPO_DIR)

# 3) Clone the repository
print(f"Cloning repository into ./{REPO_DIR}...")
os.system(f"git clone {REPO_URL}")

# 4) Change into the repo directory
os.chdir(REPO_DIR)
print("Current working directory:", os.getcwd())

Cloning repository into ./BreastCATT...
Current working directory: /content/BreastCATT


In [2]:
# every time you do a factory reset you have to run this
!pip install --upgrade -r requirements.txt -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.1/55.1 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m

In [2]:
# IF YOU ARE WORKING ON KAGGLE OR GOOGLE COLAB DO NOT EXECUTE THIS CELL

import sys
import os

# The notebook is in ‘notebooks/’, but the project (and `breastcatt`) is in the parent directory.
# We add the parent directory to the path so we can import `breastcatt`.
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
  sys.path.append(project_root)

In [None]:
import torch
from safetensors.torch import load_file
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Lambda
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score
from torchmetrics.classification import BinarySpecificity, MulticlassSpecificity
from huggingface_hub import hf_hub_download

import numpy as np
import json

from breastcatt import tfvit

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def evaluate_split(model, dataloader, device):
    """
    Evaluates the model on a given dataloader and computes metrics.
    Handles both binary (num_classes=1) and multi-class (num_classes>1) cases.
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    # Initialize metrics based on model's output shape
    is_binary = (model.num_classes == 1)
    
    if is_binary:
        specificity_metric = BinarySpecificity().to(device)
    else:
        # For multi-class, we'll compute metrics per class then macro-average
        specificity_metric = MulticlassSpecificity(num_classes=model.num_classes, average='macro').to(device)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch["labels"]
            outputs = model(**batch)
            
            if is_binary:
                # Get probabilities and predictions
                probs = outputs.logits.sigmoid()  # shape [B, 1]
                preds = (probs > 0.5).long().squeeze(-1)  # shape [B]
                
                # Update metrics with proper shapes
                specificity_metric.update(probs.squeeze(-1), labels)  # [B], [B]
            else:
                # Multi-class case
                preds = outputs.logits.argmax(dim=-1)  # shape [B]
                specificity_metric.update(outputs.logits, labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    if is_binary:
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, zero_division=0)
        sensitivity = recall_score(all_labels, all_preds, zero_division=0)
        specificity = specificity_metric.compute().item()
    else:
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        sensitivity = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        specificity = specificity_metric.compute().item()

    return {
        "accuracy": accuracy,
        "precision": precision,
        "sensitivity": sensitivity,
        "specificity": specificity,
    }

from datasets import load_dataset

dataset = load_dataset("SemilleroCV/DMR-IR", revision="69ffd6240b4a50bc4a05c59b70773f3a506054f2", 
                        token="token")

train_val_split = 0.15

# If we don't have a validation split, split off a percentage of train as validation.
train_val_split = None if "validation" in dataset.keys() else train_val_split
if isinstance(train_val_split, float) and train_val_split > 0.0:
    split = dataset["train"].train_test_split(train_val_split, seed=0)
    dataset["train"] = split["train"]
    dataset["validation"] = split["test"]

print(f"validation split size: {len(dataset['validation'])}")

In [20]:
# --- ONLY RUN THIS CELL to Load one of our models ---
MODEL_REPO_ID = "SemilleroCV/tfvit-base-text-dmr-ir" # it can be base, base-text and base-text-seg

print(f"Loading model from {MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="config.json", revision="d3f752bd5b80e2389546f1bf3473ff0768bac02a")
try:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="pytorch_model.bin",
                                    revision="d3f752bd5b80e2389546f1bf3473ff0768bac02a"  
    )
except:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.safetensors",
                                    revision="d3f752bd5b80e2389546f1bf3473ff0768bac02a"
    )

# Load config
with open(config_path, "r") as f:
    init_args = json.load(f)

# Init model
model = (tfvit.multimodal_vit_large_patch16(**init_args) if init_args["embed_dim"] == 1024
         else tfvit.multimodal_vit_base_patch16(**init_args))

# Load and adapt state dict
state_dict = load_file(state_dict_path, device="cpu") if state_dict_path.endswith("safetensors") \
            else torch.load(state_dict_path, map_location="cpu")

# Key renaming
state_dict = {k.replace("text_embed_proj.", "language_model.proj."): v 
              for k, v in state_dict.items()}

# Load weights (ignore missing LM weights)
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)

# --- Transformations ---
min_max_norm = Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() > x.min() else x)
val_transforms = Compose(
    [
        Resize((224, 224)),
        ToTensor(),
        min_max_norm,
    ]
)

def preprocess_val(example_batch):
    """Apply _val_transforms across a batch."""
    example_batch["pixel_values"] = [
        val_transforms(image) for image in example_batch["image"]
    ]
    return example_batch

# DataLoaders creation:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    texts = [example["text"] for example in examples]
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "texts": texts, "labels": labels} 

eval_dataset = dataset["validation"].with_transform(preprocess_val)

BATCH_SIZE = 512
test_dataloader = DataLoader(dataset = eval_dataset, batch_size = BATCH_SIZE,
                             shuffle = False, collate_fn = collate_fn)

metrics = evaluate_split(model, test_dataloader, DEVICE)

print(f"  Accuracy:    {metrics['accuracy']:.4f}")
print(f"  Precision:   {metrics['precision']:.4f}")
print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
print(f"  Specificity: {metrics['specificity']:.4f}")

Loading model from SemilleroCV/tfvit-base-text-dmr-ir...
✅ MAE base weights already exist at: checkpoints/fvit/mae_pretrain_vit_base.pth


Loading checkpoint from: checkpoints/fvit/mae_pretrain_vit_base.pth
Adapting patch_embed.proj.weight from 3 channels to 1 channel...

✅ Loaded weights: 148 layers.
❌ Not found in model: 0 layers.
📋 Details of load_state_dict:
_IncompatibleKeys(missing_keys=['cls_token', 'pos_embed', 'language_model.model_lm.embeddings.word_embeddings.weight', 'language_model.model_lm.embeddings.position_embeddings.weight', 'language_model.model_lm.embeddings.token_type_embeddings.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.bias', 'language_model.model_lm.encoder.layer.0.attention.self.query.weight', 'language_model.model_lm.encoder.layer.0.attention.self.query.bias', 'language_model.model_lm.encoder.layer.0.attention.self.key.weight', 'language_model.model_lm.encoder.layer.0.attention.self.key.bias', 'language_model.model_lm.encoder.layer.0.attention.self.value.weight', 'language_model.model_lm.encoder.layer.0.attention.s

Evaluating: 100%|██████████| 2/2 [00:32<00:00, 16.23s/it]

  Accuracy:    0.6153
  Precision:   0.3076
  Sensitivity: 0.5000
  Specificity: 0.5000





In [26]:
# --- ONLY RUN THIS CELL to Load one of our models ---
MODEL_REPO_ID = "SemilleroCV/tfvit-base-text-seg-2" # it can be base, base-text and base-text-seg

print(f"Loading model from {MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="config.json")
try:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="pytorch_model.bin")
except:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.safetensors")

# Load config
with open(config_path, "r") as f:
    init_args = json.load(f)

# Init model
model = (tfvit.multimodal_vit_large_patch16(**init_args) if init_args["embed_dim"] == 1024
         else tfvit.multimodal_vit_base_patch16(**init_args))

# Load and adapt state dict
state_dict = load_file(state_dict_path, device="cpu") if state_dict_path.endswith("safetensors") \
            else torch.load(state_dict_path, map_location="cpu")

# Key renaming
state_dict = {k.replace("text_embed_proj.", "language_model.proj."): v 
              for k, v in state_dict.items()}

# Load weights (ignore missing LM weights)
model.load_state_dict(state_dict)
model.to(DEVICE)

# --- Transformations ---
min_max_norm = Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() > x.min() else x)
val_transforms = Compose(
    [
        Resize((224, 224)),
        ToTensor(),
        min_max_norm,
    ]
)

def preprocess_val(example_batch):
    """Apply _val_transforms across a batch."""
    example_batch["pixel_values"] = [
        val_transforms(image) for image in example_batch["image"]
    ]
    return example_batch

# DataLoaders creation:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    texts = [example["text"] for example in examples]
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "texts": texts, "labels": labels} 

eval_dataset = dataset["validation"].with_transform(preprocess_val)

BATCH_SIZE = 64
test_dataloader = DataLoader(dataset = eval_dataset, batch_size = BATCH_SIZE,
                             shuffle = False, collate_fn = collate_fn)

metrics = evaluate_split(model, test_dataloader, DEVICE)

print(f"  Accuracy:    {metrics['accuracy']:.4f}")
print(f"  Precision:   {metrics['precision']:.4f}")
print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
print(f"  Specificity: {metrics['specificity']:.4f}")

Loading model from SemilleroCV/tfvit-base-text-seg-2...
✅ MAE base weights already exist at: checkpoints/fvit/mae_pretrain_vit_base.pth
Loading checkpoint from: checkpoints/fvit/mae_pretrain_vit_base.pth
Adapting patch_embed.proj.weight from 3 channels to 1 channel...

✅ Loaded weights: 148 layers.
❌ Not found in model: 0 layers.
📋 Details of load_state_dict:
_IncompatibleKeys(missing_keys=['cls_token', 'pos_embed', 'language_model.model_lm.embeddings.word_embeddings.weight', 'language_model.model_lm.embeddings.position_embeddings.weight', 'language_model.model_lm.embeddings.token_type_embeddings.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.bias', 'language_model.model_lm.encoder.layer.0.attention.self.query.weight', 'language_model.model_lm.encoder.layer.0.attention.self.query.bias', 'language_model.model_lm.encoder.layer.0.attention.self.key.weight', 'language_model.model_lm.encoder.layer.0.attention.sel

Evaluating: 100%|██████████| 14/14 [00:46<00:00,  3.35s/it]

  Accuracy:    0.6200
  Precision:   0.0000
  Sensitivity: 0.0000
  Specificity: 1.0000





In [30]:
# --- ONLY RUN THIS CELL to Load one of our models ---
MODEL_REPO_ID = "SemilleroCV/tfvit-base-text-2" # it can be base, base-text and base-text-seg

print(f"Loading model from {MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="config.json")
try:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="pytorch_model.bin")
except:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.safetensors")

# Load config
with open(config_path, "r") as f:
    init_args = json.load(f)

# Init model
model = (tfvit.multimodal_vit_large_patch16(**init_args) if init_args["embed_dim"] == 1024
         else tfvit.multimodal_vit_base_patch16(**init_args))

# Load and adapt state dict
state_dict = load_file(state_dict_path, device="cpu") if state_dict_path.endswith("safetensors") \
            else torch.load(state_dict_path, map_location="cpu")

# Key renaming
state_dict = {k.replace("text_embed_proj.", "language_model.proj."): v 
              for k, v in state_dict.items()}

# Load weights (ignore missing LM weights)
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)

# --- Transformations ---
min_max_norm = Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() > x.min() else x)
val_transforms = Compose(
    [
        Resize((224, 224)),
        ToTensor(),
        min_max_norm,
    ]
)

def preprocess_val(example_batch):
    """Apply _val_transforms across a batch."""
    example_batch["pixel_values"] = [
        val_transforms(image) for image in example_batch["image"]
    ]
    return example_batch

# DataLoaders creation:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    texts = [example["text"] for example in examples]
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "texts": texts, "labels": labels} 

eval_dataset = dataset["validation"].with_transform(preprocess_val)

BATCH_SIZE = 512
test_dataloader = DataLoader(dataset = eval_dataset, batch_size = BATCH_SIZE,
                             shuffle = False, collate_fn = collate_fn)

metrics = evaluate_split(model, test_dataloader, DEVICE)

print(f"  Accuracy:    {metrics['accuracy']:.4f}")
print(f"  Precision:   {metrics['precision']:.4f}")
print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
print(f"  Specificity: {metrics['specificity']:.4f}")

Loading model from SemilleroCV/tfvit-base-text-2...
✅ MAE base weights already exist at: checkpoints/fvit/mae_pretrain_vit_base.pth


Loading checkpoint from: checkpoints/fvit/mae_pretrain_vit_base.pth
Adapting patch_embed.proj.weight from 3 channels to 1 channel...

✅ Loaded weights: 148 layers.
❌ Not found in model: 0 layers.
📋 Details of load_state_dict:
_IncompatibleKeys(missing_keys=['cls_token', 'pos_embed', 'language_model.model_lm.embeddings.word_embeddings.weight', 'language_model.model_lm.embeddings.position_embeddings.weight', 'language_model.model_lm.embeddings.token_type_embeddings.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.bias', 'language_model.model_lm.encoder.layer.0.attention.self.query.weight', 'language_model.model_lm.encoder.layer.0.attention.self.query.bias', 'language_model.model_lm.encoder.layer.0.attention.self.key.weight', 'language_model.model_lm.encoder.layer.0.attention.self.key.bias', 'language_model.model_lm.encoder.layer.0.attention.self.value.weight', 'language_model.model_lm.encoder.layer.0.attention.s

Evaluating: 100%|██████████| 2/2 [00:32<00:00, 16.48s/it]

  Accuracy:    0.9847
  Precision:   0.9749
  Sensitivity: 0.9842
  Specificity: 0.9850





## evaluating on test

In [27]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Lambda
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score
from torchmetrics.classification import BinarySpecificity, MulticlassSpecificity
from huggingface_hub import hf_hub_download
# from safetensors.torch import load_file

import numpy as np
import json

from breastcatt import tfvit

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def evaluate_split(model, dataloader, device):
    """
    Evaluates the model on a given dataloader and computes metrics.
    Handles both binary (num_classes=1) and multi-class (num_classes>1) cases.
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    # Initialize metrics based on model's output shape
    is_binary = (model.num_classes == 1)
    
    if is_binary:
        specificity_metric = BinarySpecificity().to(device)
    else:
        # For multi-class, we'll compute metrics per class then macro-average
        specificity_metric = MulticlassSpecificity(num_classes=model.num_classes, average='macro').to(device)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch["labels"]
            outputs = model(**batch)
            
            if is_binary:
                # Get probabilities and predictions
                probs = outputs.logits.sigmoid()  # shape [B, 1]
                preds = (probs > 0.5).long().squeeze(-1)  # shape [B]
                
                # Update metrics with proper shapes
                specificity_metric.update(probs.squeeze(-1), labels)  # [B], [B]
            else:
                # Multi-class case
                preds = outputs.logits.argmax(dim=-1)  # shape [B]
                specificity_metric.update(outputs.logits, labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    if is_binary:
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, zero_division=0)
        sensitivity = recall_score(all_labels, all_preds, zero_division=0)
        specificity = specificity_metric.compute().item()
    else:
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        sensitivity = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        specificity = specificity_metric.compute().item()

    return {
        "accuracy": accuracy,
        "precision": precision,
        "sensitivity": sensitivity,
        "specificity": specificity,
    }

from datasets import load_dataset

dataset = load_dataset("SemilleroCV/DMR-IR", revision="69ffd6240b4a50bc4a05c59b70773f3a506054f2", split='test', 
                        token="token")

In [28]:
# --- ONLY RUN THIS CELL to Load one of our models ---
MODEL_REPO_ID = "SemilleroCV/tfvit-base-text-2" # it can be base, base-text and base-text-seg

print(f"Loading model from {MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="config.json")
try:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="pytorch_model.bin")
except:
    state_dict_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.safetensors")

# Load config
with open(config_path, "r") as f:
    init_args = json.load(f)

# Init model
model = (tfvit.multimodal_vit_large_patch16(**init_args) if init_args["embed_dim"] == 1024
         else tfvit.multimodal_vit_base_patch16(**init_args))

# Load and adapt state dict
state_dict = load_file(state_dict_path, device="cpu") if state_dict_path.endswith("safetensors") \
            else torch.load(state_dict_path, map_location="cpu")

# Key renaming
state_dict = {k.replace("text_embed_proj.", "language_model.proj."): v 
              for k, v in state_dict.items()}

# Load weights (ignore missing LM weights)
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)

# --- Transformations ---
min_max_norm = Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() > x.min() else x)
val_transforms = Compose(
    [
        Resize((224, 224)),
        ToTensor(),
        min_max_norm,
    ]
)

def preprocess_val(example_batch):
    """Apply _val_transforms across a batch."""
    example_batch["pixel_values"] = [
        val_transforms(image) for image in example_batch["image"]
    ]
    return example_batch

# DataLoaders creation:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    texts = [example["text"] for example in examples]
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "texts": texts, "labels": labels} 

eval_dataset = dataset.with_transform(preprocess_val)

BATCH_SIZE = 512
test_dataloader = DataLoader(dataset = eval_dataset, batch_size = BATCH_SIZE,
                             shuffle = False, collate_fn = collate_fn)

metrics = evaluate_split(model, test_dataloader, DEVICE)

print(f"  Accuracy:    {metrics['accuracy']:.4f}")
print(f"  Precision:   {metrics['precision']:.4f}")
print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
print(f"  Specificity: {metrics['specificity']:.4f}")

Loading model from SemilleroCV/tfvit-base-text-2...
✅ MAE base weights already exist at: checkpoints/fvit/mae_pretrain_vit_base.pth
Loading checkpoint from: checkpoints/fvit/mae_pretrain_vit_base.pth
Adapting patch_embed.proj.weight from 3 channels to 1 channel...

✅ Loaded weights: 148 layers.
❌ Not found in model: 0 layers.
📋 Details of load_state_dict:
_IncompatibleKeys(missing_keys=['cls_token', 'pos_embed', 'language_model.model_lm.embeddings.word_embeddings.weight', 'language_model.model_lm.embeddings.position_embeddings.weight', 'language_model.model_lm.embeddings.token_type_embeddings.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.weight', 'language_model.model_lm.encoder.layer.0.attention.ln.bias', 'language_model.model_lm.encoder.layer.0.attention.self.query.weight', 'language_model.model_lm.encoder.layer.0.attention.self.query.bias', 'language_model.model_lm.encoder.layer.0.attention.self.key.weight', 'language_model.model_lm.encoder.layer.0.attention.self.ke

Evaluating: 100%|██████████| 3/3 [00:54<00:00, 18.20s/it]

  Accuracy:    0.8762
  Precision:   0.8333
  Sensitivity: 0.8364
  Specificity: 0.9000



