In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
from torch.utils.data import DataLoader
device = torch.device('cuda:0')

In [2]:
train_path = r'./dataset/ImageNet/train'
valid_path = r'./dataset/ImageNet/valid'

## 数据加载器

In [3]:
size = 128
transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(size),
    v2.CenterCrop(size),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
train_dataset = ImageFolder(train_path,transforms)

In [5]:
train_data_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)

## 残差神经网络

In [6]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, hid_channels):
        super().__init__()
        self.in_channels = in_channels
        self.hid_channels = hid_channels
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels,hid_channels,kernel_size=3,padding=1),
            nn.BatchNorm2d(hid_channels),
            nn.SiLU(),
            nn.Conv2d(hid_channels,hid_channels,kernel_size=7,padding=3),
            nn.BatchNorm2d(hid_channels),
            nn.SiLU(),
            nn.Conv2d(hid_channels,in_channels,kernel_size=3,padding=1)
        )

    def forward(self, x):
        return x + self.cnn(x)
    
class ResNet(nn.Module):
    def __init__(self, size, in_channels, num_classes, res_channels, hid_channels, num_res):
        super().__init__()
        self.size = size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.res_channels = res_channels
        self.hid_channels = hid_channels
        self.num_res = num_res
        self.encoder = nn.Conv2d(in_channels,res_channels,kernel_size=3,padding=1)
        self.decoder = nn.Sequential(
            nn.SiLU(),
            nn.Conv2d(res_channels,num_classes,kernel_size=3,padding=1)
        )
        self.hid_res_layers = nn.Sequential(
            *[ResBlock(res_channels,hid_channels) for _ in range(num_res)]
        )
    
    def forward(self, x):
        y = self.encoder(x)
        y = self.hid_res_layers(y)
        y = self.decoder(y)
        return y.mean(dim=[-2,-1])

In [7]:
x = torch.randn([1,3,size,size],device=device)
resnet = ResNet([size,size],3,20,32,64,5).to(device)
y = resnet(x)

In [8]:
def train(model,dataloader,epoch,lr):
    optim = torch.optim.Adam(model.parameters(),lr=lr)
    loss_fun = nn.CrossEntropyLoss()

    for i in range(epoch):
        model.train()
        for data, label in dataloader:
            optim.zero_grad()
            pred = model(data.to(device))
            loss = loss_fun(pred,label.to(device))
            loss.backward()
            optim.step()
            print(i, loss.item())

In [9]:
train(resnet,train_data_loader,10,1e-3)

0 3.031346082687378
0 2.9801924228668213
0 3.127178430557251
0 3.014630079269409
0 2.927363872528076
0 2.870143175125122
0 3.1752493381500244
0 3.1752028465270996
0 2.9695963859558105
0 2.919929027557373
0 2.9906697273254395
0 2.873068332672119
0 2.6914920806884766
0 2.7166731357574463
0 2.7960896492004395
0 2.7284178733825684
0 2.7157068252563477
0 2.579662799835205
0 2.7923645973205566
0 2.8793718814849854
0 2.880093574523926
0 2.788482189178467
0 2.905705451965332
0 2.6335535049438477
0 2.6261863708496094
0 2.682499408721924
0 2.658428907394409
0 2.5761618614196777
0 2.6496028900146484
0 2.5789384841918945
0 2.4638962745666504
0 2.5612387657165527
0 2.4635744094848633
0 2.3985018730163574
0 2.8103926181793213
0 2.559324264526367
0 2.396188259124756
0 2.487492322921753
0 2.4883055686950684
0 3.073068618774414
0 2.333211898803711
0 2.5940237045288086
0 2.5793745517730713
0 2.7073426246643066
0 2.6884121894836426
0 2.479959011077881
0 3.0489673614501953
0 2.526317596435547
0 2.57900571