In [1]:
from tensorflow.keras.datasets import mnist
import torch
from torch import nn 
import numpy as np 
from torch.utils.data import TensorDataset,DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb

In [2]:
# config_class
class base_config():
    def __init__(self,lr,epoch,batch_size):
        self.lr = lr 
        self.epoch = epoch
        self.batch_size = batch_size

In [4]:
# 生成dataloader
def create_dataloader(data,batch_size):
    return DataLoader(data,batch_size=batch_size,shuffle=True)

In [14]:
# trainer_class_classify
class classify_trainer():
    def __init__(self,model,optimizer,loss_function,traindata,testdata,config,device,valdata=None,use_wandb=False):
        self.device = device
        self.model = model.to(device)
        self.config = config
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.use_wandb = use_wandb
        self.trainloader = DataLoader(traindata,batch_size=self.config.batch_size,shuffle=True)
        self.testloader = DataLoader(testdata,batch_size=self.config.batch_size,shuffle=True)
        if valdata is not None:
            self.validloader = DataLoader(valdata,batch_size=self.config.batch_size,shuffle=True)
        else:
            self.validloader = None
        
        # init wandb
        if self.use_wandb:
            project_name = input('please input wandb project name')
            print('wandb is loading')
            config_dict = {
                "lr":self.config.lr,
                "batch_size":self.config.batch_size,
                "epoch":self.config.epoch,
            }
            print("The project name : ",project_name)
            wandb.init(project=project_name,config=config_dict)
    
    # compute accuracy
    @classmethod 
    def compute_acc(self,pred,label):
        return np.equal(pred,label).mean()
    
    @classmethod
    def eval(self,model,dataloader,loss_function,device):
        loss = 0
        acc = 0
        model = model.to(device)
        model.eval() # 禁用dropout和bn层
        for idx,(data,label) in tqdm(enumerate(dataloader)):
            data = data.to(device)
            label = label.to(device)
            out = model(data)
            # 计算loss值
            loss += loss_function(out,label).item()
            # 计算accuracy
            # argmax这块先留个心眼，我不确定这个axis是不是需要微调
            pred = torch.argmax(out,axis=1)
            pred = pred.cpu().detach().numpy()
            label = label.cpu().detach().numpy()
            acc += self.compute_acc(pred,label)
        
        avg_loss = loss/ len(dataloader)
        avg_acc = acc / len(dataloader)
        return avg_loss,avg_acc
        
    
    def train(self,print_message=True):
        for i in tqdm(range(1,self.config.epoch+1)):
            # 进行一个model.train()
            self.model.train()
            # 对于训练过程中每一个epoch的参数记录
            train_loss = 0
            train_acc = 0
            for idx,(data,label) in tqdm(enumerate(self.trainloader)):
                data = data.to(self.device)
                label = label.to(self.device)
                out = self.model(data)
                loss = self.loss_function(out,label)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # 计算出来损失函数
                train_loss += loss.item()
                # 转到cpu上计算准确率
                pred = torch.argmax(out,axis=1)
                pred = np.array(pred.cpu().detach().numpy())
                label = np.array(label.cpu().detach().numpy())
                train_acc += self.compute_acc(pred,label)
            # 一个epoch结束后，计算均值
            train_loss = train_loss / len(self.trainloader)
            train_acc = train_acc  / len(self.trainloader)
            # 剩下的是验证集和测试集
            if self.validloader:
                valid_loss,valid_acc = self.eval(self.model,self.validloader,self.loss_function,self.device)
            else:
                valid_loss =0.0
                valid_acc = 0.0
            
            test_loss,test_acc = self.eval(self.model,self.testloader,self.loss_function,self.device)     
            if self.use_wandb:
                wandb.log(
                    {
                        "Train loss":train_loss,
                        "Train Accuracy":train_acc,
                        "Valid loss":valid_loss,
                        "Valid Accuracy":valid_acc,
                        "Test loss":test_loss,
                        "Test Accuracy":test_acc
                    }
                )
            if print_message:
                print('epoch:',i)
                print('Train loss',train_loss,'train_acc',train_acc)
                print('Valid loss',valid_loss,'valid_acc',valid_acc)
                print('Test loss',test_loss,'test acc',test_acc) 
        print('Finished')
    
    def stop(self):
        if self.use_wandb:
            wandb.finish()
        else:
            print("No wandb used")

In [10]:
# 设置一下config
config = base_config(lr = 0.01, epoch = 20,batch_size=32)

In [9]:
# 读取数据集
train_data,test_data = mnist.load_data()
# 数据集太大，用一小部分试一试
data,label = train_data[0][0:1000],train_data[1][0:1000]
# 考虑到灰度数据用conv2d处理的时候需要增加一个通道维度，需要对数据集进行一个变换
data = np.expand_dims(data,1)

