# core

> Recursive Self-Aggregation (RSA) - A general-purpose LLM aggregation algorithm using litellm

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
from fastcore.all import *
from fastcore.test import *
from litellm import completion
import random, uuid

In [None]:
#| export
class RSACandidate:
    "A candidate response in the RSA algorithm"
    def __init__(self, id:str, loop_id:int, prompt:str, response:str=None, parent_ids:list=None): store_attr()
    def __repr__(self): return f'id:{self.id}\nloop_id:{self.loop_id}\nprompt:\n{self.prompt}\nresponse:\n{self.response}\nparent_ids:\n{self.parent_ids}'

In [None]:
c = RSACandidate(id='c1', loop_id=0, prompt='Hi')
c.response = 'Hey'
test_eq(c.id, 'c1')
test_eq(c.prompt, 'Hi')
c

id:c1
loop_id:0
prompt:
Hi
response:
Hey
parent_ids:
None

In [None]:
#| export
#| export
class RSA:
    "Recursive Self-Aggregation algorithm for LLM response aggregation"
    def __init__(
        self,
        task_prompt:str,  # The main task/question to solve
        agg_prompt:str=None,  # Custom aggregation prompt
        model:str='openrouter/google/gemini-3-flash-preview',  # LLM model to use
        M:int=8,  # Number of candidates per loop
        k:int=4,  # Number of candidates to aggregate
        loops:int=3,  # Number of aggregation loops
        history:list=None,  # History of all candidates
        temperature:float=1.0,  # LLM temperature
        n_workers:int=4  # Parallel workers
    ): 
        if not task_prompt: raise ValueError("task_prompt is required")
        store_attr()
        if not history: self.history = []
    
    def __repr__(self): return f'RSA(model={self.model!r}, \nM={self.M}, \nk={self.k}, \nloops={self.loops}, \nhistory={len(self.history)} candidates, \ntask_prompt={self.task_prompt})'

In [None]:

a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?')
print(a)

RSA(model='openrouter/google/gemini-3-flash-preview', 
M=8, 
k=4, 
loops=3, 
history=0 candidates, 
task_prompt=A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?)


In [None]:
#| export
@patch
def _call_llm(self:RSA, prompt):
    response = completion(
        model=self.model,
        messages=[{"role": "user", "content": prompt}],
        temperature=self.temperature
    )
    return response.choices[0].message.content

In [None]:
#|eval: false
a._call_llm(a.task_prompt)

'The ball costs **5 cents** ($0.05).\n\n**Here is the algebraic breakdown:**\n\n1.  Let $x$ be the cost of the ball.\n2.  The bat costs $1 more than the ball, so the bat is $x + 1.00$.\n3.  The total cost is $1.10.\n\nSo:\n$x + (x + 1.00) = 1.10$\n$2x + 1.00 = 1.10$\n$2x = 0.10$\n**$x = 0.05$**\n\nIf the ball is **$0.05** and the bat is **$1.05**, the total is **$1.10**.'

In [None]:
#| export
@patch
def _agg_prompt(self:RSA, candidates: List[RSACandidate]) -> str:
    if not self.agg_prompt: self.agg_prompt = """You are given question with training examples and a test input.\nYou are also provided several candidate solutions. Some candidates may be incorrect...,\nAggregate/consider all the candidates and use their help to produce the improved correct solution..."""
    parts = [
        self.agg_prompt,
        self.task_prompt,
        "\nCANDIDATE ANSWERS (may contain mistakes):",
    ]
    for i, cand in enumerate(candidates, 1):
        parts.append(f"---- Candidate {i} ----\n{cand.response}")
    parts.append("\nYour response:")
    return "\n".join(parts)

In [None]:
c1 = RSACandidate(id='c1', loop_id=0, prompt='test', response='Answer A')
c2 = RSACandidate(id='c2', loop_id=0, prompt='test', response='Answer B')

print(a._agg_prompt([c1, c2]))

You are given question with training examples and a test input.
You are also provided several candidate solutions. Some candidates may be incorrect...,
Aggregate/consider all the candidates and use their help to produce the improved correct solution...
A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?

CANDIDATE ANSWERS (may contain mistakes):
---- Candidate 1 ----
Answer A
---- Candidate 2 ----
Answer B

Your response:


In [None]:
#| export
@patch
def get_prompts(self:RSA, loop_id, cands=None):
    if not cands: return L(RSACandidate(id=str(uuid.uuid4()), loop_id=loop_id, prompt=self.task_prompt) for _ in range(self.M))
    sel_cands = L.range(self.M).map(lambda _: L(random.sample(cands, self.k)))
    return sel_cands.map(lambda x: RSACandidate(id=str(uuid.uuid4()), loop_id=loop_id, prompt=self._agg_prompt(x), parent_ids=x.attrgot('id')))

In [None]:
# Test loop 0
cands = a.get_prompts(loop_id=0)
test_eq(len(cands), a.M)
test_eq(cands[0].prompt, a.task_prompt)

In [None]:
# Test loop 1+ (with prior candidates)
prior = L(RSACandidate(id=str(uuid.uuid4()), loop_id=0, prompt='test', response=f'Answer {i}') for i in range(8))
cands = a.get_prompts(loop_id=1, cands=prior)
test_eq(len(cands), a.M)

In [None]:
print(cands[0].prompt)

You are given question with training examples and a test input.
You are also provided several candidate solutions. Some candidates may be incorrect...,
Aggregate/consider all the candidates and use their help to produce the improved correct solution...
A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?

CANDIDATE ANSWERS (may contain mistakes):
---- Candidate 1 ----
Answer 7
---- Candidate 2 ----
Answer 0
---- Candidate 3 ----
Answer 3
---- Candidate 4 ----
Answer 4

Your response:


In [None]:
#| export
@patch
def _run_loop(self:RSA, loop_id, pool=None):
    prompts = self.get_prompts(loop_id, pool)
    responses = parallel(self._call_llm, prompts.attrgot('prompt'), n_workers=self.n_workers)
    for p, r in zip(prompts, responses): p.response = r
    return prompts

In [None]:
#|eval: false
cands = a._run_loop(loop_id=0)
test_eq(len(cands), a.M)
assert all(c.response is not None for c in cands)
assert cands[0].response != cands[1].response

In [None]:
#| export
@patch
def run(self:RSA):
    pool = self._run_loop(0)
    self.history.extend(pool)
    for i in range(1, self.loops):
        pool = self._run_loop(i, pool)
        self.history.extend(pool)
    return pool

In [None]:
#|eval: false
a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', loops=2)
result = a.run()
print(f"Final pool: {len(result)}, History: {len(a.history)}")

Final pool: 8, History: 16


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()