In [2]:
import torch
import torch.nn as nn

In [3]:
class Inception(nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        self.p1=nn.Conv2d(in_channels, c1, kernel_size=1)
        self.p2=nn.Sequential(nn.Conv2d(in_channels, c2[0], kernel_size=1),nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1))
        self.p3=nn.Sequential(nn.Conv2d(in_channels, c3[0], kernel_size=1),nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2))
        self.p4=nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),nn.Conv2d(in_channels, c4, kernel_size=1))
        
    def forward(self, x):
        p1=self.p1(x)
        p2=self.p2(x)
        p3=self.p3(x)
        p4=self.p4(x)
        return torch.cat((p1, p2, p3, p4), dim=1)

In [4]:
class GoogLeNet(nn.Sequential):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.b1=nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        # 7x7 conv,3x3 maxpool,
        self.b2=nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),nn.Conv2d(64, 192, kernel_size=3, padding=1),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        # 1x1 conv,3x3 conv,3x3 maxpool
        self.b3=nn.Sequential(
            Inception(192, 64, (96, 128), (16, 32), 32),
            Inception(256, 128, (128, 192), (32, 96), 64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            )
        # inception, inception, 3x3 maxpool
        self.b4=nn.Sequential(
            Inception(480, 192, (96, 208), (16, 48), 64),
            Inception(512, 160, (112, 224), (24, 64), 64),
            Inception(512, 128, (128, 256), (24, 64), 64),
            Inception(512, 112, (144, 288), (32, 64), 64),
            Inception(528, 256, (160, 320), (32, 128), 128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            )
        self.b5=nn.Sequential(
            Inception(832, 256, (160, 320), (32, 128), 128),
            Inception(832, 384, (192, 384), (48, 128), 128),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
            )
    
    def forward(self, x):
        x=self.b1(x)
        x=self.b2(x)
        x=self.b3(x)
        x=self.b4(x)
        x=self.b5(x)
        return x

In [5]:
net=GoogLeNet()
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Sequential output shape:	 torch.Size([1, 64, 24, 24])
Sequential output shape:	 torch.Size([1, 192, 12, 12])
Sequential output shape:	 torch.Size([1, 480, 6, 6])
Sequential output shape:	 torch.Size([1, 832, 3, 3])
Sequential output shape:	 torch.Size([1, 1024])


In [6]:
import torchvision.datasets
import torchvision.transforms as transforms
data=torchvision.datasets.FashionMNIST(root='data', train=True, transform=transforms.ToTensor())
train_set,test_set=torch.utils.data.random_split(data,[50000,10000])
train_data=torch.utils.data.DataLoader(train_set,batch_size=128,shuffle=True)
test_data=torch.utils.data.DataLoader(test_set,batch_size=128,shuffle=True)
del data,train_set,test_set

In [7]:
net.cuda()
def train(model,epoch,train_data,test_data,loss,optimizer):
    train_data.cuda()
    test_data.cuda()
    for i in range(epoch):
        train_loss,train_acc,test_acc=0.0,0.0,0.0
        for X,y in train_data:
            optimizer.zero_grad()
            y_hat=model(X)
            l=loss(y_hat,y)
            l.backward()
            optimizer.step()
            train_loss+=l.item()
            train_acc+=(y_hat.argmax(dim=1)==y).sum().item()
        for X,y in test_data:
            test_acc+=(model(X).argmax(dim=1)==y).sum().item()
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'%(i+1,train_loss/len(train_data),train_acc/len(train_data.dataset),test_acc/len(test_data.dataset)))
           
loss=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.0001)
train(net,10,train_data,test_data,loss,optimizer)

AttributeError: 'DataLoader' object has no attribute 'cuda'