# 构造数据集
data = torch.Tensor(data)
label = torch.LongTensor(label)
# 划分数据集
x_train,x_test,y_train,y_test = train_test_split(data,label)
x_train = x_train[150:]
y_train = y_train[150:]

x_valid = x_train[0:150]
y_valid = y_train[0:150]

TrainData = TensorDataset(x_train,y_train)
ValidData = TensorDataset(x_valid,y_valid)
TestData = TensorDataset(x_test,y_test)

In [None]:
# 测试用神经网络
# 撸个简单的卷积神经网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,3,3,1,1)
        self.pooling = nn.MaxPool2d(2)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc =  nn.Linear(14*14*3,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pooling(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.flatten(x)
        x= self.fc(x)
        return x

In [11]:
# 调用cuda,同时设置一些参数
device = torch.device(0)
model = Net()
optimizer = torch.optim.Adam(model.parameters(),lr = config.lr)
loss_function = nn.CrossEntropyLoss() 

In [12]:
trainer = classify_trainer(model=model,device=device,config=config,loss_function=loss_function,optimizer=optimizer,traindata=TrainData,testdata=TestData,valdata=ValidData,use_wandb=True)

wandb is loading
The project name :  my_temp_trainer_1


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzhijiao[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

In [13]:
trainer.train()

19it [00:05,  3.37it/s]00:00<?, ?it/s]
5it [00:00, 555.58it/s]
8it [00:00, 726.88it/s]
  5%|▌         | 1/20 [00:05<01:47,  5.67s/it]

epoch: 1
Train loss 1.1325401538296749 train_acc 0.6398026315789473
Valid loss 0.27536901384592055 valid_acc 0.909659090909091
Test loss 0.5699038505554199 test acc 0.8146033653846154


19it [00:00, 283.20it/s]
5it [00:00, 625.01it/s]
8it [00:00, 666.62it/s]
 10%|█         | 2/20 [00:05<00:43,  2.40s/it]

epoch: 2
Train loss 0.26620716954532425 train_acc 0.9199561403508772
Valid loss 0.14407578259706497 valid_acc 0.975
Test loss 0.5744858831167221 test acc 0.8353365384615384


19it [00:00, 216.24it/s]
5it [00:00, 833.63it/s]
8it [00:00, 727.29it/s]
 15%|█▌        | 3/20 [00:05<00:23,  1.36s/it]

epoch: 3
Train loss 0.12384537313329547 train_acc 0.9769736842105263
Valid loss 0.07092574387788772 valid_acc 0.98125
Test loss 0.5311756376177073 test acc 0.8674879807692307


19it [00:00, 316.68it/s]
5it [00:00, 555.95it/s]
8it [00:00, 533.25it/s]
 20%|██        | 4/20 [00:06<00:13,  1.16it/s]

epoch: 4
Train loss 0.05074793688560787 train_acc 0.993421052631579
Valid loss 0.03144514244049788 valid_acc 0.99375
Test loss 0.5509920343756676 test acc 0.8731971153846154


19it [00:00, 217.12it/s]
5it [00:00, 555.95it/s]
8it [00:00, 195.12it/s]
 25%|██▌       | 5/20 [00:06<00:09,  1.63it/s]

epoch: 5
Train loss 0.023161401531021846 train_acc 0.9983552631578947
Valid loss 0.019441003445535898 valid_acc 1.0
Test loss 0.5504506155848503 test acc 0.8753004807692307


19it [00:00, 294.52it/s]
5it [00:00, 714.78it/s]
8it [00:00, 371.84it/s]
 30%|███       | 6/20 [00:06<00:06,  2.25it/s]

epoch: 6
Train loss 0.013494530178018306 train_acc 1.0
Valid loss 0.013389047048985959 valid_acc 1.0
Test loss 0.5860956944525242 test acc 0.8762019230769231


19it [00:00, 256.93it/s]
5it [00:00, 277.85it/s]
8it [00:00, 533.37it/s]
 35%|███▌      | 7/20 [00:06<00:04,  2.90it/s]

epoch: 7
Train loss 0.009188119175010607 train_acc 1.0
Valid loss 0.008925470244139432 valid_acc 1.0
Test loss 0.5844008019194007 test acc 0.88671875


19it [00:00, 222.19it/s]
5it [00:00, 500.25it/s]
8it [00:00, 727.29it/s]
 40%|████      | 8/20 [00:06<00:03,  3.63it/s]

epoch: 8
Train loss 0.006765083238286407 train_acc 1.0
Valid loss 0.006974199786782265 valid_acc 1.0
Test loss 0.6112926742061973 test acc 0.8762019230769231


19it [00:00, 233.09it/s]
5it [00:00, 625.16it/s]
8it [00:00, 280.52it/s]
 45%|████▌     | 9/20 [00:06<00:02,  4.27it/s]

epoch: 9
Train loss 0.005388473172819144 train_acc 1.0
Valid loss 0.004838978406041861 valid_acc 1.0
Test loss 0.6100473236292601 test acc 0.8858173076923077


19it [00:00, 231.58it/s]
5it [00:00, 416.68it/s]
8it [00:00, 500.02it/s]
 50%|█████     | 10/20 [00:06<00:02,  4.92it/s]

epoch: 10
Train loss 0.004168477301534854 train_acc 1.0
Valid loss 0.004258430376648903 valid_acc 1.0
Test loss 0.6276752213016152 test acc 0.8810096153846154


19it [00:00, 173.46it/s]
5it [00:00, 323.58it/s]
8it [00:00, 500.39it/s]
 55%|█████▌    | 11/20 [00:07<00:01,  5.20it/s]

epoch: 11
Train loss 0.0034315689301461375 train_acc 1.0
Valid loss 0.0035156967118382455 valid_acc 1.0
Test loss 0.6575650675222278 test acc 0.8792067307692307


19it [00:00, 238.97it/s]
5it [00:00, 833.16it/s]
8it [00:00, 800.15it/s]
 60%|██████    | 12/20 [00:07<00:01,  5.93it/s]

epoch: 12
Train loss 0.003023978757770046 train_acc 1.0
Valid loss 0.00350077694747597 valid_acc 1.0
Test loss 0.6782620195299387 test acc 0.8704927884615384


19it [00:00, 327.57it/s]
5it [00:00, 833.76it/s]
8it [00:00, 889.07it/s]


epoch: 13
Train loss 0.0025740358181984015 train_acc 1.0
Valid loss 0.0025504824705421926 valid_acc 1.0
Test loss 0.6490727886557579 test acc 0.8801081730769231


19it [00:00, 322.04it/s]
5it [00:00, 555.68it/s]
8it [00:00, 888.84it/s]
 70%|███████   | 14/20 [00:07<00:00,  7.57it/s]

epoch: 14
Train loss 0.002258430823291603 train_acc 1.0
Valid loss 0.0023655532044358552 valid_acc 1.0
Test loss 0.6641375478357077 test acc 0.8810096153846154


19it [00:00, 309.38it/s]
5it [00:00, 714.29it/s]
8it [00:00, 939.90it/s]


epoch: 15
Train loss 0.0019458536132189788 train_acc 1.0
Valid loss 0.0019646560307592154 valid_acc 1.0
Test loss 0.6908159088343382 test acc 0.8801081730769231


19it [00:00, 144.45it/s]
5it [00:00, 357.14it/s]
8it [00:00, 270.97it/s]
 80%|████████  | 16/20 [00:07<00:00,  7.21it/s]

epoch: 16
Train loss 0.0017636455227865984 train_acc 1.0
Valid loss 0.0017546661430969835 valid_acc 1.0
Test loss 0.680209930986166 test acc 0.8801081730769231


19it [00:00, 235.97it/s]
5it [00:00, 525.00it/s]
8it [00:00, 800.08it/s]
 85%|████████▌ | 17/20 [00:07<00:00,  7.42it/s]

epoch: 17
Train loss 0.001513707718052166 train_acc 1.0
Valid loss 0.0017149405437521636 valid_acc 1.0
Test loss 0.6799017079174519 test acc 0.8810096153846154


19it [00:00, 339.24it/s]
5it [00:00, 705.04it/s]
8it [00:00, 800.46it/s]


epoch: 18
Train loss 0.001414583447560864 train_acc 1.0
Valid loss 0.0013643806567415595 valid_acc 1.0
Test loss 0.6816367506980896 test acc 0.8849158653846154


19it [00:00, 283.26it/s]
5it [00:00, 795.46it/s]
8it [00:00, 888.51it/s]
 95%|█████████▌| 19/20 [00:07<00:00,  8.38it/s]

epoch: 19
Train loss 0.0012879216504332266 train_acc 1.0
Valid loss 0.0013248325791209935 valid_acc 1.0
Test loss 0.6807376639917493 test acc 0.8819110576923077


19it [00:00, 198.07it/s]
5it [00:00, 999.98it/s]
8it [00:00, 665.66it/s]
100%|██████████| 20/20 [00:08<00:00,  2.49it/s]

epoch: 20
Train loss 0.0011371760537210655 train_acc 1.0
Valid loss 0.0011660270625725388 valid_acc 1.0
Test loss 0.7008891480509192 test acc 0.8810096153846154
Finished





0,1
Test Accuracy,▁▃▆▇▇▇█▇█▇▇▆▇▇▇▇▇██▇
Test loss,▃▃▁▂▂▃▃▄▄▅▆▇▆▆█▇▇▇▇█
Train Accuracy,▁▆██████████████████
Train loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Valid Accuracy,▁▆▇█████████████████
Valid loss,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,0.88101
Test loss,0.70089
Train Accuracy,1.0
Train loss,0.00114
Valid Accuracy,1.0
Valid loss,0.00117
