# rsa_dspy

> Recursive Self-Aggregation (RSA) using DSPy modules based on the paper **https://rsa-llm.github.io/**

In [None]:
#| default_exp rsa_dspy

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import dspy
from fastcore.all import *
from fastcore.test import *
import uuid
from math import comb
from itertools import combinations
from fastprogress import progress_bar

In [None]:
#| export
class RSACandidate:
    "A candidate response in the RSA algorithm"
    def __init__(self, id:str, loop_id:int, task_prompt:str, signature:type=None, candidates_str:str=None, response:str=None, parent_ids:list=None): store_attr()
    def __repr__(self): return f'id:{self.id}\nloop_id:{self.loop_id}\ntask_prompt:\n{self.task_prompt}\nresponse:\n{self.response}\nparent_ids:\n{self.parent_ids}'

In [None]:
#| hide
class TaskSolver(dspy.Signature):
    """Solve the given task/question."""
    task = dspy.InputField(desc="The task or question to solve")
    response = dspy.OutputField(desc="Your solution to the task")

class AggregateResponses(dspy.Signature):
    """Aggregate multiple candidate solutions into an improved answer."""
    task = dspy.InputField(desc="The original task/question")
    candidates = dspy.InputField(desc="Candidate solutions (may contain mistakes)")
    response = dspy.OutputField(desc="The improved aggregated solution")

In [None]:
#| export
class RSA:
    "Recursive Self-Aggregation algorithm using DSPy"
    def __init__(
        self,
        task_prompt:str,  # The main task/question to solve
        solver=None,  # task signature
        aggregator=None,  # aggregator signature
        N:int=4,  # Population size (candidates per loop)
        K:int=3,  # Number of candidates to aggregate
        loops:int=2,  # Number of aggregation loops
        history:list=None,  # History of all candidates
    ): 
        if not task_prompt: raise ValueError("task_prompt is required")
        if comb(N, K) < N: raise ValueError(f"C({N},{K})={comb(N,K)} < N={N}; need C(N,K) >= N for aggregation loops")
        store_attr()
        if not history: self.history = L()
        self.solve = dspy.ChainOfThought(solver)
        self.aggregate = dspy.ChainOfThought(aggregator)
    
    def __repr__(self): return f'RSA(N={self.N}, \nK={self.K}, \nloops={self.loops}, \nhistory={len(self.history)} candidates, \ntask_prompt={self.task_prompt})'

In [None]:
a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', solver=TaskSolver, aggregator=AggregateResponses)
print(a)

In [None]:
#| export
@patch
def _mk_candidate_str(self:RSA, candidates):
    "Create aggregation prompt from candidate responses"
    return '\n'.join(f"---- Candidate {i+1} ----\n{c.response}" for i, c in enumerate(candidates))

In [None]:
c1 = RSACandidate(id='c1', loop_id=0, task_prompt='test', response='Answer A')
c2 = RSACandidate(id='c2', loop_id=0, task_prompt='test', response='Answer B')

print(a._mk_candidate_str([c1, c2]))

In [None]:
#| export
@patch
def get_prompts(self:RSA, loop_id, cands=None):
    "Generate candidate prompts for a given loop: N initial candidates, or all C(n,K) combinations for aggregation"
    if not cands: 
        return L(RSACandidate(id=str(uuid.uuid4()), loop_id=loop_id, task_prompt=self.task_prompt, 
                              signature=self.solver, candidates_str=None) for _ in range(self.N))
    sel_cands = L(combinations(cands, self.K)).shuffle()[:self.N]
    return sel_cands.map(lambda x: RSACandidate(
        id=str(uuid.uuid4()), loop_id=loop_id, task_prompt=self.task_prompt,
        signature=self.aggregator, candidates_str=self._mk_candidate_str(x), 
        parent_ids=L(x).attrgot('id')))

In [None]:
# Test loop 0
cands = a.get_prompts(loop_id=0)
test_eq(len(cands), a.N)
test_eq(cands[0].signature, a.solver)

In [None]:
# Test loop 1+ (with prior candidates)
prior = L(RSACandidate(id=str(uuid.uuid4()), loop_id=0, task_prompt='test', response=f'Answer {i}') for i in range(8))
cands = a.get_prompts(loop_id=1, cands=prior)
test_eq(len(cands), a.N)
print(cands[0].task_prompt)

In [None]:
#| export
@patch
def _run_loop(self:RSA, loop_id, pool=None):
    "Execute one loop: generate prompts, call DSPy modules, attach responses"
    prompts = self.get_prompts(loop_id, pool)
    if pool is None:
        exec_pairs = [(self.solve, dict(task=p.task_prompt)) for p in prompts]
    else:
        exec_pairs = [(self.aggregate, dict(task=p.task_prompt, candidates=p.candidates_str)) for p in prompts]
    results = dspy.Parallel()(exec_pairs)
    for p, r in zip(prompts, results): p.response = r
    return prompts

## Configuration

RSA-DSPy uses [dspy](https://dspy.ai/) for LLM calls. Configure your LM globally:

```python
dspy.configure(lm=dspy.LM('openrouter/google/gemini-3-flash-preview', temperature=1.0))
```

See [DSPy's LM documentation](https://dspy.ai/learn/language_models/) for supported providers.

In [None]:
#|eval: false
dspy.configure(lm=dspy.LM('openrouter/google/gemini-3-flash-preview', temperature=1.0, cache=False))

cands = a._run_loop(loop_id=0)
test_eq(len(cands), a.N)
assert all(c.response is not None for c in cands)
assert cands[0].response != cands[1].response

In [None]:
#|eval: false
cands[0].response

In [None]:
#| export
@patch
def run(self:RSA):
    "Run the full RSA algorithm for the configured number of loops and return the final candidate pool"
    pool = None
    pbar = progress_bar(range(self.loops))
    for i in pbar:
        pbar.comment = f"Loop {i+1}"
        pool = self._run_loop(i, pool)
        self.history.extend(pool)
    return pool

In [None]:
#|eval: false
a = RSA(task_prompt='A bat and ball cost $1.10 total. The bat costs $1 more than the ball. How much does the ball cost?', solver=TaskSolver, aggregator=AggregateResponses, loops=2)
result = a.run()
print(f"Final pool: {len(result)}, History: {len(a.history)}")

In [None]:
#| export
@patch
def final_aggregate(self:RSA, method='llm', signature=None):
    "Final aggregation: one LLM call to aggregate all final loop candidates"
    if method.lower() not in ['llm', 'random']: raise ValueError(f"method must be 'llm' or 'random', got {method!r}")
    if not self.history: self.run()
    candidates = self.history.filter(lambda x: x.loop_id==(self.loops-1))
    if method.lower() == 'random': return candidates.shuffle()[0].response
    agg = dspy.ChainOfThought(signature) if signature else self.aggregate
    result = agg(task=self.task_prompt, candidates=self._mk_candidate_str(candidates))
    return result

In [None]:
#|eval: false
# Test 'llm' aggregation
result = a.final_aggregate(method='llm')
assert isinstance(result, dspy.Prediction)
assert len(result.response) > 0
print(result)

# Test 'random' aggregation
result = a.final_aggregate(method='random')
assert isinstance(result, dspy.Prediction)
assert len(result.response) > 0

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()