<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#模型实现" data-toc-modified-id="模型实现-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>模型实现</a></span><ul class="toc-item"><li><span><a href="#数据集制作" data-toc-modified-id="数据集制作-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>数据集制作</a></span><ul class="toc-item"><li><span><a href="#图像预处理" data-toc-modified-id="图像预处理-1.1.1"><span class="toc-item-num">1.1.1&nbsp;&nbsp;</span>图像预处理</a></span></li><li><span><a href="#建立数据集类" data-toc-modified-id="建立数据集类-1.1.2"><span class="toc-item-num">1.1.2&nbsp;&nbsp;</span>建立数据集类</a></span></li></ul></li><li><span><a href="#网络模型搭建" data-toc-modified-id="网络模型搭建-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>网络模型搭建</a></span><ul class="toc-item"><li><span><a href="#网络搭建" data-toc-modified-id="网络搭建-1.2.1"><span class="toc-item-num">1.2.1&nbsp;&nbsp;</span>网络搭建</a></span></li></ul></li><li><span><a href="#DataLoader搭建" data-toc-modified-id="DataLoader搭建-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>DataLoader搭建</a></span><ul class="toc-item"><li><span><a href="#DataLoader类创建" data-toc-modified-id="DataLoader类创建-1.3.1"><span class="toc-item-num">1.3.1&nbsp;&nbsp;</span>DataLoader类创建</a></span></li></ul></li><li><span><a href="#训练模块" data-toc-modified-id="训练模块-1.4"><span class="toc-item-num">1.4&nbsp;&nbsp;</span>训练模块</a></span><ul class="toc-item"><li><span><a href="#训练逻辑" data-toc-modified-id="训练逻辑-1.4.1"><span class="toc-item-num">1.4.1&nbsp;&nbsp;</span>训练逻辑</a></span></li><li><span><a href="#创建训练类" data-toc-modified-id="创建训练类-1.4.2"><span class="toc-item-num">1.4.2&nbsp;&nbsp;</span>创建训练类</a></span></li><li><span><a href="#训练类测试" data-toc-modified-id="训练类测试-1.4.3"><span class="toc-item-num">1.4.3&nbsp;&nbsp;</span>训练类测试</a></span></li></ul></li><li><span><a href="#预测模块" data-toc-modified-id="预测模块-1.5"><span class="toc-item-num">1.5&nbsp;&nbsp;</span>预测模块</a></span></li></ul></li></ul></div>

## 模型实现

### 数据集制作
#### 图像预处理
具体处理措施：  

·所有图片resize到固定长宽

#### 建立数据集类

In [1]:
import cv2
import os
import torch
import numpy as np
from torch.utils.data.dataset import Dataset

class MyDataset(Dataset):
    
    def __init__(self,img_path,label_content,img_width,img_height):
        self.img_path = img_path
        self.label_content = label_content
        self.img_width = img_width
        self.img_height = img_height
    
    def __getitem__(self,index):
        img_path = self.img_path[index]
        img_cv2 = cv2.imread(img_path)  #return （H，W，C），C list by B，G，R
        img_array = cv2.resize(img_cv2,(self.img_width,self.img_height))  #cv2.resize的输出为（w,h），宽在前，高在后
        
        img_name = os.path.split(img_path)[1]
        img_label = self.label_content[img_name]['label']
        #transform the label 0~9 to index 1~10，index 0 is for blank
        img_label_index_list = [x+1 for x in img_label]
        
        return img_array,img_label_index_list
    
    def __len__(self):
        return len(self.img_path)


### 网络模型搭建

#### 网络搭建
按照一张图片竖向切分成15条让rnn识别来设计。

In [2]:
import torch.nn as nn
import torchvision.models as models

class MyModel(nn.Module):
    def __init__(self,T,img_width):
        super().__init__()
        self.img_width = img_width
        self.cnn_out_width = self.cal_finnal_width()
        #self.cnn: input(N,3,240,480),output(N,512,8,15)
        trained_resnet18 = models.resnet18(weights='DEFAULT')
        self.cnn = nn.Sequential(trained_resnet18.conv1,
                                 trained_resnet18.bn1,
                                 trained_resnet18.relu,
                                 trained_resnet18.maxpool,
                                 trained_resnet18.layer1,
                                 trained_resnet18.layer2,
                                 trained_resnet18.layer3,
                                 trained_resnet18.layer4)
        #self.feat_to_seq: input(N,15,512,1),output(N,30,512,1)
        self.feat_to_seq = nn.Conv2d(self.cnn_out_width,T,kernel_size=(1,1),stride=1)
        #self.rnn: input(15,N,512),output(15,N,64*2)
        self.rnn = nn.GRU(input_size=512,hidden_size=64,bidirectional=True)
        #self.linear: input(15,N,64*2),output(15,N,11)
        self.linear = nn.Linear(128,11)
        #self.softmax: input(15,N,11),output(15,N,11) 
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self,X):
        X = self.cnn(X)
        X = X.mean(dim=2,keepdim=True)
        X = X.permute((0,3,1,2))
        X = self.feat_to_seq(X)
        X = X.permute((1,0,2,3))
        X = X.squeeze(dim=3)
        X,_ = self.rnn(X)
        X = self.linear(X)
