In [20]:
import torch
from torch import nn

In [21]:
class ClassificationNN(nn.Module):

    def __init__(self, num_classes=2):
        super(ClassificationNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.global_max = nn.MaxPool2d(kernel_size=64)
        self.global_avg = nn.AvgPool2d(kernel_size=64)
        self.pooling = nn.MaxPool2d(10)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.global_max(x2)
        x4 = self.global_avg(x2)
        x = torch.cat([x3, x4], 1)
        x = torch.flatten(x, 1)
        # output = fc(x.squeeze())

        return x1, x2

In [22]:
class ClassificationUnet(nn.Module):
    def __init__(self, num_classes=2, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=3, norm_layer=nn.InstanceNorm3d):
        super(ClassificationUnet, self).__init__()

        self.num_classes = num_classes

        block = DownsamplingBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
                                             kernel_size=kernel_size, norm_layer=norm_layer)

        for i in range(1, num_downs):
            block = DownsamplingBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
                                      out_channels=initial_filter_size * 2 ** (num_downs-i),
                                      kernel_size=kernel_size, submodule=block, norm_layer=norm_layer)

        block = DownsamplingBlock(in_channels=in_channels, out_channels=initial_filter_size,
                                             kernel_size=kernel_size, submodule=block, norm_layer=norm_layer)

        self.model = block
        self.fc = nn.Linear(initial_filter_size * 2 ** (num_downs+1), num_classes)

    @staticmethod
    def pooling(layer, in_channels=None, kernel_size=None):
        global_max = nn.MaxPool3d(kernel_size=kernel_size)
        global_avg = nn.AvgPool3d(kernel_size=kernel_size)
        layer = torch.cat([global_avg(layer), global_max(layer)], 1)
        layer = torch.flatten(layer, 1)
        return layer

    def forward(self, x):
        x = self.model(x)
        x = self.pooling(x, in_channels=x.size()[1], kernel_size=x.size()[2])

        x = self.fc(x)

        return x


# half part of unet, the extraction path without skipping connection
class DownsamplingBlock(nn.Module):
    def __init__(self, in_channels=None, out_channels=None, kernel_size=3, submodule=None, norm_layer=nn.InstanceNorm3d):
        super(DownsamplingBlock, self).__init__()

        pool = nn.MaxPool3d(2, stride=2)
        conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                              norm_layer=norm_layer)
        conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                              norm_layer=norm_layer)
        if submodule is None:
            down = [conv1, conv2]
            model = down
        else:
            down = [conv1, conv2, pool]
            model = down + [submodule]

        self.model = nn.Sequential(*model)

    @staticmethod
    def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm3d):
        layer = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, padding=1),
            norm_layer(out_channels),
            nn.LeakyReLU(inplace=True))
        return layer

    def forward(self, x):
        return self.model(x)


In [23]:
model = ClassificationUnet()
parameter_dict = model.state_dict()

In [24]:
print(type(model.named_parameters()))

<class 'generator'>


In [29]:
loss = torch.tensor([0]).float()
for name, param in model.named_parameters():
    if "fc.weight" in name:
        weight = param.data
        loss += torch.sum(weight * weight)
        print(name, param.data)

fc.weight tensor([[ 0.0287, -0.0117,  0.0053,  ...,  0.0181, -0.0113, -0.0214],
        [ 0.0165,  0.0036,  0.0296,  ..., -0.0178, -0.0224, -0.0264]])


In [30]:
print(loss)

tensor([0.6722])


In [19]:
l2_loss = weight*weight
beta = 2
l2_loss = beta*l2_loss
print(l2_loss)
l2_loss = torch.sum(l2_loss)
print(l2_loss.type())

tensor([[1.8106e-03, 4.0437e-06, 1.4800e-03,  ..., 8.8446e-05, 1.3211e-03,
         7.7822e-05],
        [7.6561e-06, 7.6831e-05, 2.4072e-04,  ..., 1.7067e-03, 3.2952e-04,
         6.9047e-05]])
torch.FloatTensor
