# Paper 16: A Simple Neural Network Module for Relational Reasoning
## Adam Santoro, David Raposo, David G.T. Barrett, et al., DeepMind (2017)

### Relation Networks (RN)

Plug-and-play module for reasoning about relationships between objects. Key insight: explicitly compute pairwise relations!

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

np.random.seed(42)

## Relation Network Architecture

Core idea:
```
RN(O) = f_φ( Σ_{i,j} g_θ(o_i, o_j, q) )
```

- **g_θ**: Relation function (processes pairs)
- **f_φ**: Aggregation function (processes relations)
- **O**: Set of objects
- **q**: Query/context

In [None]:
def relu(x):
    return np.maximum(0, x)

class MLP:
    """Simple multi-layer perceptron"""
    def __init__(self, input_dim, hidden_dims, output_dim):
        self.layers = []
        
        # Create layers
        dims = [input_dim] + hidden_dims + [output_dim]
        for i in range(len(dims) - 1):
            W = np.random.randn(dims[i+1], dims[i]) * 0.01
            b = np.zeros((dims[i+1], 1))
            self.layers.append((W, b))
    
    def forward(self, x):
        """Forward pass through MLP"""
        if len(x.shape) == 1:
            x = x.reshape(-1, 1)
        
        for i, (W, b) in enumerate(self.layers):
            x = np.dot(W, x) + b
            # ReLU for all but last layer
            if i < len(self.layers) - 1:
                x = relu(x)
        
        return x.flatten()

# Test MLP
mlp = MLP(input_dim=10, hidden_dims=[20, 20], output_dim=5)
test_input = np.random.randn(10)
output = mlp.forward(test_input)
print(f"MLP output shape: {output.shape}")

## Relation Network Module

In [None]:
class RelationNetwork:
    """
    Relation Network for reasoning about object relationships
    
    RN(O) = f_φ( Σ_{i,j} g_θ(o_i, o_j, q) )
    """
    def __init__(self, object_dim, query_dim, g_hidden_dims, f_hidden_dims, output_dim):
        """
        object_dim: dimension of each object representation
        query_dim: dimension of query/question
        g_hidden_dims: hidden dimensions for g_θ (relation function)
        f_hidden_dims: hidden dimensions for f_φ (aggregation function)
        output_dim: final output dimension
        """
        # g_θ: processes pairs of objects + query
        g_input_dim = object_dim * 2 + query_dim
        g_output_dim = g_hidden_dims[-1] if g_hidden_dims else 256
        self.g_theta = MLP(g_input_dim, g_hidden_dims[:-1], g_output_dim)
        
        # f_φ: processes aggregated relations
        f_input_dim = g_output_dim
        self.f_phi = MLP(f_input_dim, f_hidden_dims, output_dim)
    
    def forward(self, objects, query):
        """
        objects: list of object representations (each is a vector)
        query: query/context vector
        
        Returns: output vector
        """
        n_objects = len(objects)
        
        # Compute relations for all pairs
        relations = []
        
        for i in range(n_objects):
            for j in range(n_objects):
                # Concatenate object pair + query
                pair_input = np.concatenate([objects[i], objects[j], query])
                
                # Apply g_θ to compute relation
                relation = self.g_theta.forward(pair_input)
                relations.append(relation)
        
        # Aggregate relations (sum)
        aggregated = np.sum(relations, axis=0)
        
        # Apply f_φ to get final output
        output = self.f_phi.forward(aggregated)
        
        return output

# Create relation network
rn = RelationNetwork(
    object_dim=8,
    query_dim=4,
    g_hidden_dims=[32, 32, 32],
    f_hidden_dims=[64, 32],
    output_dim=10  # e.g., 10 answer classes
)

# Test with sample objects
test_objects = [np.random.randn(8) for _ in range(5)]
test_query = np.random.randn(4)

output = rn.forward(test_objects, test_query)
print(f"\nRelation Network output: {output[:5]}...")
print(f"Output shape: {output.shape}")

## Sort-of-CLEVR Dataset

Simplified visual reasoning task with colored shapes

