# import

In [21]:
import torch
from torch import nn
import os
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from os.path import join
import torchvision
from torch.utils.data import DataLoader
from d2l import torch as d2l

In [22]:
def set_seed(seed=0):
    torch.manual_seed(seed)

In [23]:
class NotAbsPathException(Exception):
    def __init__(self,path, *args: object) -> None:
        super().__init__(*args)
        self.path = path
    def __str__(self):
        return f"{self.path} is not an abstract path"

class CheckPoint:
    def __init__(self,path,file_name) -> None:
        """
            检查path是否是绝对路径
            检查多级目录是否创建
        """
        if not os.path.isabs(path):
            raise NotAbsPathException(path)

        if not os.path.exists(path):
            os.makedirs(path)

        self.path = path
        self.file_name = file_name
        self.file_path = os.path.join(self.path,self.file_name)

    def save(self,data:dict):
        """
            保存数据, 所有数据完全由用户提供(不同情况下, 需要保存的数据不同, 没有相同的解决方案)
        """
        torch.save(data,self.file_path)

    def load(self):
        """
            加载
        """
        return torch.load(self.file_path)

In [24]:
class MNIST_2NN(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # 输入：784
        # 隐藏层 1：784*200，200
        # 隐藏层 2：200*200，200
        # 输出：200*10,10
        self.flat = torch.nn.Flatten()
        self.fc_1 = torch.nn.Linear(784, 200)
        self.fc_2 = torch.nn.Linear(200, 200)
        self.fc_3 = torch.nn.Linear(200, 10)
        self.relu = torch.nn.ReLU()

    def init_params(self, seed):
        set_seed(seed)
        for layer in self.children():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                # 参数初始化方法一般与激活函数有关
                # Relu-kaming
                # sigmoid-xavier
                torch.nn.init.kaiming_normal_(layer.weight.data)
                torch.nn.init.zeros_(layer.bias.data)

    def forward(self, x):
        x = self.flat(x)
        x = self.fc_1(x)
        x = self.relu(x)
        x = torch.nn.Dropout(0.5)(x)  # 过拟合
        x = self.fc_2(x)
        x = self.relu(x)
        x = torch.nn.Dropout(0.5)(x)  # 过拟合
        x = self.fc_3(x)
        return x


In [25]:
reload_flag = True # 是否从目录重新加载
log_dir = '/home/whr-pc-ubuntu/code/test/checkpoint/log/1111'
file_name = 'models.pth'


In [26]:
net = MNIST_2NN()
net.init_params(0)

In [27]:
checkpoint = CheckPoint(log_dir,file_name)

In [28]:
checkpoint.save({
    "model":net.state_dict(),
    "step":0
    })

In [29]:
checkpoint.load()

{'model': OrderedDict([('fc_1.weight',
               tensor([[-0.0569, -0.0582, -0.0127,  ..., -0.0799, -0.0297, -0.0058],
                       [ 0.0354, -0.0281, -0.0193,  ...,  0.0201,  0.0130,  0.0101],
                       [-0.0080,  0.0373, -0.0127,  ...,  0.0465, -0.0056, -0.0946],
                       ...,
                       [ 0.0469, -0.0082,  0.0193,  ...,  0.0332, -0.0398, -0.0023],
                       [-0.0562,  0.0057, -0.0071,  ...,  0.0176,  0.0732, -0.1302],
                       [-0.0638,  0.1276, -0.0920,  ..., -0.0135, -0.0364, -0.1337]])),
              ('fc_1.bias',
               tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                       0., 0., 0., 

In [30]:
reload_flag = False # 是否从目录重新加载
log_dir = '/home/whr-pc-ubuntu/code/test/checkpoint/log/1111'
file_name = 'models.pth'


# test

先训练一段时间, 将结果记录在tensor board中

然后将reload_flag修改为True, 观察tensor board的曲线是否是连贯的

In [31]:
def now_str():
    return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")


def get_writer(*tags):
    path = 'logs'
    for tag in tags:
        path = join(path, tag)
    writer = SummaryWriter(log_dir)
    return writer




In [32]:
def load_data(seed=0, batch_size=256, shuffle=True):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081)),
    ])
    dataset_path = '/home/whr-pc-ubuntu/code/dataset'

    train_dataset = torchvision.datasets.MNIST(
        dataset_path,
        True,
        transform,
        download=True,
    )
    test_dataset = torchvision.datasets.MNIST(
        dataset_path,
        False,
        transform,
        download=True,
    )

    set_seed(seed)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle)

    return train_dataloader, test_dataloader


## first train

In [33]:
load_flag = False

In [34]:
epoch = 1000
seed = 0
batch_size = 128
lr = 0.01
train_dataloader,test_dataloader = load_data(seed,batch_size,False)
device = torch.device("cuda:0")

In [35]:
log_dir = '/home/whr-pc-ubuntu/code/test/checkpoint/log/5555'
file_name = 'models.pth'
checkpoint = CheckPoint(log_dir,file_name)
writer = get_writer(log_dir)

In [36]:
net = MNIST_2NN()
net.init_params(seed)
net.to(device)
optim = torch.optim.SGD(net.parameters(),lr)    
start = 0

