In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        for _ in range(num_layers):
            self.layers.append(
                nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
            )
            in_channels += growth_rate  # Increase the input channels after each layer
        
    def forward(self, x):
        for layer in self.layers:
            new_features = F.relu(layer(x))
            x = torch.cat([x, new_features], dim=1)  # Concatenate new features with existing features
        return x

In [None]:
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        return x