#         y_hat = self.softmax(X)
#         return y_hat
        return X

    def cal_finnal_width(self):
        width = self.img_width
        for i in range(5):
            if width%2 == 0:
                width = int(width/2)
            else:
                width = int(width/2)+1
        return width

### DataLoader搭建

#### DataLoader类创建

In [3]:
class MyDataLoader():
    def __init__(self,my_dataset,batch_size=20,use_my_collate=True,shuffle=False,):
        from torch.utils.data import DataLoader
        if use_my_collate:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle,collate_fn=self.my_collate)
        else:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle)
        
    def my_collate(self,batch_data):
        '''
        batch_data shape: [(feat_array,label_list),(feat_array,label_list),……]
        '''
        feat_list =[]
        label_list = []
        label_len_list = []
        for sample in batch_data:
            feat_list.append(torch.from_numpy(sample[0].transpose((2,0,1))))
            label_list.extend(sample[1])
            label_len_list.append(len(sample[1]))
            
        feat = torch.stack(feat_list).float()
        label = torch.Tensor(label_list).int()
        label_len = torch.Tensor(label_len_list).int()
        
        return feat,label,label_len
        

### 训练模块

#### 训练逻辑
主要功能：  
·　进度显示  
·　固定随机种子  
·　学习率自动调整  
·　梯度裁剪  
·　模型存储  
·　模型读取

#### 创建训练类

In [4]:
class MyTrain():
    def __init__(self,max_epoch=1,batch_size=20,lr=0.001,random_seed=None,grad_threshold=10,T=30,img_width=200,out_dir='./'):
        self.max_epoch = max_epoch
        self.batch_size = batch_size
        self.lr = lr
        self.random_seed = random_seed
        self.grad_threshold = grad_threshold
        self.T = T
        self.img_width = img_width
        self.img_height = int(0.5*img_width)
        self.out_dir = out_dir
    
    def train(self):
        import json
        import glob
        
        max_epoch,batch_size,grad_threshold = self.max_epoch,self.batch_size,self.grad_threshold
        
        #fix random seed
        if self.random_seed is not None:
            self.fix_seed()
            
        #creat dataset instance
        img_path = glob.glob('../input/train/*.png')
        label_content = json.load(open('../input/train.json'))
        img_path.sort()
        my_dataset = MyDataset(img_path,label_content,self.img_width,self.img_height)

        #creat model and loss instance
        crnn_model = MyModel(self.T,self.img_width)
        ctc_loss = nn.CTCLoss()
        if torch.cuda.is_available():
            crnn_model.cuda()
#             ctc_loss.cuda()
        
        #creat optim instance
        adam_optimizer = torch.optim.Adam(crnn_model.parameters(),lr=self.lr,weight_decay=0.0001)
        batch_device = next(iter(crnn_model.parameters())).device
        print('train device:',batch_device)
        
        acc_max = 0
        
#         train each epoch
        for epoch_index in range(max_epoch):
            
            loss_list=[]
            acc_list =[]
            
            #creat dataloader instance
            from tqdm import tqdm
            my_dataloader = MyDataLoader(my_dataset,batch_size=batch_size,shuffle=True).data_loader
#             a = next(iter(my_dataloader))
#             print('img:',a[0].shape)
#             print('label:\n',a[1])
#             print('label_len:',a[2])
#             break
            my_dataloader = tqdm(my_dataloader,ncols=120)
            
            #train each batch data
            for batch_index,batch_data in enumerate(my_dataloader):
                
                batch_feat = batch_data[0]
                batch_label = batch_data[1]
                batch_label_len = batch_data[2]
                if torch.cuda.is_available():
                    batch_feat = batch_feat.cuda()
                    batch_label = batch_label.cuda()
                    batch_label_len = batch_label_len.cuda()
                    
                adam_optimizer.zero_grad()
                
                batch_y_hat = crnn_model(batch_feat) 
                batch_y_hat = batch_y_hat.to(torch.float64)  #经过crnn计算出来结果是float32,传入ctcloss使用GPU时会报错，把这个结果改成float64就不会报错
                batch_y_hat_len = torch.LongTensor([len(batch_y_hat)]*batch_size)
                
