# core

> Recursive Self-Aggregation (RSA) - A general-purpose LLM aggregation algorithm using litellm based on the paper **https://rsa-llm.github.io/**

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
from fastcore.all import *
from fastcore.test import *
from litellm import completion
import random, uuid
from math import comb
from itertools import combinations
from fastprogress import progress_bar

In [None]:
#| export
class RSACandidate:
    "A candidate response in the RSA algorithm"
    def __init__(self, id:str, loop_id:int, prompt:str, response:str=None, parent_ids:list=None): store_attr()
    def __repr__(self): return f'id:{self.id}\nloop_id:{self.loop_id}\nprompt:\n{self.prompt}\nresponse:\n{self.response}\nparent_ids:\n{self.parent_ids}'

In [None]:
c = RSACandidate(id='c1', loop_id=0, prompt='Hi')
c.response = 'Hey'
test_eq(c.id, 'c1')
test_eq(c.prompt, 'Hi')
c

In [None]:
#| export
class RSA:
    "Recursive Self-Aggregation algorithm for LLM response aggregation"
    def __init__(
        self,
        task_prompt:str,  # The main task/question to solve
        agg_prompt:str=None,  # Custom aggregation prompt
        model:str='openrouter/google/gemini-3-flash-preview',  # LLM model to use
        N:int=4,  # Population size (candidates per loop)
        K:int=3,  # Number of candidates to aggregate
        loops:int=2,  # Number of aggregation loops
        history:list=None,  # History of all candidates
        temperature:float=1.0,  # LLM temperature
        n_workers:int=4  # Parallel workers
    ): 
        if not task_prompt: raise ValueError("task_prompt is required")
        if comb(N, K) < N: raise ValueError(f"C({N},{K})={comb(N,K)} < N={N}; need C(N,K) >= N for aggregation loops")
        store_attr()
        if not history: self.history = L()
        if not self.agg_prompt: self.agg_prompt = """You are given question with training examples and a test input.\nYou are also provided several candidate solutions. Some candidates may be incorrect\nAggregate/consider all the candidates and use their help to produce the improved correct solution"""
    
    def __repr__(self): return f'RSA(model={self.model!r}, \nN={self.N}, \nK={self.K}, \nloops={self.loops}, \nhistory={len(self.history)} candidates, \ntask_prompt={self.task_prompt})'

In [None]:

a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?')
print(a)

In [None]:
#| export
@patch
def _call_llm(self:RSA, prompt, **kwargs):
    "Call the LLM with the given prompt and return the response content"
    response = completion(
        model=self.model,
        messages=[{"role": "user", "content": prompt}],
        temperature=self.temperature,
        num_retries=3,
        **kwargs
    )
    return response.choices[0].message.content

In [None]:
#|eval: false
a._call_llm(a.task_prompt)

## Configuration

