In [1]:
from tensorflow.keras.datasets import mnist
import torch
from torch import nn 
import numpy as np 
from torch.utils.data import TensorDataset,DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb

In [10]:
# 通用config
config = {
    "lr":0.001,
    "batch_size":16,
    "epoch":20,
}

In [19]:
class base_config():
    def __init__(self,lr,epoch,batch_size):
        self.lr = lr 
        self.epoch = epoch
        self.batch_size = batch_size

In [20]:
def create_dataloader(data,batch_size):
    return DataLoader(data,batch_size=batch_size,shuffle=True)

In [21]:
class classify_trainer():
    def __init__(self,model,optimizer,loss_function,traindata,testdata,config,device,valdata=None,use_wandb=False):
        self.device = device
        self.model = model.to(device)
        self.config = config
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.use_wandb = use_wandb
        self.trainloader = DataLoader(traindata,batch_size=self.config.batch_size,shuffle=True)
        self.testloader = DataLoader(testdata,batch_size=self.config.batch_size,shuffle=True)
        if valdata is not None:
            self.validloader = DataLoader(valdata,batch_size=self.config.batch_size,shuffle=True)
        else:
            self.validloader = None
        
        # init wandb
        if self.use_wandb:
            project_name = input('please input wandb project name')
            print('wandb is loading')
            config_dict = {
                "lr":self.config.lr,
                "batch_size":self.config.batch_size,
                "epoch":self.config.epoch,
            }
            wandb.init(project_name=project_name,config=config_dict)
    
    # compute accuracy
    @classmethod 
    def compute_acc(pred,label):
        return np.equal(pred,label).mean()
    
    @classmethod
    def eval(self,model,dataloader,loss_function,device):
        loss = 0
        acc = 0
        model = model.to(device)
        model.eval() # 禁用dropout和bn层
        for idx,(data,label) in tqdm(enumerate(dataloader)):
            data = data.to(device)
            label = label.to(device)
            out = model(data)
            # 计算loss值
            loss += loss_function(out,label).item()
            # 计算accuracy
            # argmax这块先留个心眼，我不确定这个axis是不是需要微调
            pred = torch.argmax(out,axis=1)
            pred = pred.cpu().detach().numpy()
            label = label.cpu().detach().numpy()
            acc += self.compute_acc(pred,label)
        
        avg_loss = loss/ len(dataloader)
        avg_acc = acc / len(dataloader)
        return avg_loss,avg_acc
        
    
    def train(self,print_message=True):
        for i in tqdm(range(1,self.config.epoch+1)):
            # 进行一个model.train()
            self.model.train()
            # 对于训练过程中每一个epoch的参数记录
            train_loss = 0
            train_acc = 0
            for idx,(data,label) in tqdm(enumerate(self.trainloader)):
                data = data.to(self.device)
                label = label.to(self.device)
                out = self.model(data)
                loss = self.loss_function(out,label)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # 计算出来损失函数
                train_loss += loss.item()
                # 转到cpu上计算准确率
                pred = torch.argmax(out,axis=1)
                pred = pred.cpu().detach().numpy()
                label = label.cpu().detach().numpy()
                train_acc += self.compute_acc(pred,label)
            # 一个epoch结束后，计算均值
            train_loss = train_loss / len(self.trainloader)
            train_acc = train_acc  / len(self.trainloader)
            # 剩下的是验证集和测试集
            if self.validloader:
                valid_loss,valid_acc = eval(self.model,self.validloader,self.loss_function,self.device)
            else:
                valid_loss =0.0
                valid_acc = 0.0
            
            test_loss,test_acc = eval(self.model,self.testloader,self.loss_function,self.device)     
            if self.use_wandb:
                wandb.log(
                    {
                        "Train loss":train_loss,
                        "Train Accuracy":train_acc,
                        "Valid loss":valid_loss,
                        "Valid Accuracy":valid_acc,
                        "Test loss":test_loss,
                        "Test Accuracy":test_acc
                    }
                )
            if print_message:
                print('epoch:',i)
                print('Train loss',train_loss,'train_acc',train_acc)
                print('Valid loss',valid_loss,'valid_acc',valid_acc)
                print('Test loss',test_loss,'test acc',test_acc) 
        print('Finished')
        if self.use_wandb:
            wandb.finish()

In [2]:
!wandb login

wandb: Currently logged in as: zhijiao. Use `wandb login --relogin` to force relogin