In [None]:
class SortOfCLEVR:
    """Generate Sort-of-CLEVR dataset"""
    def __init__(self):
        self.colors = ['red', 'blue', 'green', 'orange', 'yellow', 'purple']
        self.shapes = ['circle', 'square', 'triangle']
        self.sizes = ['small', 'large']
    
    def generate_scene(self, n_objects=6):
        """
        Generate a scene with objects
        Each object: (x, y, color_idx, shape_idx, size_idx)
        """
        objects = []
        used_colors = set()
        
        for i in range(n_objects):
            # Random position
            x = np.random.uniform(0, 1)
            y = np.random.uniform(0, 1)
            
            # Unique color
            available_colors = [c for c in range(len(self.colors)) if c not in used_colors]
            if not available_colors:
                break
            color_idx = np.random.choice(available_colors)
            used_colors.add(color_idx)
            
            # Random shape and size
            shape_idx = np.random.randint(len(self.shapes))
            size_idx = np.random.randint(len(self.sizes))
            
            objects.append({
                'x': x,
                'y': y,
                'color': color_idx,
                'shape': shape_idx,
                'size': size_idx
            })
        
        return objects
    
    def generate_question(self, scene, question_type='relational'):
        """
        Generate questions:
        - Non-relational: "What is the shape of the red object?"
        - Relational: "What is the shape of the object closest to the red object?"
        """
        if question_type == 'relational':
            # Pick a reference object
            ref_obj = np.random.choice(scene)
            
            # Find closest object
            min_dist = float('inf')
            closest_obj = None
            for obj in scene:
                if obj is ref_obj:
                    continue
                dist = np.sqrt((obj['x'] - ref_obj['x'])**2 + (obj['y'] - ref_obj['y'])**2)
                if dist < min_dist:
                    min_dist = dist
                    closest_obj = obj
            
            question = f"Shape of object closest to {self.colors[ref_obj['color']]}?"
            answer = closest_obj['shape']
            
        else:  # non-relational
            # Pick a random object
            obj = np.random.choice(scene)
            question = f"What is the shape of the {self.colors[obj['color']]} object?"
            answer = obj['shape']
        
        return question, answer, question_type

# Generate sample scene
dataset = SortOfCLEVR()
scene = dataset.generate_scene(n_objects=6)

print("Generated scene:")
for i, obj in enumerate(scene):
    print(f"  Object {i}: {dataset.colors[obj['color']]:8s} "
          f"{dataset.shapes[obj['shape']]:8s} {dataset.sizes[obj['size']]:6s} "
          f"at ({obj['x']:.2f}, {obj['y']:.2f})")

# Generate questions
print("\nSample questions:")
for qtype in ['non-relational', 'relational', 'relational']:
    q, a, t = dataset.generate_question(scene, qtype)
    print(f"  [{t:15s}] {q}")
    print(f"  Answer: {dataset.shapes[a]}")

## Visualize Scene

