Here we will explore the ideas from `EFFICIENTLY LEARNING AT TEST-TIME: ACTIVE FINE-TUNING OF LLMS` from Jonas Hubotter, Sascha Bongni, Ido Hakimi, and Andreas Krause. The paper is [here](https://arxiv.org/pdf/2410.08020) and repo is [here](https://github.com/jonhue/activeft). The idea is to fine tune at test time on a small sample of data. Of particular interest is the sample of data that is used. Rather than typical retrieval where nearest neighbors are selected from the embedding space, they introduce a more data efficient method of selecting data. This is useful not only for retraining, but also for RAG type applications.

In [1]:
import tiktoken
import torch
from torch.utils.data import DataLoader, Dataset
from utils import json_to_dataframe, json_to_string_list

# Define dataloader

def encode_and_decode_example(list_of_strings):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Get the token ID for <|endoftext|>
    endoftext_token = tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]

    all_tokens = []
    for text in list_of_strings:
        # Encode the text
        encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
        all_tokens.extend(encoded + [endoftext_token])

    # Decode the tokens
    decoded = tokenizer.decode(all_tokens)

    return all_tokens, decoded

In [5]:
class GPTDatasetV1(Dataset):
    def __init__(self, articles, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Get the token ID for <|endoftext|>
        endoftext_token = tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]

        # Tokenize all articles with end-of-text token
        all_tokens = []
        for article in articles:
            article_tokens = tokenizer.encode(article, allowed_special={"<|endoftext|>"})
            all_tokens.extend(article_tokens + [endoftext_token])

        # Use a sliding window to chunk the tokens into overlapping sequences of max_length
        for i in range(0, len(all_tokens) - max_length, stride):
            input_chunk = all_tokens[i:i + max_length]
            target_chunk = all_tokens[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [6]:
def create_dataloader_v1(articles, batch_size=4, max_length=256, 
                         stride=128, shuffle=True, drop_last=True, num_workers=0):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(articles, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader

## Load radiology reports dataset

In [7]:
filepath = '../../data/vector_veterinary_imaging_2.json'

df = json_to_dataframe(filepath) 
rad_strings = json_to_string_list(filepath)

In [8]:
df

Unnamed: 0,case_identifier,findings,conclusions_and_recommendations
0,181153,Orthogonal pelvis and orthogonal right shoulde...,1. Medial right mildly comminuted acetabular f...
1,181413,"Three view whole body images dated April 14, 2...",The material within the stomach and small inte...
2,181821,Three view thoracic radiographs (total of 5 th...,No aggressive osseous changes are noted. The b...
3,181886,Orthogonal images of the right pelvic limb are...,"1. Chronic right calcaneal tendonopathy, with ..."
4,181911,Lateral abdomen and pelvis images are provided...,"1. Numerous small urinary cystoliths, non-obst..."
...,...,...,...
1756,274208,Three view thorax and three view abdomen image...,Aggressive osseous change of the L6 vertebral ...
1757,274229,Orthogonal thorax and three view abdomen image...,Right cranial pulmonary mass. This is most lik...
1758,274244,"Liver: Diffusely homogenously hyperechoic, oth...",At least one gastric mural nodule extending in...
1759,274249,Ventrodorsal pelvis and orthogonal stifles ima...,"Right coxofemoral subluxation, progressive fro..."


In [9]:
len(rad_strings)

1761

In [10]:
rad_strings[0]

'Findings: Orthogonal pelvis and orthogonal right shoulder and lateral left shoulder images dated April 17, 2023 are provided for review (total of 5 images). Shoulders: A sagittal plane fracture is present through the right scapular body, where the spine meets the body, extending cranially through the cranial margin of the acromion. this fracture does not articulate with the glenoid rims or the scapulohumeral joint. This fracture is visualized on the craniocaudal image, not visualized on the lateral image, thought due to superimposition. The fracture is non-displaced. Small fissures are suspected extending into the scapular spine. The right first rib is fractured in the body. A non-displaced fracture is also suspected in the body of the right second rib. The visible scapula, scapulohumeral margins, and humerus of the left shoulder are normal. The included cervical and thoracic spine is normal. Pelvis: A mildly comminuted segment fracture is present through the medial and cranial third 

## Create data pipeline

In [11]:
vocab_size = 50257
output_dim = 256
max_len = 1024
context_length = max_len

In [12]:
import torch
import torch.nn as nn

token_embedding_layer = nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

In [13]:
max_length = 4
dataloader = create_dataloader_v1(rad_strings, batch_size=8, max_length=max_length, stride=max_length)

In [14]:
def inspect_batch(x, y, n_samples=2):
    for i in range(min(n_samples, len(x))):
        tokenizer = tiktoken.get_encoding("gpt2")
        
        print(f"\nSample {i+1}:")
        
        # Decode and print the input sequence
        input_text = tokenizer.decode(x[i].tolist())
        print(f"Input text: {input_text}")
        print(f"Input encoding: {x[i].tolist()}")
        
        # Decode and print the target sequence
        target_text = tokenizer.decode(y[i].tolist())
        print(f"Target text: {target_text}")
        print(f"Target encoding: {y[i].tolist()}")
        
        print("-" * 50)

In [15]:
INSPECT = True

In [16]:
for batch in dataloader:
    x, y = batch

    if INSPECT:
        # Visual inspection
        inspect_batch(x, y)

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break


Sample 1:
Input text:  mild left shoulder oste
Input encoding: [11607, 1364, 8163, 32674]
Target text:  left shoulder osteo
Target encoding: [1364, 8163, 32674, 78]
--------------------------------------------------

Sample 2:
Input text: , though it is
Input encoding: [11, 996, 340, 318]
Target text:  though it is unclear
Target encoding: [996, 340, 318, 10061]
--------------------------------------------------


In [17]:
print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Undirected data selection

Explore using SIFT to select the most informative data without having a specific task

In [2]:
import faiss
from activeft.sift import Retriever

In [3]:
# Before Test-Time
embeddings = torch.randn(100, 64)
index = faiss.IndexFlatIP(embeddings.size(1))
index.add(embeddings)
retriever = Retriever(index)

In [None]:
# At Test-Time, given query
query_embeddings = torch.randn(1, 64)
indices = retriever.search(query_embeddings, N=10, K=50)
data = embeddings[indices]