# Networks

This guide provides technical reference for implementing and customizing networks in Starsim. For learning-oriented examples, see the Networks tutorial.

## Network architecture overview

Starsim networks inherit from the base `Network` class and manage collections of edges (connections between agents). All networks must implement methods for creating, updating, and removing connections between agents.

### Key network components:
- **Edges**: The fundamental unit - connections between two agents with transmission parameters
- **Parameters**: Network-specific settings that control connection behavior
- **State management**: Tracking which agents can form connections
- **Update methods**: Logic for adding/removing connections over time

## Network class hierarchy

```
Network (base class)
├── RandomNet - Random connections between agents
├── MFNet - Male-female sexual partnerships  
├── MSMNet - Male-male sexual partnerships
├── MaternalNet - Mother-child connections during pregnancy
└── Custom networks (inherit from above)
```

## Core implementation methods

| Method | Purpose | When to override |
|--------|---------|------------------|
| `__init__()` | Initialize parameters and state | Always for custom networks |
| `init_pre()` | Pre-simulation initialization | Rarely |
| `init_post()` | Post-simulation initialization | When you need access to people |
| `add_pairs()` | Create new connections | Key method for custom logic |
| `remove_pairs()` | Remove existing connections | For dynamic networks |
| `update()` | Main update method each timestep | Rarely - calls add/remove pairs |
| `find_contacts()` | Query network connections | Override for custom queries |

## Mixing Pools

Mixing pools simulate well-mixed transmission between groups rather than individual agent-to-agent connections. They are computationally efficient for large populations and ideal when you have contact matrices from epidemiological studies.

### When to use mixing pools
- **Large populations**: More efficient than individual contact networks
- **Group-based mixing**: Age groups, risk groups, geographic regions
- **Contact matrices**: When you have empirical mixing data (e.g., Prem et al.)
- **Homogeneous mixing assumptions**: Within-group transmission is well-mixed

### Mixing pool architecture

#### Single mixing pool (`MixingPool`)
Models transmission from a source group to a destination group:

```python
mp = ss.MixingPool(
    diseases = 'sir',  # Which diseases use this pool
    src = lambda sim: sim.people.age < 15,  # Source group (children)  
    dst = ss.AgeGroup(low=15, high=None),   # Destination group (adults)
    contacts = ss.poisson(lam=2),           # Contact distribution
    beta = ss.TimeProb(0.2),               # Transmission probability
)
```

#### Multiple mixing pools (`MixingPools`)
Models transmission between multiple groups using contact matrices:

```python
# Define age groups
groups = ss.ndict([ss.AgeGroup(low=low, high=high) 
                   for low, high in zip([0,20,40,60], [20,40,60,100])])

# Contact matrix (source=rows, destination=columns)
contact_matrix = np.array([
    [5, 2, 1, 0],  # Children -> [Children, Young adults, Adults, Elderly]
    [2, 4, 3, 1],  # Young adults -> [...]
    [1, 3, 4, 2],  # Adults -> [...]
    [0, 1, 2, 3],  # Elderly -> [...]
])

mps = ss.MixingPools(
    src = groups,
    dst = groups, 
    contacts = contact_matrix,
    beta = ss.TimeProb(0.15),
)
```

### Transmission mechanics

Mixing pools calculate transmission using average infectious levels:

1. **Transmission potential**: `trans = mean(infectious * rel_trans)` for source group
2. **Acquisition potential**: `acq = contacts * susceptible * rel_sus` for each destination agent  
3. **Transmission probability**: `p = beta * trans * acq`

This differs from contact networks where transmission occurs on specific edges between individuals.

### Implementation patterns

#### Pattern 1: Age-structured mixing
```python
# Define age bins
age_bins = np.arange(0, 81, 5)  # 5-year age groups up to 80+
groups = ss.ndict([ss.AgeGroup(low=low, high=high) 
                   for low, high in zip(age_bins[:-1], age_bins[1:])])

# Load contact matrix (e.g., from Prem et al.)
contact_matrix = load_contact_matrix('country_name')

mps = ss.MixingPools(
    src=groups, dst=groups,
    contacts=contact_matrix,
    beta=ss.TimeProb(0.1)
)
```

#### Pattern 2: Multi-attribute mixing  
```python
# Mixing by both age and risk
risk_levels = ['low', 'medium', 'high']
age_groups = ['young', 'adult', 'elderly']

# Create combined groups
combined_groups = {}
for risk in risk_levels:
    for age in age_groups:
        key = f'{risk}_{age}'
        combined_groups[key] = lambda sim, r=risk, a=age: (
            (sim.people.risk_level == r) & (sim.people.age_group == a)
        )

# Define block-structured contact matrix
# Structure: risk levels within age groups
contact_matrix = create_block_contact_matrix(risk_levels, age_groups)
```

#### Pattern 3: Dynamic group membership
```python
class DynamicMixingPools(ss.MixingPools):
    def update_pre(self):
        """Update group membership before transmission"""
        # Recalculate group memberships based on current states
        self.src_uids = {
            'school': ss.uids(self.sim.people.in_school),
            'work': ss.uids(self.sim.people.employed),
            'home': ss.uids(self.sim.people.alive),  # Everyone has household contacts
        }
        self.dst_uids = self.src_uids.copy()
```