In [None]:
def visualize_scene(scene, dataset):
    """Visualize Sort-of-CLEVR scene"""
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Color mapping
    color_map = {
        'red': 'red',
        'blue': 'blue',
        'green': 'green',
        'orange': 'orange',
        'yellow': 'yellow',
        'purple': 'purple'
    }
    
    for obj in scene:
        x, y = obj['x'], obj['y']
        color = color_map[dataset.colors[obj['color']]]
        shape = dataset.shapes[obj['shape']]
        size = 300 if obj['size'] == 1 else 150
        
        if shape == 'circle':
            ax.scatter([x], [y], s=size, c=color, marker='o', edgecolors='black', linewidths=2)
        elif shape == 'square':
            ax.scatter([x], [y], s=size, c=color, marker='s', edgecolors='black', linewidths=2)
        else:  # triangle
            ax.scatter([x], [y], s=size, c=color, marker='^', edgecolors='black', linewidths=2)
    
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.1)
    ax.set_aspect('equal')
    ax.set_title('Sort-of-CLEVR Scene', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    plt.show()

visualize_scene(scene, dataset)

## Object Representation Encoder

In [None]:
def encode_object(obj, dataset):
    """
    Encode object as vector:
    [x, y, color_one_hot, shape_one_hot, size_one_hot]
    """
    # Position
    pos = np.array([obj['x'], obj['y']])
    
    # One-hot encodings
    color_oh = np.zeros(len(dataset.colors))
    color_oh[obj['color']] = 1
    
    shape_oh = np.zeros(len(dataset.shapes))
    shape_oh[obj['shape']] = 1
    
    size_oh = np.zeros(len(dataset.sizes))
    size_oh[obj['size']] = 1
    
    # Concatenate
    encoding = np.concatenate([pos, color_oh, shape_oh, size_oh])
    return encoding

def encode_question(question_text, ref_color, dataset):
    """
    Encode question as vector (simplified)
    In practice: use LSTM or embeddings
    """
    # One-hot for reference color
    color_oh = np.zeros(len(dataset.colors))
    if ref_color is not None:
        color_oh[ref_color] = 1
    
    # Question type (simplified: 1 for relational, 0 for non-relational)
    is_relational = 1.0 if 'closest' in question_text else 0.0
    
    return np.concatenate([color_oh, [is_relational]])

# Test encoding
obj_encoding = encode_object(scene[0], dataset)
print(f"Object encoding shape: {obj_encoding.shape}")
print(f"Object encoding: {obj_encoding}")

q_encoding = encode_question("Shape of object closest to red?", 0, dataset)
print(f"\nQuestion encoding shape: {q_encoding.shape}")

## Full Pipeline: Scene → Objects → RN → Answer

In [None]:
# Create relation network with correct dimensions
object_dim = 2 + len(dataset.colors) + len(dataset.shapes) + len(dataset.sizes)
query_dim = len(dataset.colors) + 1

rn_visual = RelationNetwork(
    object_dim=object_dim,
    query_dim=query_dim,
    g_hidden_dims=[64, 64, 32],
    f_hidden_dims=[64, 32],
    output_dim=len(dataset.shapes)  # Predict shape
)

# Encode scene
encoded_objects = [encode_object(obj, dataset) for obj in scene]

# Generate question
question, answer, qtype = dataset.generate_question(scene, 'relational')

# Extract reference color from question (simplified)
ref_color = None
for i, color in enumerate(dataset.colors):
    if color in question.lower():
        ref_color = i
        break

encoded_question = encode_question(question, ref_color, dataset)

# Run relation network
prediction = rn_visual.forward(encoded_objects, encoded_question)
predicted_shape = np.argmax(prediction)

print(f"Question: {question}")
print(f"True answer: {dataset.shapes[answer]}")
print(f"Predicted answer: {dataset.shapes[predicted_shape]}")
print(f"\n(Model is untrained, so random prediction)")

## Visualize Relations Between Objects

In [None]:
# Compute pairwise distances (example of relations)
n_objects = len(scene)
distance_matrix = np.zeros((n_objects, n_objects))

for i in range(n_objects):
    for j in range(n_objects):
        dist = np.sqrt((scene[i]['x'] - scene[j]['x'])**2 + 
                      (scene[i]['y'] - scene[j]['y'])**2)
        distance_matrix[i, j] = dist

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Scene with connections
color_map = {'red': 'red', 'blue': 'blue', 'green': 'green', 
            'orange': 'orange', 'yellow': 'yellow', 'purple': 'purple'}

for i, obj_i in enumerate(scene):
    for j, obj_j in enumerate(scene):
        if i != j:
            # Draw connection (thicker = closer)
            dist = distance_matrix[i, j]
            alpha = np.exp(-dist * 2)  # Closer objects = higher alpha
            ax1.plot([obj_i['x'], obj_j['x']], [obj_i['y'], obj_j['y']], 
                    'k-', alpha=alpha, linewidth=1)

for obj in scene:
    color = color_map[dataset.colors[obj['color']]]
    ax1.scatter([obj['x']], [obj['y']], s=300, c=color, 
               edgecolors='black', linewidths=3, zorder=5)
    ax1.text(obj['x'], obj['y']-0.08, dataset.colors[obj['color']], 
            ha='center', fontsize=9, fontweight='bold')

ax1.set_xlim(-0.1, 1.1)
ax1.set_ylim(-0.2, 1.1)
ax1.set_aspect('equal')
ax1.set_title('Object Relations (spatial)', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Distance matrix
im = ax2.imshow(distance_matrix, cmap='viridis')
ax2.set_xlabel('Object', fontsize=12)
ax2.set_ylabel('Object', fontsize=12)
ax2.set_title('Pairwise Distances', fontsize=14, fontweight='bold')
plt.colorbar(im, ax=ax2, label='Distance')

plt.tight_layout()
plt.show()

print(f"\nRelation Network considers ALL {n_objects * (n_objects - 1)} pairs!")

## Permutation Invariance Test

In [None]:
# Test that RN is invariant to object order
test_objects = [np.random.randn(object_dim) for _ in range(4)]
test_query = np.random.randn(query_dim)

# Original order
output1 = rn_visual.forward(test_objects, test_query)

# Shuffled order
shuffled_objects = test_objects.copy()
np.random.shuffle(shuffled_objects)
output2 = rn_visual.forward(shuffled_objects, test_query)

# Check if outputs are the same
diff = np.linalg.norm(output1 - output2)

print("Permutation Invariance Test:")
print(f"Original output: {output1[:4]}...")
print(f"Shuffled output: {output2[:4]}...")
print(f"Difference: {diff:.10f}")
print(f"\n{'✓ PASSED' if diff < 1e-10 else '✗ FAILED'}: RN is permutation invariant!")

## Compare with Baseline (No Relational Reasoning)

In [None]:
class BaselineNetwork:
    """
    Baseline: just concatenate all objects + query, no explicit relations
    """
    def __init__(self, object_dim, query_dim, max_objects, output_dim):
        # Concatenate all objects + query
        input_dim = object_dim * max_objects + query_dim
        self.mlp = MLP(input_dim, [128, 64], output_dim)
        self.max_objects = max_objects
        self.object_dim = object_dim
    
    def forward(self, objects, query):
        # Pad or truncate to max_objects
        padded = []
        for i in range(self.max_objects):
            if i < len(objects):
                padded.append(objects[i])
            else:
                padded.append(np.zeros(self.object_dim))
        
        # Concatenate everything
        concat = np.concatenate(padded + [query])
        return self.mlp.forward(concat)

# Create baseline
baseline = BaselineNetwork(object_dim, query_dim, max_objects=10, output_dim=len(dataset.shapes))

# Test
baseline_output = baseline.forward(encoded_objects, encoded_question)

print("Baseline Network (no explicit relations):")
print(f"Output: {baseline_output}")
print(f"\nBaseline doesn't explicitly reason about pairs!")

## Key Takeaways

### Relation Network (RN) Formula:

$$
\text{RN}(O) = f_\phi \left( \sum_{i,j} g_\theta(o_i, o_j, q) \right)
$$

Where:
- $O = \{o_1, o_2, ..., o_n\}$: Set of objects
- $g_\theta$: Relation function (MLP) - reasons about pairs
- $f_\phi$: Aggregation function (MLP) - combines relations
- $q$: Query/context (e.g., question)

### Key Properties:

1. **Explicit Pairwise Relations**: 
   - Considers all $n^2$ pairs (or $\binom{n}{2}$ unique pairs)
   - Each pair processed independently by $g_\theta$

2. **Permutation Invariance**:
   - Sum aggregation → order doesn't matter
   - $\text{RN}(\{o_1, o_2\}) = \text{RN}(\{o_2, o_1\})$

3. **Compositional**:
   - Can plug into any architecture
   - Objects from CNN, LSTM, etc.

### Architecture Details:

**For visual QA**:
```
Image → CNN → Feature maps → Objects (spatial positions)
Question → LSTM → Query embedding
Objects + Query → RN → Answer
```

**For text**:
```
Sentence → LSTM → Word embeddings → Objects
Query → Embedding
Objects + Query → RN → Answer
```

### Computational Complexity:

- **Pairs**: $O(n^2)$ where $n$ = number of objects
- **g_θ evaluations**: $n^2$ forward passes
- Can be expensive for large $n$
- Can use $i \neq j$ to exclude self-pairs → $n(n-1)$ pairs

### Results:

**Sort-of-CLEVR**:
- Relational questions: 96% (RN) vs 63% (CNN baseline)
- Non-relational: 98% (RN) vs 98% (CNN)

**CLEVR** (full dataset):
- 95.5% accuracy (superhuman performance!)
- Previous best: 68.5%

**bAbI**:
- 18/20 tasks with single model
- Strong performance on relational reasoning tasks

### Why It Works:

1. **Inductive bias**: Explicitly models relations
2. **Data efficiency**: Structured computation → less data needed
3. **Interpretability**: Can visualize $g_\theta$ outputs
4. **Generalization**: Learns relational patterns

### Comparison with Other Approaches:

| Approach | Pairwise Relations | Permutation Invariant | Complexity |
|----------|-------------------|----------------------|------------|
| CNN | Implicit | ✗ | $O(n)$ |
| RNN/LSTM | Sequential | ✗ | $O(n)$ |
| Attention | Weighted pairs | ✓ | $O(n^2)$ |
| **RN** | **Explicit** | **✓** | **$O(n^2)$** |
| Graph NN | Explicit (edges) | ✓ | $O(|E|)$ |

### Extensions:

- **Self-attention**: Special case of RN with learnable aggregation
- **Transformers**: Attention = relation reasoning!
- **Graph NNs**: RN on graph structure
- **Relational LSTM**: RN + recurrence

### Limitations:

- $O(n^2)$ complexity (expensive for large $n$)
- Sum aggregation may lose information
- Requires object extraction (non-trivial for images)

### Applications:

- Visual QA
- Physics prediction
- Multi-agent systems
- Graph reasoning
- Relational databases
- Any task with structured objects!