#                 print('1:',batch_y_hat,'\n','2:',batch_label,'\n','3:',batch_y_hat_len,'\n','4:',batch_label_len)                
                batch_loss = ctc_loss(batch_y_hat,batch_label,batch_y_hat_len,batch_label_len)
                batch_loss.backward()
                
                #进行梯度裁剪，避免梯度爆炸
                nn.utils.clip_grad_norm_(crnn_model.parameters(),grad_threshold)
                                
                #更新模型参数
                adam_optimizer.step()
                
                #调整学习率
                adam_optimizer.param_groups[0]['lr'] = self.lr*(0.8**(epoch_index%10))
                
                #更新tqdm进度条的描述内容
                batch_acc = self.get_accurancy(batch_y_hat,batch_label,batch_label_len)
                loss_list.append(batch_loss)
                acc_list.append(batch_acc)
                batch_lr = adam_optimizer.param_groups[0]['lr']
                batch_loss_mean = sum(loss_list)/len(loss_list)
                batch_acc_mean = round(sum(acc_list)/len(acc_list),3)
                                
                my_dataloader.set_description(f'epoch{epoch_index}| batch{batch_index}| loss: {round(batch_loss.item(),3)}| lr: {round(batch_lr,4)}| loss_mean: {round(batch_loss_mean.item(),3)}| batch_acc: {batch_acc_mean}')
           
            #保存最近一个epoch的模型参数，如果比之前的模型都好，则将其设置为最佳模型
            torch.save(crnn_model.state_dict(),os.path.join(self.out_dir,'crnn_resnet18_ctc_last'))
            if batch_acc_mean > acc_max:
                acc_max = batch_acc_mean
                torch.save(crnn_model.state_dict(),os.path.join(self.out_dir,'crnn_resnet18_ctc_best'))
    
    def fix_seed(self):
        import random
        import numpy as np
        import torch
        
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        torch.random.manual_seed(self.random_seed)
        torch.cuda.manual_seed_all(self.random_seed)
        torch.backends.cudnn.deterministic = True
        
    def get_accurancy(self,y_hat,label,label_len):
        #y_hat(T,N,C)
        y_hat = y_hat.permute(1,0,2).cpu()
#         print('y_hat:',y_hat[0])
        label =label.cpu()
        label_len = label_len.cpu()
        #pred(N,T)
        pred = torch.argmax(y_hat,dim=2)
#         print('pred:',pred[0])
        batch_size = pred.shape[0]
        acc = 0
        #decode
        for i in range(batch_size):
            raw_pred_list = list(pred[i].numpy())
#             print('raw_pred_list:',raw_pred_list)
            pred_data = []
            for j in range(len(raw_pred_list)):
                if j == 0 and raw_pred_list[0] != 0:
                    pred_data.append(raw_pred_list[0]-1)
                if j != 0 and raw_pred_list[j] != raw_pred_list[j-1] and raw_pred_list[j] != 0:
                    pred_data.append(raw_pred_list[j]-1)
#             print('pred_data:',pred_data)
            
            label_start_index = int(torch.sum(label_len[:i]).item())
            label_end_index = int(label_len[i].item())+label_start_index
            label_data = list(label[label_start_index:label_end_index].int().numpy())
            label_data = [x-1 for x in label_data]
#             print(f'pred_data:{pred_data},label_data:{label_data}\n')
            
            if pred_data == label_data:
                acc+=1
        acc/=batch_size
        return acc
        
    def model_save(self,model):
        pass
    def model_load(self):
        pass
        

In [15]:
import torch
import os
import torch.nn as nn
import torchvision.models as models
test_model = models.resnet18(weights = 'DEFAULT')
torch.save(test_model.state_dict(),os.path.join('./','crnn_resnet18_ctc_best'))

#### 训练类测试

训练失败  

表现：经过1-2个batch之后，预测的编号就全都变成0，也就是空白了，不管训练多少次，预测的结果都是0，且loss为负数  

原因查找：  

1，是否对图片进行norm处理，示例中没有进行norm，我进行了norm。把norm去掉之后仍然不能训练，不是这个原因。  
2, 图片预处理流程不同。实例中dataset输出的图片是array格式的，形状是(H,W,C),用从cv2读取的，且channel的顺序是B，G，R，我的dataset输出的图片是tensor格式的，channel修改成了RGB的顺序，且形状为(C,H,W)。对是否能训练没用，但是让整体结构更清晰了。  
3，传入ctcloss的preds是否要经过softmax。示范中没有经过，我经过了。按照示范调整之后，损失从负值变成了正值，但是，损失还是不能正常更新变小，仍然停留在一个值不动。  
4，模型架构不一致，连接cnn和rnn的部分不同，尤其是把cnn的结果转换成高度为1的矩阵的时候不一样，一是转换的方式不同，而是转换完之后把图片分成的竖条数量不同。先实验了转换方式，示范是用mean方法在高度这个维度上把高度降为1，再对宽度这个维度使用卷积升维，把宽度从7变成30，具体方法是先对mean的结果进行permute，把宽度换到channel上，然后再使用conV2d对channel进行升维，然后再用permute把结果转换回来。因为我设置的输入图片的尺寸比示范大,我按照示范的方式，所以只用mean把高度降成1就转换成了我理想的输出宽度15，实验之后发现对学习没有用，预测值还是为0。到这里怀疑是输出宽度太小了，影响了实验结果，考虑将每张图片的输出宽度调整到30，和示范保持一致再看看结果。发现还是不行。  
5，尝试将图片的输入尺寸变成100*200，然后完全按照示例中的cnn到rnn的转换方法，并最终输出T=30，模型开始能够正常学习…………  
5，尝试保持图片输出尺寸为100*200，但是修改最终输出T为（13,15,18,20,23,25,28,30）中的一个，对比结果。


In [73]:
if __name__ == '__main__':
    img_width_list = [100,200,300,400,500]
    T_list = [13,15,18,20,23,25,28,30]
    for T in T_list:
        for width in img_width_list:
            print('T:',T,'width:',width)
            test_train = MyTrain(max_epoch=3,T=T,img_width=width)
            test_train.train()

