In [1]:
import torch
import torch.nn as nn

class StandardBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(StandardBatchNorm, self).__init__()
        self.momentum = momentum
        self.running_mean = 0
        self.running_var = 0
        self.eps = torch.tensor(eps)
        self.num_features = num_features
        self.affine = affine
        shape = (1, self.num_features, 1, 1)

        if self.affine:
            self.gamma = nn.Parameter(torch.empty(shape))
            self.beta = nn.Parameter(torch.empty(shape))

        self._param_init()

    def _param_init(self):
        nn.init.zeros_(self.beta)
        nn.init.ones_(self.gamma)

    def forward(self, x):
        if self.training:
            n = x.numel() / x.size(1)
            dimensions = (0, 2, 3)
            var = x.var(dim=dimensions, keepdim=True, unbiased=False)
            mean = x.mean(dim=dimensions, keepdim=True)

            with torch.no_grad():
                self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean
                self.running_var = self.momentum * (n / (n - 1)) * var + (1 - self.momentum) * self.running_var

        else:
            mean = self.running_mean
            var = self.running_var
        dn = torch.sqrt(var + self.eps)
        x = (x - mean) / dn
        if self.affine:
            x = x * self.gamma + self.beta

        return x


class Block(nn.Module):
    def __init__(self, in_channels, intermediate_channels, identity_downsample=None, stride=1, norm_type='batch'):
        super().__init__()
        self.expansion = 2
        self.conv1 = nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False)

        if norm_type == 'bn':
            self.nl1 = StandardBatchNorm(intermediate_channels)
            self.nl2 = StandardBatchNorm(intermediate_channels * self.expansion)
        # Add other normalization types here if needed

        self.conv2 = nn.Conv2d(intermediate_channels, intermediate_channels * self.expansion, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample
        self.stride = stride

    def forward(self, x):
        identity = x.clone()

        x = self.conv1(x)
        x = self.nl1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.nl2(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x


class ResNet(nn.Module):
    def __init__(self, block, layers, image_channels, num_classes, norm_type='bn'):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv2d(image_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        if norm_type == 'bn':
            self.norm = StandardBatchNorm(16)
            self.gamma = self.norm.gamma
            self.beta = self.norm.beta
        # Add other normalization types here if needed
        else:
            raise ValueError("Invalid normalization type. Choose from 'bn'.")

        self.relu = nn.ReLU()

        self.layer1 = self._make_layer(block, layers[0], intermediate_channels=16, stride=1, norm_type=norm_type)
        self.layer2 = self._make_layer(block, layers[1], intermediate_channels=32, stride=2, norm_type=norm_type)
        self.layer3 = self._make_layer(block, layers[2], intermediate_channels=64, stride=2, norm_type=norm_type)

        self.avgpool = nn.AdaptiveAvgPool2d((8, 8))
        self.fc = nn.Linear(128 * 8 * 8, num_classes)  # Adjusted this line

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Adjusted this line
        x = self.fc(x)

        return x

    def _make_layer(self, block, num_residual_blocks, intermediate_channels, stride, norm_type):
        identity_downsample = None
        layers = []

        if stride != 1 or self.in_channels != intermediate_channels * 2:
            if norm_type == 'bn':
                identity_downsample = nn.Sequential(
                    nn.Conv2d(self.in_channels, intermediate_channels * 2, kernel_size=1, stride=stride, bias=False),
                    StandardBatchNorm(intermediate_channels * 2)
                )
            # Add other normalization types here if needed
            else:
                raise ValueError("Invalid normalization type. Choose from 'bn'.")

        layers.append(block(self.in_channels, intermediate_channels, identity_downsample, stride, norm_type))
        self.in_channels = intermediate_channels * 2

        for _ in range(1, num_residual_blocks):  # Adjusted this line
            layers.append(block(self.in_channels, intermediate_channels, norm_type=norm_type))

        return nn.Sequential(*layers)


# Example usage:
resnet = ResNet(Block, [2, 2, 2], image_channels=3, num_classes=10, norm_type='bn')
print(resnet)


ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm): StandardBatchNorm()
  (relu): ReLU()
  (layer1): Sequential(
    (0): Block(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (nl1): StandardBatchNorm()
      (nl2): StandardBatchNorm()
      (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (relu): ReLU()
      (identity_downsample): Sequential(
        (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): StandardBatchNorm()
      )
    )
    (1): Block(
      (conv1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (nl1): StandardBatchNorm()
      (nl2): StandardBatchNorm()
      (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (relu): ReLU()
    )
  )
  (layer2): Sequential(
    (0): Block(
      (conv1): Conv2d(32, 32, kernel_size=

In [4]:
# Print all parameters
for name, param in resnet.named_parameters():
    print(f"Parameter name: {name}, Shape: {param.shape}")

Parameter name: gamma, Shape: torch.Size([1, 16, 1, 1])
Parameter name: beta, Shape: torch.Size([1, 16, 1, 1])
Parameter name: conv1.weight, Shape: torch.Size([16, 3, 3, 3])
Parameter name: layer1.0.conv1.weight, Shape: torch.Size([16, 16, 3, 3])
Parameter name: layer1.0.nl1.gamma, Shape: torch.Size([1, 16, 1, 1])
Parameter name: layer1.0.nl1.beta, Shape: torch.Size([1, 16, 1, 1])
Parameter name: layer1.0.nl2.gamma, Shape: torch.Size([1, 32, 1, 1])
Parameter name: layer1.0.nl2.beta, Shape: torch.Size([1, 32, 1, 1])
Parameter name: layer1.0.conv2.weight, Shape: torch.Size([32, 16, 3, 3])
Parameter name: layer1.0.identity_downsample.0.weight, Shape: torch.Size([32, 16, 1, 1])
Parameter name: layer1.0.identity_downsample.1.gamma, Shape: torch.Size([1, 32, 1, 1])
Parameter name: layer1.0.identity_downsample.1.beta, Shape: torch.Size([1, 32, 1, 1])
Parameter name: layer1.1.conv1.weight, Shape: torch.Size([16, 32, 3, 3])
Parameter name: layer1.1.nl1.gamma, Shape: torch.Size([1, 16, 1, 1])
Pa

In [6]:
torch.save(resnet.state_dict(),'resent.pth')

In [7]:
model = ResNet(Block, [2, 2, 2], image_channels=3, num_classes=10, norm_type='bn')

In [11]:
model.load_state_dict(torch.load('resent.pth'))

<All keys matched successfully>