# DenseNet

In [1]:
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import time
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt

## 稠密层

In [2]:
# 一个稠密块
def conv_block(input_features, out_features):
    return nn.Sequential(
        nn.BatchNorm2d(input_features), nn.ReLU(),
        nn.Conv2d(in_channels=input_features, out_channels=out_features, kernel_size=3, padding=1)
    )

In [3]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs: int, input_features, num_features):
        super(DenseBlock, self).__init__()
        layers = []
        for i in range(num_convs):
            layers.append(
                conv_block(num_features * i + input_features, num_features)
            )
        self.net = nn.Sequential(*layers)

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)
        return X

In [4]:
blk = DenseBlock(2, 3, 10)
X = torch.randn(size=(4, 3, 8, 8))
y = blk(X)
y.shape

torch.Size([4, 23, 8, 8])

## 过渡层

由于稠密块增加通道数，过多通道数则过于复杂模型，所以可以通过过渡层来控制模型复杂度。他通过1*1卷积控制通道数，并且步幅为2的平均池化层来减少宽和高，进一步控制模型复杂度

In [5]:
def transition_block(input_features, output_features):
    return nn.Sequential(
        nn.BatchNorm2d(input_features), nn.ReLU(),
        nn.Conv2d(input_features, output_features, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )

In [6]:
tblk = transition_block(23, 10)
tblk(y).shape


torch.Size([4, 10, 4, 4])

## DenseNet模型

In [7]:
# 单卷积层和最大池化层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))



In [8]:
num_channels, growth_rate = 64, 32
num_convs_in_dense_net = [4 for i in range(4)]
blks = []


In [9]:
for i, num_convs in enumerate(num_convs_in_dense_net):
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    num_channels += growth_rate * num_convs
    if i != len(num_convs_in_dense_net) - 1:
        blks.append(transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2


In [10]:
net = nn.Sequential(
    b1, *blks,
    nn.BatchNorm2d(num_channels), nn.ReLU(),
    nn.AdaptiveMaxPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(num_channels, 10),
)

In [11]:
num_epoths, batch_size, lr = 10, 256, .1

transform = transforms.Compose([
    transforms.RandomSizedCrop(96), transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

dataset = datasets.FashionMNIST(root='././data/', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=dataset, shuffle=True, batch_size=batch_size)




In [12]:
opt = optim.SGD(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [18]:
net = net.to(device)
writer = SummaryWriter('./runs/loss_1')
dataloader_size = len(dataloader)

for epoth in range(num_epoths):
    net.train()
    total_loss = .0
    running_size = 0
    for i, (features, labels) in enumerate(dataloader):
        features, labels = features.to(device), labels.to(device)
        time1 = time.localtime(time.time())
        opt.zero_grad()
        outputs = net(features)
        loss = criterion(outputs, labels)
        loss.backward()
        opt.step()
        time2 = time.localtime(time.time())
        total_loss += loss.item()
        running_size += 1
        if (i + 1) % 10 == 0:
            writer.add_scalar("Loss/train", loss, (i + 1) * (epoth + 1))
            writer.flush()
            print(f'epoch:{epoth + 1}/{num_epoths}  i: {i + 1}/{dataloader_size}  loss: {loss.item()}')
        print(
            f'epoch {epoth + 1} begin time: {time.strftime("%h:%m:%S", time1)}, end time{time.strftime("%h:%m:%S", time2)}, loss: {total_loss / running_size}')

epoch 1 begin time: Aug:08:59, end timeAug:08:08, loss: 2.1107988357543945
epoch 1 begin time: Aug:08:08, end timeAug:08:19, loss: 1.9577456712722778
epoch 1 begin time: Aug:08:19, end timeAug:08:31, loss: 1.9122936725616455
epoch 1 begin time: Aug:08:31, end timeAug:08:41, loss: 1.8272144496440887
epoch 1 begin time: Aug:08:41, end timeAug:08:50, loss: 1.7798974990844727
epoch 1 begin time: Aug:08:50, end timeAug:08:58, loss: 1.7524544795354207
epoch 1 begin time: Aug:08:59, end timeAug:08:06, loss: 1.7353826761245728
epoch 1 begin time: Aug:08:06, end timeAug:08:17, loss: 1.7451502233743668
epoch 1 begin time: Aug:08:17, end timeAug:08:24, loss: 1.7121010091569688
epoch:1/10  i: 10/235  loss: 1.3787015676498413
epoch 1 begin time: Aug:08:24, end timeAug:08:31, loss: 1.678761065006256
epoch 1 begin time: Aug:08:31, end timeAug:08:39, loss: 1.6498191356658936
epoch 1 begin time: Aug:08:40, end timeAug:08:49, loss: 1.6595563193162282
epoch 1 begin time: Aug:08:49, end timeAug:08:58, los

KeyboardInterrupt: 

In [None]:
writer.close()