<img src="imgs/AlexNet.png">

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.Functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.optim as optim

In [19]:
class LRN(nn.Module):
    def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=False):
        super(LRN, self).__init__()
        self.ACROSS_CHANNELS = ACROSS_CHANNELS
        if self.ACROSS_CHANNELS:
            self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), #0.2.0_4会报错，需要在最新的分支上AvgPool3d才有padding参数
                    stride=1,
                    padding=(int((local_size-1.0)/2), 0, 0)) 
        else:
            self.average=nn.AvgPool2d(kernel_size=local_size,
                    stride=1,
                    padding=int((local_size-1.0)/2))
        self.alpha = alpha
        self.beta = beta
    
    
    def forward(self, x):
        if self.ACROSS_CHANNELS:
            div = x.pow(2).unsqueeze(1)
            div = self.average(div).squeeze(1)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)#这里的1.0即为bias
        else:
            div = x.pow(2)
            div = self.average(div)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        x = x.div(div)
        return x

In [48]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=(11, 11), stride=4)
        self.activate1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.LRN1 = LRN(local_size=5, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)    
        self.layer1 = nn.Sequential(self.conv1, self.activate1,self.pool1, self.LRN1)
        

        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=(5, 5), stride=1, groups=2, padding=2)
        self.activate2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.LRN2 = LRN(local_size=5, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
        self.layer2 = nn.Sequential(self.conv2, self.activate2, self.pool2, self.LRN2)
        
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=(3, 3), stride=1, padding=1)
        self.activate3 = nn.ReLU(inplace=True)
        self.layer3 = nn.Sequential(self.conv3, self.activate3)
        
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384,kernel_size=(3,3), stride=1, padding=1)
        self.activate4 = nn.ReLU(inplace=True)
        self.layer4 = nn.Sequential(self.conv4, self.activate4)
        
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=(3,3), stride=1, padding=1)
        self.activate5 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.layer5 = nn.Sequential(self.conv5, self.activate5, self.pool5)
        
        self.fc1 = nn.Linear(in_features=6*6*256, out_features=4096)
        self.activate6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout()
        self.layer6 = nn.Sequential(self.fc1, self.activate6, self.dropout6)
        
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.activate7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout()
        self.layer7 = nn.Sequential(self.fc2, self.activate7, self.dropout7)
        
        self.fc3 = nn.Linear(in_features=4096, out_features=1000)
        self.activate8 = nn.ReLU(inplace=True)
        self.dropout8 = nn.Dropout()
        self.layer8 = nn.Sequential(self.fc3, self.activate8, self.dropout8)
        
    def forward(self, x):
#         x = self.layer3(self.layer2(self.layer1(x)))
#         print(x.size())
        x = self.layer5(self.layer4(self.layer3(self.layer2(self.layer1(x)))))
        x = x.view(-1, 6*6*256)
        x = self.layer8(self.layer7(self.layer6(x)))
        return x
        

In [49]:
net = AlexNet()
# print(net)
input_data = t.randn(1, 3, 227, 227)
out = net(input_data)
print(out.size())

torch.Size([1, 1000])