T: 13 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.751| lr: 0.001| loss_mean: 2.201| batch_acc: 0.122: 100%|█| 1500/1500 [05:35<00:00,  4.48it/s
epoch1| batch1499| loss: 1.336| lr: 0.0008| loss_mean: 1.344| batch_acc: 0.418: 100%|█| 1500/1500 [02:08<00:00, 11.66it/
epoch2| batch1499| loss: 0.672| lr: 0.0006400000000000002| loss_mean: 1.001| batch_acc: 0.566: 100%|█| 1500/1500 [02:08<


T: 13 width: 200
train device: cuda:0


epoch0| batch1499| loss: 0.741| lr: 0.001| loss_mean: 1.839| batch_acc: 0.258: 100%|█| 1500/1500 [02:40<00:00,  9.33it/s
epoch1| batch1499| loss: 0.546| lr: 0.0008| loss_mean: 0.94| batch_acc: 0.601: 100%|█| 1500/1500 [02:40<00:00,  9.35it/s
epoch2| batch1499| loss: 0.931| lr: 0.0006400000000000002| loss_mean: 0.708| batch_acc: 0.682: 100%|█| 1500/1500 [02:40<


T: 13 width: 300
train device: cuda:0


epoch0| batch1499| loss: 1.318| lr: 0.001| loss_mean: 1.68| batch_acc: 0.335: 100%|█| 1500/1500 [03:53<00:00,  6.41it/s]
epoch1| batch1499| loss: 0.861| lr: 0.0008| loss_mean: 0.804| batch_acc: 0.644: 100%|█| 1500/1500 [03:53<00:00,  6.41it/
epoch2| batch1499| loss: 0.708| lr: 0.0006400000000000002| loss_mean: 0.618| batch_acc: 0.717: 100%|█| 1500/1500 [03:54<


T: 13 width: 400
train device: cuda:0


epoch0| batch1499| loss: 0.682| lr: 0.001| loss_mean: 1.713| batch_acc: 0.307: 100%|█| 1500/1500 [05:32<00:00,  4.51it/s
epoch1| batch1499| loss: 0.974| lr: 0.0008| loss_mean: 0.806| batch_acc: 0.646: 100%|█| 1500/1500 [05:33<00:00,  4.50it/
epoch2| batch1499| loss: 0.767| lr: 0.0006400000000000002| loss_mean: 0.625| batch_acc: 0.71: 100%|█| 1500/1500 [08:54<0


T: 13 width: 500
train device: cuda:0


epoch0| batch1499| loss: 0.874| lr: 0.001| loss_mean: 1.648| batch_acc: 0.335: 100%|█| 1500/1500 [07:18<00:00,  3.42it/s
epoch1| batch1499| loss: 0.329| lr: 0.0008| loss_mean: 0.789| batch_acc: 0.662: 100%|█| 1500/1500 [07:11<00:00,  3.47it/
epoch2| batch1499| loss: 1.12| lr: 0.0006400000000000002| loss_mean: 0.617| batch_acc: 0.72: 100%|█| 1500/1500 [07:13<00


T: 15 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.918| lr: 0.001| loss_mean: 2.216| batch_acc: 0.119: 100%|█| 1500/1500 [02:08<00:00, 11.71it/s
epoch1| batch1499| loss: 1.214| lr: 0.0008| loss_mean: 1.287| batch_acc: 0.434: 100%|█| 1500/1500 [02:07<00:00, 11.78it/
epoch2| batch1499| loss: 1.087| lr: 0.0006400000000000002| loss_mean: 0.969| batch_acc: 0.565: 100%|█| 1500/1500 [02:06<


T: 15 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.856| lr: 0.001| loss_mean: 2.25| batch_acc: 0.116: 100%|█| 1500/1500 [02:39<00:00,  9.40it/s]
epoch1| batch1499| loss: 0.952| lr: 0.0008| loss_mean: 1.119| batch_acc: 0.527: 100%|█| 1500/1500 [02:40<00:00,  9.34it/
epoch2| batch1499| loss: 1.602| lr: 0.0006400000000000002| loss_mean: 0.791| batch_acc: 0.644: 100%|█| 1500/1500 [02:40<


T: 15 width: 300
train device: cuda:0


epoch0| batch1499| loss: 1.572| lr: 0.001| loss_mean: 2.252| batch_acc: 0.113: 100%|█| 1500/1500 [03:53<00:00,  6.42it/s
epoch1| batch1499| loss: 0.813| lr: 0.0008| loss_mean: 1.053| batch_acc: 0.557: 100%|█| 1500/1500 [03:54<00:00,  6.39it/
epoch2| batch1499| loss: 1.406| lr: 0.0006400000000000002| loss_mean: 0.722| batch_acc: 0.681: 100%|█| 1500/1500 [03:53<


T: 15 width: 400
train device: cuda:0


epoch0| batch1499| loss: 0.551| lr: 0.001| loss_mean: 1.829| batch_acc: 0.258: 100%|█| 1500/1500 [06:55<00:00,  3.61it/s
epoch1| batch1499| loss: 1.45| lr: 0.0008| loss_mean: 0.886| batch_acc: 0.61: 100%|█| 1500/1500 [07:37<00:00,  3.28it/s]
epoch2| batch1499| loss: 1.014| lr: 0.0006400000000000002| loss_mean: 0.672| batch_acc: 0.693: 100%|█| 1500/1500 [05:28<


T: 15 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.911| lr: 0.001| loss_mean: 2.569| batch_acc: 0.037: 100%|█| 1500/1500 [07:14<00:00,  3.45it/s
epoch1| batch1499| loss: 1.66| lr: 0.0008| loss_mean: 1.223| batch_acc: 0.499: 100%|█| 1500/1500 [07:12<00:00,  3.47it/s
epoch2| batch1499| loss: 0.871| lr: 0.0006400000000000002| loss_mean: 0.751| batch_acc: 0.669: 100%|█| 1500/1500 [07:11<


T: 18 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.919| lr: 0.001| loss_mean: 2.165| batch_acc: 0.119: 100%|█| 1500/1500 [02:05<00:00, 11.96it/s
epoch1| batch1499| loss: 1.092| lr: 0.0008| loss_mean: 1.36| batch_acc: 0.397: 100%|█| 1500/1500 [02:05<00:00, 12.00it/s
epoch2| batch1499| loss: 0.335| lr: 0.0006400000000000002| loss_mean: 0.996| batch_acc: 0.565: 100%|█| 1500/1500 [02:04<


T: 18 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.446| lr: 0.001| loss_mean: 2.213| batch_acc: 0.126: 100%|█| 1500/1500 [02:36<00:00,  9.61it/s
epoch1| batch1499| loss: 1.249| lr: 0.0008| loss_mean: 1.116| batch_acc: 0.527: 100%|█| 1500/1500 [02:36<00:00,  9.58it/
epoch2| batch1499| loss: 0.503| lr: 0.0006400000000000002| loss_mean: 0.794| batch_acc: 0.647: 100%|█| 1500/1500 [02:35<


T: 18 width: 300
train device: cuda:0


epoch0| batch1499| loss: 1.41| lr: 0.001| loss_mean: 1.768| batch_acc: 0.285: 100%|█| 1500/1500 [03:49<00:00,  6.53it/s]
epoch1| batch1499| loss: 0.725| lr: 0.0008| loss_mean: 0.863| batch_acc: 0.636: 100%|█| 1500/1500 [03:58<00:00,  6.28it/
epoch2| batch1499| loss: 0.807| lr: 0.0006400000000000002| loss_mean: 0.659| batch_acc: 0.708: 100%|█| 1500/1500 [07:07<


T: 18 width: 400
train device: cuda:0


epoch0| batch1499| loss: 0.584| lr: 0.001| loss_mean: 1.817| batch_acc: 0.279: 100%|█| 1500/1500 [05:28<00:00,  4.57it/s
epoch1| batch1499| loss: 0.522| lr: 0.0008| loss_mean: 0.848| batch_acc: 0.635: 100%|█| 1500/1500 [05:29<00:00,  4.55it/
epoch2| batch1499| loss: 0.774| lr: 0.0006400000000000002| loss_mean: 0.655| batch_acc: 0.702: 100%|█| 1500/1500 [05:28<


T: 18 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.783| lr: 0.001| loss_mean: 2.083| batch_acc: 0.157: 100%|█| 1500/1500 [07:12<00:00,  3.47it/s
epoch1| batch1499| loss: 0.677| lr: 0.0008| loss_mean: 0.986| batch_acc: 0.597: 100%|█| 1500/1500 [07:12<00:00,  3.47it/
epoch2| batch1499| loss: 0.554| lr: 0.0006400000000000002| loss_mean: 0.705| batch_acc: 0.696: 100%|█| 1500/1500 [07:13<


T: 20 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.901| lr: 0.001| loss_mean: 2.264| batch_acc: 0.092: 100%|█| 1500/1500 [02:04<00:00, 12.03it/s
epoch1| batch1499| loss: 1.417| lr: 0.0008| loss_mean: 1.362| batch_acc: 0.399: 100%|█| 1500/1500 [02:04<00:00, 12.03it/
epoch2| batch1499| loss: 0.988| lr: 0.0006400000000000002| loss_mean: 0.984| batch_acc: 0.568: 100%|█| 1500/1500 [02:04<


T: 20 width: 200
train device: cuda:0


epoch0| batch1499| loss: 0.484| lr: 0.001| loss_mean: 1.998| batch_acc: 0.186: 100%|█| 1500/1500 [02:36<00:00,  9.60it/s
epoch1| batch1499| loss: 1.082| lr: 0.0008| loss_mean: 0.989| batch_acc: 0.579: 100%|█| 1500/1500 [02:36<00:00,  9.58it/
epoch2| batch1499| loss: 1.006| lr: 0.0006400000000000002| loss_mean: 0.733| batch_acc: 0.671: 100%|█| 1500/1500 [02:36<


T: 20 width: 300
train device: cuda:0


epoch0| batch1499| loss: 0.962| lr: 0.001| loss_mean: 1.859| batch_acc: 0.256: 100%|█| 1500/1500 [07:00<00:00,  3.57it/s
epoch1| batch1499| loss: 0.564| lr: 0.0008| loss_mean: 0.897| batch_acc: 0.607: 100%|█| 1500/1500 [04:09<00:00,  6.02it/
epoch2| batch1499| loss: 0.666| lr: 0.0006400000000000002| loss_mean: 0.671| batch_acc: 0.689: 100%|█| 1500/1500 [03:50<


T: 20 width: 400
train device: cuda:0


epoch0| batch1499| loss: 0.658| lr: 0.001| loss_mean: 1.611| batch_acc: 0.355: 100%|█| 1500/1500 [05:27<00:00,  4.58it/s
epoch1| batch1499| loss: 0.504| lr: 0.0008| loss_mean: 0.769| batch_acc: 0.668: 100%|█| 1500/1500 [05:28<00:00,  4.57it/
epoch2| batch1499| loss: 0.647| lr: 0.0006400000000000002| loss_mean: 0.596| batch_acc: 0.729: 100%|█| 1500/1500 [05:28<


T: 20 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.291| lr: 0.001| loss_mean: 2.126| batch_acc: 0.182: 100%|█| 1500/1500 [07:12<00:00,  3.47it/s
epoch1| batch1499| loss: 0.729| lr: 0.0008| loss_mean: 0.941| batch_acc: 0.599: 100%|█| 1500/1500 [07:14<00:00,  3.45it/
epoch2| batch1499| loss: 0.557| lr: 0.0006400000000000002| loss_mean: 0.703| batch_acc: 0.688: 100%|█| 1500/1500 [07:14<


T: 23 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.826| lr: 0.001| loss_mean: 2.28| batch_acc: 0.081: 100%|█| 1500/1500 [02:04<00:00, 12.04it/s]
epoch1| batch1499| loss: 0.924| lr: 0.0008| loss_mean: 1.461| batch_acc: 0.356: 100%|█| 1500/1500 [02:04<00:00, 12.00it/
epoch2| batch1499| loss: 1.403| lr: 0.0006400000000000002| loss_mean: 1.029| batch_acc: 0.558: 100%|█| 1500/1500 [02:04<


T: 23 width: 200
train device: cuda:0


epoch0| batch1499| loss: 2.117| lr: 0.001| loss_mean: 2.136| batch_acc: 0.151: 100%|█| 1500/1500 [04:36<00:00,  5.43it/s
epoch1| batch1499| loss: 1.17| lr: 0.0008| loss_mean: 1.055| batch_acc: 0.56: 100%|█| 1500/1500 [03:52<00:00,  6.45it/s]
epoch2| batch1499| loss: 0.487| lr: 0.0006400000000000002| loss_mean: 0.782| batch_acc: 0.659: 100%|█| 1500/1500 [02:37<


T: 23 width: 300
train device: cuda:0


epoch0| batch1499| loss: 1.103| lr: 0.001| loss_mean: 1.817| batch_acc: 0.264: 100%|█| 1500/1500 [03:50<00:00,  6.51it/s
epoch1| batch1499| loss: 0.915| lr: 0.0008| loss_mean: 0.9| batch_acc: 0.613: 100%|█| 1500/1500 [03:51<00:00,  6.48it/s]
epoch2| batch1499| loss: 0.359| lr: 0.0006400000000000002| loss_mean: 0.681| batch_acc: 0.697: 100%|█| 1500/1500 [03:50<


T: 23 width: 400
train device: cuda:0


epoch0| batch1499| loss: 1.392| lr: 0.001| loss_mean: 1.622| batch_acc: 0.345: 100%|█| 1500/1500 [05:28<00:00,  4.57it/s
epoch1| batch1499| loss: 0.59| lr: 0.0008| loss_mean: 0.784| batch_acc: 0.654: 100%|█| 1500/1500 [05:29<00:00,  4.56it/s
epoch2| batch1499| loss: 0.659| lr: 0.0006400000000000002| loss_mean: 0.602| batch_acc: 0.724: 100%|█| 1500/1500 [05:29<


T: 23 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.782| lr: 0.001| loss_mean: 2.279| batch_acc: 0.095: 100%|█| 1500/1500 [07:13<00:00,  3.46it/s
epoch1| batch1499| loss: 0.79| lr: 0.0008| loss_mean: 1.092| batch_acc: 0.552: 100%|█| 1500/1500 [07:14<00:00,  3.45it/s
epoch2| batch1499| loss: 0.475| lr: 0.0006400000000000002| loss_mean: 0.723| batch_acc: 0.688: 100%|█| 1500/1500 [07:13<


T: 25 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.305| lr: 0.001| loss_mean: 2.183| batch_acc: 0.124: 100%|█| 1500/1500 [05:10<00:00,  4.83it/s
epoch1| batch1499| loss: 1.494| lr: 0.0008| loss_mean: 1.297| batch_acc: 0.441: 100%|█| 1500/1500 [02:18<00:00, 10.80it/
epoch2| batch1499| loss: 0.774| lr: 0.0006400000000000002| loss_mean: 0.975| batch_acc: 0.575: 100%|█| 1500/1500 [02:05<


T: 25 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.55| lr: 0.001| loss_mean: 2.154| batch_acc: 0.137: 100%|█| 1500/1500 [02:36<00:00,  9.57it/s]
epoch1| batch1499| loss: 1.06| lr: 0.0008| loss_mean: 1.079| batch_acc: 0.539: 100%|█| 1500/1500 [02:37<00:00,  9.52it/s
epoch2| batch1499| loss: 0.544| lr: 0.0006400000000000002| loss_mean: 0.764| batch_acc: 0.662: 100%|█| 1500/1500 [02:37<


T: 25 width: 300
train device: cuda:0


epoch0| batch1499| loss: 0.924| lr: 0.001| loss_mean: 1.893| batch_acc: 0.255: 100%|█| 1500/1500 [03:50<00:00,  6.51it/s
epoch1| batch1499| loss: 0.601| lr: 0.0008| loss_mean: 0.854| batch_acc: 0.637: 100%|█| 1500/1500 [03:50<00:00,  6.51it/
epoch2| batch1499| loss: 0.484| lr: 0.0006400000000000002| loss_mean: 0.644| batch_acc: 0.712: 100%|█| 1500/1500 [03:50<


T: 25 width: 400
train device: cuda:0


epoch0| batch1499| loss: 1.178| lr: 0.001| loss_mean: 2.037| batch_acc: 0.168: 100%|█| 1500/1500 [05:28<00:00,  4.56it/s
epoch1| batch1499| loss: 1.13| lr: 0.0008| loss_mean: 0.98| batch_acc: 0.576: 100%|█| 1500/1500 [05:29<00:00,  4.55it/s]
epoch2| batch1499| loss: 1.295| lr: 0.0006400000000000002| loss_mean: 0.712| batch_acc: 0.673: 100%|█| 1500/1500 [05:28<


T: 25 width: 500
train device: cuda:0


epoch0| batch1499| loss: 0.944| lr: 0.001| loss_mean: 1.639| batch_acc: 0.35: 100%|█| 1500/1500 [07:13<00:00,  3.46it/s]
epoch1| batch1499| loss: 1.126| lr: 0.0008| loss_mean: 0.807| batch_acc: 0.659: 100%|█| 1500/1500 [07:14<00:00,  3.46it/
epoch2| batch1499| loss: 0.527| lr: 0.0006400000000000002| loss_mean: 0.622| batch_acc: 0.722: 100%|█| 1500/1500 [10:49<


T: 28 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.407| lr: 0.001| loss_mean: 2.137| batch_acc: 0.129: 100%|█| 1500/1500 [02:11<00:00, 11.42it/s
epoch1| batch1499| loss: 1.379| lr: 0.0008| loss_mean: 1.281| batch_acc: 0.441: 100%|█| 1500/1500 [02:05<00:00, 11.94it/
epoch2| batch1499| loss: 1.271| lr: 0.0006400000000000002| loss_mean: 0.938| batch_acc: 0.59: 100%|█| 1500/1500 [02:04<0


T: 28 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.849| lr: 0.001| loss_mean: 2.343| batch_acc: 0.088: 100%|█| 1500/1500 [02:36<00:00,  9.57it/s
epoch1| batch1499| loss: 0.753| lr: 0.0008| loss_mean: 1.236| batch_acc: 0.475: 100%|█| 1500/1500 [02:37<00:00,  9.54it/
epoch2| batch1499| loss: 0.817| lr: 0.0006400000000000002| loss_mean: 0.847| batch_acc: 0.63: 100%|█| 1500/1500 [02:36<0


T: 28 width: 300
train device: cuda:0


epoch0| batch1499| loss: 1.341| lr: 0.001| loss_mean: 2.151| batch_acc: 0.138: 100%|█| 1500/1500 [03:50<00:00,  6.51it/s
epoch1| batch1499| loss: 1.006| lr: 0.0008| loss_mean: 1.084| batch_acc: 0.529: 100%|█| 1500/1500 [03:51<00:00,  6.49it/
epoch2| batch1499| loss: 1.229| lr: 0.0006400000000000002| loss_mean: 0.762| batch_acc: 0.663: 100%|█| 1500/1500 [03:51<


T: 28 width: 400
train device: cuda:0


epoch0| batch1499| loss: 1.296| lr: 0.001| loss_mean: 1.793| batch_acc: 0.28: 100%|█| 1500/1500 [05:28<00:00,  4.56it/s]
epoch1| batch1499| loss: 0.592| lr: 0.0008| loss_mean: 0.857| batch_acc: 0.639: 100%|█| 1500/1500 [05:28<00:00,  4.57it/
epoch2| batch1499| loss: 0.384| lr: 0.0006400000000000002| loss_mean: 0.654| batch_acc: 0.706: 100%|█| 1500/1500 [05:28<


T: 28 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.171| lr: 0.001| loss_mean: 2.271| batch_acc: 0.111: 100%|█| 1500/1500 [07:23<00:00,  3.38it/s
epoch1| batch1499| loss: 0.915| lr: 0.0008| loss_mean: 0.976| batch_acc: 0.59: 100%|█| 1500/1500 [10:52<00:00,  2.30it/s
epoch2| batch1499| loss: 0.741| lr: 0.0006400000000000002| loss_mean: 0.688| batch_acc: 0.7: 100%|█| 1500/1500 [07:14<00


T: 30 width: 100
train device: cuda:0


epoch0| batch1499| loss: 1.689| lr: 0.001| loss_mean: 2.163| batch_acc: 0.152: 100%|█| 1500/1500 [02:05<00:00, 11.97it/s
epoch1| batch1499| loss: 1.685| lr: 0.0008| loss_mean: 1.286| batch_acc: 0.454: 100%|█| 1500/1500 [02:04<00:00, 12.00it/
epoch2| batch1499| loss: 0.644| lr: 0.0006400000000000002| loss_mean: 0.984| batch_acc: 0.567: 100%|█| 1500/1500 [02:05<


T: 30 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.236| lr: 0.001| loss_mean: 2.014| batch_acc: 0.195: 100%|█| 1500/1500 [02:36<00:00,  9.56it/s
epoch1| batch1499| loss: 0.647| lr: 0.0008| loss_mean: 0.981| batch_acc: 0.585: 100%|█| 1500/1500 [02:37<00:00,  9.55it/
epoch2| batch1499| loss: 0.612| lr: 0.0006400000000000002| loss_mean: 0.713| batch_acc: 0.686: 100%|█| 1500/1500 [02:36<


T: 30 width: 300
train device: cuda:0


epoch0| batch1499| loss: 0.637| lr: 0.001| loss_mean: 1.74| batch_acc: 0.306: 100%|█| 1500/1500 [03:49<00:00,  6.53it/s]
epoch1| batch1499| loss: 0.582| lr: 0.0008| loss_mean: 0.809| batch_acc: 0.649: 100%|█| 1500/1500 [03:49<00:00,  6.53it/
epoch2| batch1499| loss: 0.967| lr: 0.0006400000000000002| loss_mean: 0.616| batch_acc: 0.719: 100%|█| 1500/1500 [03:50<


T: 30 width: 400
train device: cuda:0


epoch0| batch1499| loss: 0.681| lr: 0.001| loss_mean: 1.878| batch_acc: 0.243: 100%|█| 1500/1500 [05:28<00:00,  4.57it/s
epoch1| batch1499| loss: 0.745| lr: 0.0008| loss_mean: 0.842| batch_acc: 0.637: 100%|█| 1500/1500 [05:28<00:00,  4.57it/
epoch2| batch1499| loss: 0.423| lr: 0.0006400000000000002| loss_mean: 0.632| batch_acc: 0.718: 100%|█| 1500/1500 [05:54<


T: 30 width: 500
train device: cuda:0


epoch0| batch1499| loss: 1.064| lr: 0.001| loss_mean: 2.047| batch_acc: 0.185: 100%|█| 1500/1500 [10:37<00:00,  2.35it/s
epoch1| batch1499| loss: 1.031| lr: 0.0008| loss_mean: 0.913| batch_acc: 0.622: 100%|█| 1500/1500 [07:14<00:00,  3.45it/
epoch2| batch1499| loss: 0.81| lr: 0.0006400000000000002| loss_mean: 0.665| batch_acc: 0.708: 100%|█| 1500/1500 [07:14<0


In [5]:
if __name__ == '__main__':
    img_width_list = [200]
    T_list = [30]
    for T in T_list:
        for width in img_width_list:
            print('T:',T,'width:',width)
            test_train = MyTrain(max_epoch=50,T=T,img_width=width)
            test_train.train()

T: 30 width: 200
train device: cuda:0


epoch0| batch1499| loss: 1.402| lr: 0.001| loss_mean: 1.884| batch_acc: 0.244: 100%|█| 1500/1500 [07:06<00:00,  3.51it/s
epoch1| batch1499| loss: 0.608| lr: 0.0008| loss_mean: 0.969| batch_acc: 0.584: 100%|█| 1500/1500 [02:36<00:00,  9.58it/
epoch2| batch1499| loss: 0.764| lr: 0.0006| loss_mean: 0.712| batch_acc: 0.688: 100%|█| 1500/1500 [02:35<00:00,  9.64it/
epoch3| batch1499| loss: 0.557| lr: 0.0005| loss_mean: 0.564| batch_acc: 0.743: 100%|█| 1500/1500 [02:34<00:00,  9.68it/
epoch4| batch1499| loss: 0.649| lr: 0.0004| loss_mean: 0.438| batch_acc: 0.795: 100%|█| 1500/1500 [02:34<00:00,  9.70it/
epoch5| batch1499| loss: 0.465| lr: 0.0003| loss_mean: 0.338| batch_acc: 0.834: 100%|█| 1500/1500 [02:34<00:00,  9.72it/
epoch6| batch1499| loss: 0.387| lr: 0.0003| loss_mean: 0.244| batch_acc: 0.878: 100%|█| 1500/1500 [02:34<00:00,  9.69it/
epoch7| batch1499| loss: 0.111| lr: 0.0002| loss_mean: 0.169| batch_acc: 0.915: 100%|█| 1500/1500 [02:34<00:00,  9.71it/
epoch8| batch1499| loss: 0.015| 

### 预测模块