RSA uses [litellm](https://docs.litellm.ai/) for LLM calls, which automatically reads API keys from environment variables:

- `OPENAI_API_KEY` for OpenAI models
- `ANTHROPIC_API_KEY` for Anthropic models  
- `OPENROUTER_API_KEY` for OpenRouter models
- etc.

You can also set a custom endpoint globally:

```python
import litellm
litellm.api_base = "https://your-endpoint.com/v1"
```

See [litellm's provider docs](https://docs.litellm.ai/docs/providers) for the full list of supported providers and their environment variables.

In [None]:
#| export
@patch
def _build_agg_prompt(self:RSA, candidates: list[RSACandidate]) -> str:
    "Build an aggregation prompt combining the task prompt with candidate responses"
    parts = [
        self.agg_prompt,
        self.task_prompt,
        "\nCANDIDATE ANSWERS (may contain mistakes):",
    ]
    for i, cand in enumerate(candidates, 1):
        parts.append(f"---- Candidate {i} ----\n{cand.response}")
    parts.append("\nYour response:")
    return "\n".join(parts)

In [None]:
c1 = RSACandidate(id='c1', loop_id=0, prompt='test', response='Answer A')
c2 = RSACandidate(id='c2', loop_id=0, prompt='test', response='Answer B')

print(a._build_agg_prompt([c1, c2]))

In [None]:
#| export
@patch
def get_prompts(self:RSA, loop_id, cands=None):
    "Generate candidate prompts for a given loop: N initial candidates, or all C(n,K) combinations for aggregation"
    if not cands: return L(RSACandidate(id=str(uuid.uuid4()), loop_id=loop_id, prompt=self.task_prompt) for _ in range(self.N))
    sel_cands = L(combinations(cands, self.K)).shuffle()[:self.N]
    return sel_cands.map(lambda x: RSACandidate(id=str(uuid.uuid4()), loop_id=loop_id, prompt=self._build_agg_prompt(x), parent_ids=L(x).attrgot('id')))

In [None]:
# Test loop 0
cands = a.get_prompts(loop_id=0)
test_eq(len(cands), a.N)
test_eq(cands[0].prompt, a.task_prompt)

In [None]:
# Test loop 1+ (with prior candidates)
prior = L(RSACandidate(id=str(uuid.uuid4()), loop_id=0, prompt='test', response=f'Answer {i}') for i in range(8))
cands = a.get_prompts(loop_id=1, cands=prior)
test_eq(len(cands), a.N)

In [None]:
print(cands[0].prompt)

In [None]:
#| export
@patch
def _run_loop(self:RSA, loop_id, pool=None):
    "Execute one loop: generate prompts, call LLM in parallel, attach responses"
    prompts = self.get_prompts(loop_id, pool)
    responses = parallel(self._call_llm, prompts.attrgot('prompt'), n_workers=self.n_workers)
    for p, r in zip(prompts, responses): p.response = r
    return prompts

In [None]:
#|eval: false
cands = a._run_loop(loop_id=0)
test_eq(len(cands), a.N)
assert all(c.response is not None for c in cands)
assert cands[0].response != cands[1].response

In [None]:
#| export
@patch
def run(self:RSA):
    "Run the full RSA algorithm for the configured number of loops and return the final candidate pool"
    pool = None
    pbar = progress_bar(range(self.loops))
    for i in pbar:
        pbar.comment = f"Loop {i+1}"
        pool = self._run_loop(i, pool)
        self.history.extend(pool)
    return pool

In [None]:
#|eval: false
a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', loops=2)
result = a.run()
print(f"Final pool: {len(result)}, History: {len(a.history)}")

In [None]:
#| export
@patch
def aggregate(self:RSA, method='llm', final_agg_prompt=None, response_format=None):
    "Final aggregation: one LLM call to aggregate all final loop candidates, with optional structured output"
    if method.lower() not in ['llm', 'random']: raise ValueError(f"method must be 'llm' or 'random', got {method!r}")
    if not self.history: self.run()
    candidates = self.history.filter(lambda x: x.loop_id==(self.loops-1))
    if method.lower() == 'random': return '', candidates.shuffle()[0].response
    custom_agg_prompt = final_agg_prompt or self.agg_prompt
    responses = '\n'.join(f"---- Candidate {i+1} ----\n{c.response}" for i, c in enumerate(candidates))
    prompt = f"{custom_agg_prompt}\n\nPrompt:\n{self.task_prompt}\n\nCANDIDATE ANSWERS:\n{responses}\n\nProvide the best aggregated answer:"
    result = self._call_llm(prompt, **({'response_format': response_format} if response_format else {}))
    return prompt, result

In [None]:
#|eval: false
# Test with custom aggregation prompt
custom_prompt = "Combine these answers into one final answer:"
prompt, result = a.aggregate(final_agg_prompt=custom_prompt)
assert custom_prompt in prompt
assert isinstance(result, str)
assert len(result) > 0

# Test with response_format (structured output)
from pydantic import BaseModel

class Answer(BaseModel):
    answer: str
    confidence: float

prompt, result = a.aggregate(response_format=Answer)
assert isinstance(result, str)  # litellm returns JSON string
print(prompt, result)

In [None]:
#|eval: false
# Test 'llm' aggregation
prompt, result = a.aggregate(method='llm')
assert isinstance(prompt, str)
assert isinstance(result, str)
assert len(result) > 0
assert 'CANDIDATE ANSWERS' in prompt

# Test 'random' aggregation
prompt, result = a.aggregate(method='random')
test_eq(prompt, '')
assert isinstance(result, str)
assert len(result) > 0

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

how can i improve mine?

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ 

Here are the key improvements to consider, roughly in order of impact:

1. **Different temperatures per phase** â€” Use higher temperature (e.g., 1.0-1.2) for initial generation to maximize diversity, then lower temperature (e.g., 0.7) for aggregation steps where you want more focused synthesis.

2. **Tracking diversity** â€” The paper notes diversity should decrease over loops as errors get pruned. You could add a simple diversity metric (e.g., embedding similarity between candidates) to monitor this and potentially adjust parameters.

3. **Smarter candidate selection** â€” Instead of random `shuffle()[:N]` for selecting which K-combinations to use, you could prioritize combinations that maximize diversity (picking candidates that are more different from each other).

4. **Configurable aggregation prompt per loop** â€” The paper suggests the aggregation task may benefit from different prompting strategies at different stages. You could allow `agg_prompt` to be a callable that takes `loop_id`.

5. **Early stopping** â€” If all candidates converge to very similar answers before `loops` is reached, you could stop early.

Which of these interests you most? I can help you think through the implementation details.