<a href="https://colab.research.google.com/github/zzkzzkjsw/hinton/blob/main/mylenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/zzkzzkjsw/hinton.git

fatal: destination path 'hinton' already exists and is not an empty directory.


In [2]:
%cd hinton

/content/hinton


In [3]:
!pip install -r requirements.txt



In [4]:
import sys
sys.path.append('/content/hinton')

In [5]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import functional as F

import wandb

from my_config import Config
from my_utils import set_seed

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

config = Config()
config.train.device = device
wandb.init(project="lenet", config=config)

set_seed(config.train.seed)

cuda


[34m[1mwandb[0m: Currently logged in as: [33mzzk5678[0m ([33mzzkzzkjsw[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
mnist_train = torchvision.datasets.FashionMNIST('./data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
mnist_test = torchvision.datasets.FashionMNIST('./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)

In [8]:
def custom_collate_fn(batch):
    # batch is a list, each element is a (X,y) pair from dataset
    # in Minist dataset, X is a tensor, y is a integer
    xs = torch.stack([item[0] for item in batch])
    ys = torch.tensor([item[1] for item in batch], dtype=torch.long)  # 将 y 列表转换为张量
    return xs, ys

In [9]:
train_iter = DataLoader(mnist_train, batch_size=config.train.batch_size, shuffle=config.data.shuffle, num_workers=config.data.num_workers, collate_fn=custom_collate_fn)
test_iter = DataLoader(mnist_test, batch_size=config.train.batch_size, shuffle=False, num_workers=config.data.num_workers, collate_fn=custom_collate_fn)

In [10]:
class LeNet(nn.Module):
    def __init__(self, out_dim=10, pooling='avg'):
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,kernel_size=5,padding=2)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(6,16,kernel_size=5)
        if(pooling=='max'):
            self.pooling = nn.MaxPool2d(2)
        elif(pooling=='avg'):
            self.pooling = nn.AvgPool2d(2)
        self.lin1 = nn.Linear(400,128)
        # self.lin2 = nn.Linear(120,84)
        self.lin3 = nn.Linear(128,out_dim)



    def forward(self, x):
        bsz = x.shape[0]
        assert list(x.shape) == [bsz, 1, 28, 28]
        x = self.relu(self.conv1(x))
        assert list(x.shape) == [bsz, 6, 28, 28]
        x = self.pooling(x)
        assert list(x.shape) == [bsz, 6, 14, 14]
        x = self.relu(self.conv2(x))
        assert list(x.shape) == [bsz, 16, 10, 10]
        x = self.pooling(x)
        assert list(x.shape) == [bsz, 16, 5, 5]
        x = x.reshape(bsz, -1)
        assert list(x.shape) == [bsz, 400]
        x = self.relu(self.lin1(x))
        # x = self.relu(self.lin2(x))
        x = self.lin3(x)
        return x

In [11]:
net = LeNet()

In [12]:
# net = nn.Sequential(
#     nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
#     nn.AvgPool2d(kernel_size=2, stride=2),
#     nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
#     nn.AvgPool2d(kernel_size=2, stride=2),
#     nn.Flatten(),
#     nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
#     nn.Linear(120, 84), nn.Sigmoid(),
#     nn.Linear(84, 10))

In [13]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=config.train.learning_rate)

In [14]:
def accuracy(y_pred, y):
    return (y_pred.argmax(-1)==y).sum()/len(y)

In [15]:
# def init_weights(m):
#     if type(m) == nn.Linear or type(m) == nn.Conv2d:
#         nn.init.xavier_uniform_(m.weight)
# net.apply(init_weights)
net = net.to(config.train.device)
for epoch in range(config.train.num_epochs):
    device = config.train.device
    net.train()
    for idx,(X,y) in enumerate(train_iter):

        optimizer.zero_grad()

        X,y = X.to(device), y.to(device)
        y_pred = net(X)
        l = loss(y_pred,y)

        # with torch.no_grad():
        acc = accuracy(y_pred, y)
        wandb.log({"Iteration loss":l,"acc":acc, "idx":idx, "epoch":epoch})

        l.backward()

        torch.nn.utils.clip_grad_norm_(net.parameters(),10.0)
        optimizer.step()

    net.eval()
    for X,y in test_iter:
        # with torch.no_grad():
        X,y = X.to(device), y.to(device)
        y_pred = net(X)
        test_acc = accuracy(y_pred, y)
        wandb.log({"test_acc":acc, "idx":idx, "epoch":epoch})