### Performance optimization

#### Efficient group calculations
```python
# Good: Use AgeGroup class for caching
age_groups = ss.ndict([ss.AgeGroup(low=i*10, high=(i+1)*10) for i in range(8)])

# Less efficient: Lambda functions recalculate each time  
age_groups = {f'age_{i}': lambda sim, i=i: (sim.people.age >= i*10) & (sim.people.age < (i+1)*10)
              for i in range(8)}
```

#### Contact matrix considerations
- Pre-calculate matrices when possible
- Use sparse matrices for mostly-zero contact patterns
- Consider symmetry constraints where appropriate

### Debugging mixing pools

#### Common issues
1. **Mismatched dimensions**: Contact matrix size must match number of groups
2. **Empty groups**: Check that source/destination groups are non-empty
3. **Beta scaling**: Remember to use `ss.TimeProb()` for time unit conversion
4. **Group overlap**: Ensure groups don't double-count individuals

#### Diagnostic tools
```python
# Check group sizes
for name, group_func in groups.items():
    size = group_func(sim).sum()
    print(f'{name}: {size} agents')

# Visualize contact matrix
plt.imshow(contact_matrix, cmap='viridis')
plt.colorbar(label='Contacts per day')

# Track transmission by group (custom analyzer)
class MixingPoolAnalyzer(ss.Analyzer):
    def step(self):
        # Track new infections by source group
        # Implementation depends on specific needs
        pass
```

## Performance considerations

### Efficient pairing algorithms
- Use boolean indexing: `eligible = people.alive & condition`
- Batch operations when possible
- Avoid loops over individual agents
- Pre-calculate weights/probabilities outside loops

### Memory management
- Remove expired edges regularly
- Use appropriate data types (boolean arrays vs integer arrays)
- Consider edge limits for very large populations

### Scaling considerations
```python
# Good: Vectorized operations
eligible_uids = ss.uids(people.alive & people.eligible)
n_pairs = len(eligible_uids) // 2

# Bad: Individual loops
for uid in all_uids:
    if people.alive[uid] and people.eligible[uid]:
        # ... individual processing
```

## Common implementation gotchas

### Parameter handling
- Always call `self.update_pars(**kwargs)` in `__init__`
- Use `self.pars.parameter_name` to access parameters  
- Define defaults with `self.define_pars(param=default_value)`

### Edge array management
- Edges are automatically managed - don't modify directly unless needed
- Use provided methods: `append()`, `remove()`, `find_contacts()`
- Check edge array sizes before operations: `if len(self.edges) > 0:`

### Agent state consistency
- Always check `people.alive` before forming connections
- Consider agent capacity (e.g., maximum number of partnerships)
- Handle birth/death events properly in dynamic networks

### Network naming conventions
- Network names in simulations are automatically lowercase class names
- Access via `sim.networks.networkname` (e.g., `sim.networks.mfnet`)
- Disease beta parameters should use the lowercase network name as key

In [None]:
### Example 1: Basic mixing pool comparison

```python
import starsim as ss

# Transmission using a mixing pool
mp = ss.MixingPool(beta=ss.TimeProb(0.1), contacts=ss.poisson(lam=3))
sir = ss.SIR() # Beta doesn't matter for mixing pools
sim1 = ss.Sim(diseases=sir, networks=mp, verbose=0, label='Mixing Pool')

# Transmission using a contact network  
net = ss.RandomNet(n_contacts=ss.poisson(lam=3))
sir = ss.SIR(beta=ss.TimeProb(0.1))
sim2 = ss.Sim(diseases=sir, networks=net, verbose=0, label='Contact Network')

msim = ss.MultiSim([sim1, sim2]).run()
msim.plot()
```

### Example 2: Single mixing pool with group targeting

```python
# Target specific groups with a single mixing pool
mp = ss.MixingPool(
    diseases = 'sir', # Use this pool only for SIR, not other diseases
    src = lambda sim: sim.people.age < 15, # Sources: children under 15
    dst = ss.AgeGroup(low=15, high=None), # Destinations: adults 15+
    contacts = ss.poisson(lam=2), # Contact distribution
    beta = ss.TimeProb(0.2), # Transmission probability
)

sim = ss.Sim(diseases=['sir', 'hiv'], networks=mp) # Two diseases, only SIR transmits
sim.run()
sim.plot()
```

Key parameters:
- **`src`/`dst`**: Can be arrays of UIDs or callable functions returning UIDs
- **`contacts`**: Distribution of effective contacts per agent
- **`beta`**: Uses `TimeProb` for automatic time unit adjustment
- **`diseases`**: Specify which diseases use this pool

### Example 3: Age-structured mixing with contact matrices

