**Motivation**: Improve ResNet with better representational power.

Introduced cardinality (parallel transformations with grouped convs).

# Split-Transform-Merge (Group Convolution)

### In the ResNext architecture code I see there is a line in the BasicBlock section,

width = int(out_channels * (base_width / 64.)) * cardinality

what is the significance of it? why we need this?

---

this line is central to how **ResNeXt** scales its capacity differently from a standard ResNet. Let’s unpack it carefully:

---

### 🔍 The line

```python
width = int(out_channels * (base_width / 64.)) * cardinality
```

appears inside the **BasicBlock or Bottleneck block** of ResNeXt, and it determines the **intermediate channel width** used in the grouped convolution.

---

### 🧩 Background — what ResNeXt changes from ResNet

In **ResNet**, each bottleneck block looks like:

```
1×1 conv (reduce channels)
3×3 conv (same channels)
1×1 conv (expand channels)
```

But in **ResNeXt**, the 3×3 conv is **split into multiple groups** (using *grouped convolution*).
Each group processes a smaller number of channels, and their outputs are concatenated together.

This “split-transform-merge” structure allows increasing representational power **without** drastically increasing computation.

---

### ⚙️ The parameters

* **`out_channels`** → number of channels at block output (same as in ResNet)
* **`base_width`** → base number of channels per group when `cardinality=1` (default = 64)
* **`cardinality`** → number of parallel groups (the key idea of ResNeXt)

---

### 🧮 The formula logic

```python
width = int(out_channels * (base_width / 64.)) * cardinality
```

1. `(base_width / 64.)` → scales the intermediate width **relative to the original ResNet bottleneck design** (which uses width 64 when out_channels=256).
2. `out_channels * (base_width / 64.)` → adjusts per-group width depending on desired base width.
3. `* cardinality` → accounts for the **number of groups**.
   Each group gets its own small slice of the channels, and when combined, the total effective width becomes proportional to `cardinality`.

---

### 🧠 Intuition

This lets you control **two independent dimensions of capacity**:

* `base_width` → width of each transformation path (depth of each group)
* `cardinality` → number of transformation paths (breadth of the block)

Instead of just making the network “wider” (as in Wide ResNet), ResNeXt makes it **multi-branched**, each branch being narrower but collectively powerful.

---

### 📊 Example

Say:

* `out_channels = 128`
* `base_width = 4`
* `cardinality = 32`

Then:

```
width = int(128 * (4 / 64.)) * 32
      = int(128 * 0.0625) * 32
      = 8 * 32
      = 256
```

So the grouped 3×3 convolution will have `in_channels = 256`, `groups = 32`, meaning **8 channels per group**.

---

### ✅ Why we need this

Because this formula:

* Preserves proportional scaling to ResNet baseline (via `/64` normalization),
* Enables flexible scaling via `base_width` and `cardinality`,
* Keeps the computation manageable (since each group handles fewer channels).

In short:

> This line defines the **effective width** of the grouped convolution layer, balancing model capacity and efficiency by controlling how many parallel low-dimensional transformations (groups) exist inside a block.

---


### But in the Block structure if you see as below

self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.relu = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
self.bn2 = nn.BatchNorm2d(width) 

self.conv3 = nn.Conv2d(width, out_channels*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)

in the first convolution say if in_planes = 3 and width=256, then how the 1x1 conv is shrinking the channels?

---

Excellent observation — this confusion is *very* common when first reading **ResNeXt** code.
Let’s break it down carefully 👇

---

### 💡 Key point

In **ResNeXt**, unlike the original **ResNet**,
the first `1×1` convolution (`conv1`) doesn’t *always* shrink* the channels.
Its purpose is to **project** the input into the *grouped convolution’s internal width (`width`)* —
and that width can actually be **greater or smaller** than the input channels (`in_planes`).

---

### 🧩 Why this happens

Let’s revisit the standard **ResNet bottleneck** first:

| Layer    | Channels                              | Role                  |
| -------- | ------------------------------------- | --------------------- |
| 1×1 conv | `in_planes → out_channels / 4`        | **Reduce** dimensions |
| 3×3 conv | `out_channels / 4 → out_channels / 4` | Process               |
| 1×1 conv | `out_channels / 4 → out_channels`     | **Expand** dimensions |