if load_flag:
    data = checkpoint.load()
    net.load_state_dict(data['model'])
    # optim.load_state_dict(data['optim'])
    start = data['i'] + 1
    
net.eval()    
loss_func = nn.CrossEntropyLoss()
accumulator = d2l.Accumulator(3)
for i in range(start,epoch):
    for x,y in train_dataloader:
        optim.zero_grad()
        x,y = x.to(device),y.to(device)
        y_hat = net(x)

        loss = loss_func(y_hat,y)
        loss.backward()
        optim.step()

        accumulator.add(loss*x.shape[0],d2l.accuracy(y_hat,y),x.shape[0])

    test_acc = d2l.evaluate_accuracy_gpu(net,test_dataloader,device)

    writer.add_scalar("train loss",accumulator[0]/accumulator[-1],i)
    writer.add_scalar("train acc",accumulator[1]/accumulator[-1],i)
    writer.add_scalar("test acc",test_acc,i)
    
    
    # 间隔20次保存一次
    if i%20 == 0:
        checkpoint.save({
            "model":net.state_dict(),
            "optim":optim.state_dict(),
            "i":i
            })
    

## train again

In [17]:
load_flag = True


In [18]:
epoch = 1000
seed = 0
batch_size = 128
lr = 0.01
train_dataloader,test_dataloader = load_data(seed,batch_size,False)
device = torch.device("cuda:0")

In [19]:
log_dir = '/home/whr-pc-ubuntu/code/test/checkpoint/log/5555'
file_name = 'models.pth'
checkpoint = CheckPoint(log_dir,file_name)
writer = get_writer(log_dir)

In [20]:
net = MNIST_2NN()
net.init_params(seed)
net.to(device)
optim = torch.optim.SGD(net.parameters(),lr)    
start = 0

if load_flag:
    print("reload")
    data = checkpoint.load()
    net.load_state_dict(data['model'])
    optim.load_state_dict(data['optim'])
    start = data['i'] + 1

net.eval()
loss_func = nn.CrossEntropyLoss()
accumulator = d2l.Accumulator(3)
for i in range(start,epoch):
    for x,y in train_dataloader:
        optim.zero_grad()
        x,y = x.to(device),y.to(device)
        y_hat = net(x)

        loss = loss_func(y_hat,y)
        loss.backward()
        optim.step()

        accumulator.add(loss*x.shape[0],d2l.accuracy(y_hat,y),x.shape[0])

    test_acc = d2l.evaluate_accuracy_gpu(net,test_dataloader,device)

    writer.add_scalar("train loss",accumulator[0]/accumulator[-1],i)
    writer.add_scalar("train acc",accumulator[1]/accumulator[-1],i)
    writer.add_scalar("test acc",test_acc,i)
    
    # 间隔20次保存一次
    if i%20 == 0:
        checkpoint.save({
            "model":net.state_dict(),
            "optim":optim.state_dict(),
            "i":i
            })
    

reload


# epoch等效替代

In [22]:
epoch = 40
seed = 0
batch_size = 128
lr = 0.01
train_dataloader,test_dataloader = load_data(seed,batch_size,False)
device = torch.device("cuda:0")

In [23]:
log_dir = '/home/whr-pc-ubuntu/code/test/checkpoint/log/3333'
file_name = 'models.pth'
checkpoint = CheckPoint(log_dir,file_name)
writer = get_writer(log_dir)

In [24]:
net = MNIST_2NN()
net.init_params(seed)
net.to(device)
optim = torch.optim.SGD(net.parameters(),lr)    

loss_func = nn.CrossEntropyLoss()
accumulator = d2l.Accumulator(3)


In [25]:
for i in range(0,1000):
    for x,y in train_dataloader:
        optim.zero_grad()
        x,y = x.to(device),y.to(device)
        y_hat = net(x)

        loss = loss_func(y_hat,y)
        loss.backward()
        optim.step()

        accumulator.add(loss*x.shape[0],d2l.accuracy(y_hat,y),x.shape[0])

    writer.add_scalar("train loss",accumulator[0]/accumulator[-1],i)
    writer.add_scalar("train acc",accumulator[1]/accumulator[-1],i)
    
    # 间隔20次保存一次
    if i%20 == 0:
        checkpoint.save({
            "model":net.state_dict(),
            "optim":optim.state_dict(),
            "i":i
            })

In [25]:
for i in range(epoch,2*epoch):
    for x,y in train_dataloader:
        optim.zero_grad()
        x,y = x.to(device),y.to(device)
        y_hat = net(x)

        loss = loss_func(y_hat,y)
        loss.backward()
        optim.step()

        accumulator.add(loss*x.shape[0],d2l.accuracy(y_hat,y),x.shape[0])

    writer.add_scalar("train loss",accumulator[0]/accumulator[-1],i)
    writer.add_scalar("train acc",accumulator[1]/accumulator[-1],i)
    
    # 间隔20次保存一次
    if i%20 == 0:
        checkpoint.save({
            "model":net.state_dict(),
            "optim":optim.state_dict(),
            "i":i
            })