In [3]:
wandb.init(project='Temp_trainer_2')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzhijiao[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666656966, max=1.0…

In [4]:
train_data,test_data = mnist.load_data()

In [5]:
# 数据集太大，用一小部分试一试
data,label = train_data[0][0:1000],train_data[1][0:1000]
# 考虑到灰度数据用conv2d处理的时候需要增加一个通道维度，需要对数据集进行一个变换
data = np.expand_dims(data,1)

In [6]:
batch_size = 32

In [7]:
# 撸个简单的卷积神经网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,3,3,1,1)
        self.pooling = nn.MaxPool2d(2)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc =  nn.Linear(14*14*3,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pooling(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.flatten(x)
        x= self.fc(x)
        return x


In [8]:
# 设置一下Batch_size


batch_size = 32
# 构造数据集
data = torch.Tensor(data)
label = torch.LongTensor(label)
# 划分数据集
x_train,x_test,y_train,y_test = train_test_split(data,label)
x_train = x_train[150:]
y_train = y_train[150:]

x_valid = x_train[0:150]
y_valid = y_train[0:150]

TrainData = TensorDataset(x_train,y_train)
ValidData = TensorDataset(x_valid,y_valid)
TestData = TensorDataset(x_test,y_test)

# 搞个loader
TrainLoader = DataLoader(TrainData,batch_size=batch_size,shuffle=True)
ValidLoader = DataLoader(ValidData,batch_size=batch_size,shuffle=True)
TestLoader = DataLoader(TestData,batch_size=batch_size,shuffle=True)


In [9]:
# 训练(我觉得我有必要写一个Trainer的库，方便之后进行调用，但是这次主要是为了搞wandb)
# 在写主要的训练函数之前，需要先写一个eval函数，用于对验证集和测试集进行评估。
# 再补一个函数，对于分类的数据，计算accuracy,输入为np.array类型的输出值（因此需要在训练过程中进行一个数据转换）
def compute_acc(pred,label):
    return np.equal(pred,label).mean()
    
def eval(model,dataloader,loss_function,device):
    loss = 0
    acc = 0
    model.eval() # 禁用dropout和bn层
    for idx,(data,label) in tqdm(enumerate(dataloader)):
        data = data.to(device)
        label = label.to(device)
        out = model(data)
        # 计算loss值
        loss += loss_function(out,label).item()
        # 计算accuracy
        # argmax这块先留个心眼，我不确定这个axis是不是需要微调
        pred = torch.argmax(out,axis=1)
        pred = pred.cpu().detach().numpy()
        label = label.cpu().detach().numpy()
        acc += compute_acc(pred,label)
    
    avg_loss = loss/ len(dataloader)
    avg_acc = acc / len(dataloader)
    return avg_loss,avg_acc
         
# 现在可以写train函数了   (model需要先放到GPU上才可以)
def train(model,optimizer,loss_function,epoch,trainloader,validloader,testloader,device,use_wandb=False):
    # 开始训练
    for i in tqdm(range(1,epoch+1)):
        # 进行一个model.train()
        model.train()
        # 对于训练过程中每一个epoch的参数记录
        train_loss = 0
        train_acc = 0
        for idx,(data,label) in tqdm(enumerate(trainloader)):
            data = data.to(device)
            label = label.to(device)
            out = model(data)
            loss = loss_function(out,label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 计算出来损失函数
            train_loss += loss.item()
            # 转到cpu上计算准确率
            pred = torch.argmax(out,axis=1)
            pred = pred.cpu().detach().numpy()
            label = label.cpu().detach().numpy()
            train_acc += compute_acc(pred,label)
        # 一个epoch结束后，计算均值
        train_loss = train_loss / len(trainloader)
        train_acc = train_acc  / len(trainloader)
        # 剩下的是验证集和测试集
        valid_loss,valid_acc = eval(model,validloader,loss_function,device=device)
        test_loss,test_acc = eval(model,testloader,loss_function,device=device)       
        if use_wandb:
            wandb.log(
                {
                    "Train loss":train_loss,
                    "Train Accuracy":train_acc,
                    "Valid loss":valid_loss,
                    "Valid Accuracy":valid_acc,
                    "Test loss":test_loss,
                    "Test Accuracy":test_acc
                }
            )
             
        print('epoch:',i)
        print('Train loss',train_loss,'train_acc',train_acc)
        print('Valid loss',valid_loss,'valid_acc',valid_acc)
        print('Test loss',test_loss,'test acc',test_acc)
    

In [10]:
# 调用cuda,同时设置一些参数
device = torch.device(0)
epoch = 20
model = Net()
optimizer = torch.optim.Adam(model.parameters())
loss_function = nn.CrossEntropyLoss()
model = model.to(device)

In [11]:
# 开始第一次debug
train(model,optimizer,loss_function,epoch,TrainLoader,ValidLoader,TestLoader,device,use_wandb=True)

19it [00:04,  4.04it/s]00:00<?, ?it/s]
5it [00:00, 380.33it/s]
8it [00:00, 457.75it/s]
  5%|▌         | 1/20 [00:04<01:30,  4.74s/it]

epoch: 1
Train loss 2.176876469662315 train_acc 0.2538377192982456
Valid loss 1.9441388845443726 valid_acc 0.5409090909090909
Test loss 2.006058171391487 test acc 0.4423076923076923


19it [00:00, 210.00it/s]
5it [00:00, 447.84it/s]
8it [00:00, 634.76it/s]
 10%|█         | 2/20 [00:04<00:36,  2.03s/it]

epoch: 2
Train loss 1.7185200390062834 train_acc 0.6260964912280702
Valid loss 1.4825405120849608 valid_acc 0.7022727272727273
Test loss 1.5489587932825089 test acc 0.6781850961538461


19it [00:00, 129.92it/s]
5it [00:00, 384.62it/s]
8it [00:00, 571.08it/s]
 15%|█▌        | 3/20 [00:05<00:20,  1.19s/it]

epoch: 3
Train loss 1.1835577111495168 train_acc 0.7834429824561403
Valid loss 0.9786724090576172 valid_acc 0.8198863636363637
Test loss 1.0841639265418053 test acc 0.7605168269230769


19it [00:00, 209.59it/s]
5it [00:00, 491.04it/s]
8it [00:00, 503.62it/s]
 20%|██        | 4/20 [00:05<00:12,  1.29it/s]

epoch: 4
Train loss 0.7799321914974012 train_acc 0.8245614035087719
Valid loss 0.6690926909446716 valid_acc 0.847159090909091
Test loss 0.7781154215335846 test acc 0.8001802884615384


19it [00:00, 201.74it/s]
5it [00:00, 550.61it/s]
8it [00:00, 419.15it/s]
 25%|██▌       | 5/20 [00:05<00:08,  1.83it/s]

epoch: 5
Train loss 0.5695752219149941 train_acc 0.8695175438596491
Valid loss 0.5092377185821533 valid_acc 0.93125
Test loss 0.6232075542211533 test acc 0.8578725961538461


19it [00:00, 211.37it/s]
5it [00:00, 555.99it/s]
8it [00:00, 333.32it/s]
 30%|███       | 6/20 [00:05<00:05,  2.43it/s]

epoch: 6
Train loss 0.4442730702851948 train_acc 0.9078947368421053
Valid loss 0.3965911388397217 valid_acc 0.959659090909091
Test loss 0.5429770946502686 test acc 0.8686899038461539


19it [00:00, 162.40it/s]
5it [00:00, 416.59it/s]
8it [00:00, 265.56it/s]
 35%|███▌      | 7/20 [00:05<00:04,  2.95it/s]

epoch: 7
Train loss 0.35020308588680465 train_acc 0.9314692982456141
Valid loss 0.31420656442642214 valid_acc 0.96875
Test loss 0.47877057641744614 test acc 0.8722956730769231


19it [00:00, 139.25it/s]
5it [00:00, 555.67it/s]
8it [00:00, 666.75it/s]
 40%|████      | 8/20 [00:05<00:03,  3.47it/s]

epoch: 8
Train loss 0.28443452637446553 train_acc 0.9523026315789473
Valid loss 0.2602283746004105 valid_acc 0.972159090909091
Test loss 0.4388983026146889 test acc 0.8888221153846154


19it [00:00, 215.41it/s]
5it [00:00, 416.49it/s]
8it [00:00, 571.52it/s]
 45%|████▌     | 9/20 [00:05<00:02,  4.17it/s]

epoch: 9
Train loss 0.23758901028256668 train_acc 0.9665570175438596
Valid loss 0.21063620448112488 valid_acc 0.975
Test loss 0.4235890004783869 test acc 0.8822115384615384


19it [00:00, 205.24it/s]
5it [00:00, 624.36it/s]
8it [00:00, 514.85it/s]
 50%|█████     | 10/20 [00:06<00:02,  4.82it/s]

epoch: 10
Train loss 0.19892563514019312 train_acc 0.9775219298245614
Valid loss 0.1730757921934128 valid_acc 0.98125
Test loss 0.3952592685818672 test acc 0.8870192307692307


19it [00:00, 208.61it/s]
5it [00:00, 504.95it/s]
8it [00:00, 216.24it/s]
 55%|█████▌    | 11/20 [00:06<00:01,  5.20it/s]

epoch: 11
Train loss 0.16968969608608045 train_acc 0.9808114035087719
Valid loss 0.14408986866474152 valid_acc 0.99375
Test loss 0.37651101872324944 test acc 0.8801081730769231


19it [00:00, 153.84it/s]
5it [00:00, 333.29it/s]
8it [00:00, 444.37it/s]
 60%|██████    | 12/20 [00:06<00:01,  5.20it/s]

epoch: 12
Train loss 0.1411876792186185 train_acc 0.9835526315789473
Valid loss 0.12095648199319839 valid_acc 1.0
Test loss 0.36550818011164665 test acc 0.8810096153846154


19it [00:00, 182.58it/s]
5it [00:00, 454.11it/s]
8it [00:00, 400.00it/s]
 65%|██████▌   | 13/20 [00:06<00:01,  5.44it/s]

epoch: 13
Train loss 0.12149415674962495 train_acc 0.9884868421052632
Valid loss 0.10261463820934295 valid_acc 1.0
Test loss 0.35833899676799774 test acc 0.8783052884615384


19it [00:00, 209.83it/s]
5it [00:00, 714.56it/s]
8it [00:00, 701.24it/s]
 70%|███████   | 14/20 [00:06<00:00,  6.04it/s]

epoch: 14
Train loss 0.10566178866122898 train_acc 0.9901315789473685
Valid loss 0.08906531482934951 valid_acc 1.0
Test loss 0.35306747630238533 test acc 0.8801081730769231


19it [00:00, 230.16it/s]
5it [00:00, 624.97it/s]
8it [00:00, 500.05it/s]
 75%|███████▌  | 15/20 [00:06<00:00,  6.54it/s]

epoch: 15
Train loss 0.09222753581247832 train_acc 0.9895833333333333
Valid loss 0.07527245432138444 valid_acc 1.0
Test loss 0.3416645349934697 test acc 0.8792067307692307


19it [00:00, 213.08it/s]
5it [00:00, 624.99it/s]
8it [00:00, 456.45it/s]
 80%|████████  | 16/20 [00:07<00:00,  6.82it/s]

epoch: 16
Train loss 0.07939997256586426 train_acc 0.9967105263157895
Valid loss 0.06924290880560875 valid_acc 1.0
Test loss 0.3413748424500227 test acc 0.8822115384615384


19it [00:00, 176.82it/s]
5it [00:00, 500.17it/s]
8it [00:00, 176.59it/s]


epoch: 

 85%|████████▌ | 17/20 [00:07<00:00,  6.28it/s]

17
Train loss 0.07091004009309568 train_acc 0.9983552631578947
Valid loss 0.06164337694644928 valid_acc 1.0
Test loss 0.3320395862683654 test acc 0.8879206730769231


19it [00:00, 213.45it/s]
5it [00:00, 475.62it/s]
8it [00:00, 500.07it/s]
 90%|█████████ | 18/20 [00:07<00:00,  6.57it/s]

epoch: 18
Train loss 0.06238919251451367 train_acc 1.0
Valid loss 0.05552132464945316 valid_acc 1.0
Test loss 0.3244030363857746 test acc 0.8897235576923077


19it [00:00, 228.28it/s]
5it [00:00, 600.11it/s]
8it [00:00, 795.56it/s]
 95%|█████████▌| 19/20 [00:07<00:00,  7.11it/s]

epoch: 19
Train loss 0.055658355355262756 train_acc 1.0
Valid loss 0.04844651743769646 valid_acc 1.0
Test loss 0.3232556041330099 test acc 0.8849158653846154


19it [00:00, 192.87it/s]
5it [00:00, 500.13it/s]
8it [00:00, 487.52it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]

epoch: 20
Train loss 0.05072746661148573 train_acc 1.0
Valid loss 0.04303951561450958 valid_acc 1.0
Test loss 0.3227849081158638 test acc 0.8879206730769231





In [12]:
wandb.finish()

0,1
Test Accuracy,▁▅▆▇████████████████
Test loss,█▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
Train Accuracy,▁▄▆▆▇▇▇█████████████
Train loss,█▆▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Valid Accuracy,▁▃▅▆▇▇██████████████
Valid loss,█▆▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,0.88792
Test loss,0.32278
Train Accuracy,1.0
Train loss,0.05073
Valid Accuracy,1.0
Valid loss,0.04304