So, ResNet always shrinks first → then expands back.

---

### ⚙️ What ResNeXt does differently

ResNeXt introduces *cardinality* (number of groups), and the “internal width” (`width`) is computed as:

```python
width = int(out_channels * (base_width / 64.)) * cardinality
```

Then:

```python
self.conv1 = nn.Conv2d(in_planes, width, 1, bias=False)
self.conv2 = nn.Conv2d(width, width, 3, stride=stride, padding=1, groups=cardinality, bias=False)
self.conv3 = nn.Conv2d(width, out_channels * expansion, 1, bias=False)
```

Now `width` could be **larger** than `in_planes` (like 256 > 3),
meaning `conv1` **expands** channels — not shrinks.

That’s okay — the 1×1 conv simply *reprojects* features to a new dimension appropriate for the grouped conv that follows.

---

### 🧠 Intuition

Think of it this way:

* In ResNet: 1×1 conv → dimension reduction for computation savings.
* In ResNeXt: 1×1 conv → **feature projection into a high-dimensional “grouped space”**.
  This projection enables each group in the next 3×3 conv to operate on its own subset of channels.

So the first conv isn’t *forced* to compress.
It’s shaping the feature map to fit the `groups = cardinality` design.

---

### 🧮 Example

Let’s say:

```python
in_planes = 3
out_channels = 128
base_width = 4
cardinality = 32
```

Then:

```
width = int(128 * (4 / 64)) * 32 = 8 * 32 = 256
```

So we get:

| Layer    | Input Channels | Output Channels | Groups | Role                                |
| -------- | -------------- | --------------- | ------ | ----------------------------------- |
| 1×1 conv | 3              | 256             | 1      | Expand/project                      |
| 3×3 conv | 256            | 256             | 32     | Grouped conv (8 channels per group) |
| 1×1 conv | 256            | 128×4=512       | 1      | Expand to match residual output     |

✅ Here the first conv **expands** instead of shrinking, because ResNeXt is increasing representation capacity by *increasing cardinality* rather than just depth or width.

---

### 📘 Summary

* In **ResNet**, the first 1×1 conv *shrinks* channels (bottleneck style).
* In **ResNeXt**, the first 1×1 conv *reshapes* to `width` (which can be smaller or larger).
* Its purpose is to prepare for **grouped convolution**, not necessarily to reduce dimensions.

---

In [None]:
# Let's Go Deeper!

import torch
from torch import nn
import torch.nn.functional as F

class BottleneckBlock(nn.Module):
    expansion = 4

    def __init__(self, in_planes, out_channels, stride=1, downsample=None, base_width=4, cardinality=32):
        super(BottleneckBlock, self).__init__()
        width = int(out_channels * (base_width / 64.)) * cardinality
        self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
        self.bn2 = nn.BatchNorm2d(width) 
        
        self.conv3 = nn.Conv2d(width, out_channels*self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
        
        self.downsample = None
        if stride != 1 or in_planes != out_channels * self.expansion:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )
        

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNext(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale, cardinality=32, base_width=4):
        self.inplanes = 64
        self.cardinality = cardinality
        self.base_width = base_width
        
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1, cardinality=self.cardinality, base_width=self.base_width)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, cardinality=self.cardinality, base_width=self.base_width)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, cardinality=self.cardinality, base_width=self.base_width)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, cardinality=self.cardinality, base_width=self.base_width)
        self.avgpool = nn.AvgPool2d(7, stride=1, padding=2)
        self.fc = nn.Linear(2048 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

#     def _make_layer(self, block, planes, blocks, stride=1):
#         layers = []
#         layers.append(block(self.inplanes, planes, stride))
#         self.inplanes = planes * block.expansion
#         for i in range(1, blocks):
#             layers.append(block(self.inplanes, planes))

#         return nn.Sequential(*layers)

    
    def _make_layer(self, block, planes, blocks, cardinality, base_width, stride=1):
        layers = []
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, stride, self.cardinality, self.base_width))
            self.inplanes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        #probas = F.softmax(logits, dim=1)
        return logits



def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNext(block=BottleneckBlock, 
                   layers=[2, 2, 2, 2],
                   num_classes=num_classes,
                   grayscale=GRAYSCALE)
    return model