```python
import numpy as np
import sciris as sc
import matplotlib.pyplot as plt

# Define age groups (5-year bins)
bin_size = 5
lows = np.arange(0, 80, bin_size)
highs = sc.cat(lows[1:], 100)
groups = ss.ndict([ss.AgeGroup(low=low, high=high) for low, high in zip(lows, highs)])
n_groups = len(groups)

# Create contact matrix (could be from Prem et al. or other sources)
# Diagonal elements are higher (within-group mixing)
cm = np.random.random((n_groups, n_groups)) + 3*np.diag(np.random.rand(n_groups))

print('Contact matrix structure:')
print('Rows = SOURCE groups, Columns = DESTINATION groups')
plt.imshow(cm, cmap='viridis')
plt.colorbar(label='Contacts per day')
plt.xlabel('Destination age group')
plt.ylabel('Source age group')

mps = ss.MixingPools(
    contacts = cm,
    beta = ss.TimeProb(0.2),
    src = groups,  # Dictionary of group definitions
    dst = groups,  # Same groups for symmetric mixing
)
```

In [None]:
# Track infections by age group using a custom analyzer
class InfectionsByAge(ss.Analyzer):
    def __init__(self, bins, **kwargs):
        super().__init__()
        self.bins = bins
        self.update_pars(**kwargs)

    def init_post(self):
        super().init_post()
        self.new_cases = np.zeros((len(self), len(self.bins)-1))

    def step(self):
        new_inf = self.sim.diseases.sir.ti_infected == self.ti
        if not new_inf.any(): 
            return
        self.new_cases[self.ti, :] = np.histogram(
            self.sim.people.age[new_inf], bins=self.bins
        )[0]

    def plot(self):
        fig, ax = plt.subplots()
        colors = plt.cm.nipy_spectral(np.linspace(0, 1, len(self.bins)))
        
        for i, (b1, b2) in enumerate(zip(self.bins[:-1], self.bins[1:])):
            ax.plot(self.timevec, self.new_cases[:,i], 
                   label=f'Age {b1}-{b2}', color=colors[i])
        ax.legend()
        ax.set_xlabel('Year')
        ax.set_ylabel('New Infections')

# Run simulation with age tracking
az = InfectionsByAge(np.concatenate([lows, [1000]]))
sir = ss.SIR()
sim = ss.Sim(diseases=sir, networks=mps, analyzers=az, 
             dur=5, dt=1/4, n_agents=1000, verbose=0)
sim.run()
sim.analyzers[0].plot()
print('Note: Default age distribution is uniform, so 75+ group has more people')

### Example 4: Multi-attribute mixing (SES-based)

```python
# Define socioeconomic status groups
ses = sc.dictobj(low=0, mid=1, high=2)

# Create population with SES distribution
ses_arr = ss.FloatArr('ses', default=ss.choice(a=ses.values(), p=[0.5, 0.3, 0.2]))
ppl = ss.People(n_agents=5_000, extra_states=ses_arr)

# Create mixing pools between SES groups
mps = ss.MixingPools(
    src = {k: lambda sim, s=v: ss.uids(sim.people.ses == s) for k,v in ses.items()},
    dst = {k: lambda sim, s=v: ss.uids(sim.people.ses == s) for k,v in ses.items()[:-1]}, 
    # Note: HIGH group cannot acquire infections (for demonstration)
    
    # Contact matrix: src on rows, dst on columns
    contacts = np.array([
        [2.50, 0.00], # low→low,  low→mid
        [0.05, 1.75], # mid→low,  mid→mid  
        [0.00, 0.15], # high→low, high→mid
    ]),
    
    beta = ss.TimeProb(0.2),
)
```

In [None]:
# Analyzer to track infections by SES group
class InfectionsBySES(ss.Analyzer):
    def init_results(self):
        self.new_cases = np.zeros((len(self), len(ses)))

    def step(self):
        new_inf = self.sim.diseases.sir.ti_infected == self.ti
        if not new_inf.any():
            return

        for value in ses.values():
            mask = new_inf & (self.sim.people.ses == value)
            self.new_cases[self.ti, value] = np.count_nonzero(mask)

# Custom seeding function targeting high-SES individuals
def seeding(self, sim, uids):
    p = np.zeros(len(uids))
    high_ses = ss.uids(sim.people.ses == ses.high)
    p[high_ses] = 0.1 # Seed 10% of high-SES individuals
    return p

# Run simulation
az = InfectionsBySES()
sir = ss.SIR(init_prev=ss.bernoulli(p=seeding))
sim = ss.Sim(people=ppl, diseases=sir, networks=mps, analyzers=az, 
             dt=1/12, dur=35, verbose=0)
sim.run()

# Plot results
fig, ax = plt.subplots()
new_cases = sim.analyzers[0].new_cases
for key, value in ses.items():
    ax.plot(sim.results.timevec, new_cases[:,value], label=key)
ax.legend()
ax.set_xlabel('Year')
ax.set_ylabel('New Infections')
plt.show()

This example demonstrates:
- **Directional transmission**: High SES → Mid SES → Low SES
- **Custom population**: Using `extra_states` to add SES attributes
- **Targeted seeding**: Starting infection in specific groups
- **Group-specific analysis**: Tracking outcomes by SES level

The wave pattern shows infections moving from high → mid → low SES groups based on the contact matrix structure.