<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#训练模块" data-toc-modified-id="训练模块-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>训练模块</a></span></li><li><span><a href="#验证模块" data-toc-modified-id="验证模块-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>验证模块</a></span></li><li><span><a href="#结果写入模块" data-toc-modified-id="结果写入模块-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>结果写入模块</a></span></li></ul></div>

## 训练模块

In [36]:
import cv2
import os
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torchvision.models as models

class MyDataset(Dataset):
    
    def __init__(self,img_path,label_content,img_width,img_height):
        self.img_path = img_path
        self.label_content = label_content
        self.img_width = img_width
        self.img_height = img_height
    
    def __getitem__(self,index):
        img_path = self.img_path[index]
        img_cv2 = cv2.imread(img_path)  #return （H，W，C），C list by B，G，R
        img_array = cv2.resize(img_cv2,(self.img_width,self.img_height))  #cv2.resize的输入为（w,h），宽在前，高在后
        img_array = img_array.transpose((2,0,1))
        
        img_name = os.path.split(img_path)[1]
        img_label = self.label_content[img_name]['label']
        #transform the label 0~9 to index 1~10，index 0 is for blank
        img_label_index_list = [x+1 for x in img_label]
        
        return img_array,img_label_index_list
    
    def __len__(self):
        return len(self.img_path)


class MyModel(nn.Module):
    def __init__(self,T,img_width,dropout=0):
        super().__init__()
        self.img_width = img_width
        self.img_height = int(0.5*self.img_width)
        self.cnn_out_width = self.cal_finnal_width()
        self.cnn_out_height = self.cal_finnal_height()
        #self.cnn: input(N,3,240,480),output(N,512,8,15)
        trained_resnet18 = models.resnet18(weights='DEFAULT')
        self.cnn = nn.Sequential(trained_resnet18.conv1,
                                 trained_resnet18.bn1,
                                 trained_resnet18.relu,
                                 trained_resnet18.maxpool,
                                 trained_resnet18.layer1,
                                 trained_resnet18.layer2,
                                 trained_resnet18.layer3,
                                 trained_resnet18.layer4)
        self.feat_to_seq = nn.Conv2d(self.cnn_out_height,1,kernel_size=(1,1),stride=1)
        self.rnn = nn.GRU(input_size=512,hidden_size=64,batch_first=True,bidirectional=True)
        self.linear = nn.Linear(128,11)
    
    def forward(self,X):
#         print('X:',X.shape)
        X = self.cnn(X)
#         print('cnn:',X.shape)
        X = X.permute((0,2,1,3))
        X = self.feat_to_seq(X)
#         print('feat_to_seq:',X.shape)
        X = X.permute((0,3,1,2))
#         print('permute:',X.shape)
        X = X.squeeze(dim=2)
#         print('squeeze:',X.shape)
        X,_ = self.rnn(X)
#         print('rnn:',X.shape)
        X = self.linear(X)
        X = X.permute((1,0,2))
#         print('linear:',X.shape)
        return X

    def cal_finnal_width(self):
        width = self.img_width
        for i in range(5):
            if width%2 == 0:
                width = int(width/2)
            else:
                width = int(width/2)+1
        return width

    def cal_finnal_height(self):
        height = self.img_height
        for i in range(5):
            if height%2 == 0:
                height = int(height/2)
            else:
                height = int(height/2)+1
        return height

class MyDataLoader():
    def __init__(self,my_dataset,batch_size=20,use_my_collate=True,shuffle=False,):
        from torch.utils.data import DataLoader
        if use_my_collate:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle,collate_fn=self.my_collate)
        else:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle)
        
    def my_collate(self,batch_data):
        '''
        batch_data shape: [(feat_array,label_list),(feat_array,label_list),……]
        '''
        feat_list =[]
        label_list = []
        label_len_list = []
        for sample in batch_data:
            feat_list.append(torch.from_numpy(sample[0]))
            label_list.extend(sample[1])
            label_len_list.append(len(sample[1]))
            
        feat = torch.stack(feat_list).float()
        label = torch.Tensor(label_list).int()
        label_len = torch.Tensor(label_len_list).int()
        
        return feat,label,label_len
        
class MyTrain():
    def __init__(self,max_epoch=1,train_batch_size=120,valid_batch_size=40,lr=0.001,random_seed=None,grad_threshold=10,T=30,img_width=200,out_dir='./',dropout=0):
        self.max_epoch = max_epoch
        self.lr = lr
        self.random_seed = random_seed
        self.grad_threshold = grad_threshold
        self.T = T
        self.img_width = img_width
        self.img_height = int(0.5*img_width)
        self.out_dir = out_dir
        self.dropout = dropout
        self.train_batch_size = train_batch_size
        self.valid_batch_size = valid_batch_size
    
    def train(self):
        import json
        import glob
        
        max_epoch,grad_threshold = self.max_epoch,self.grad_threshold
        
        #fix random seed
        if self.random_seed is not None:
            self.fix_seed()
            
        #creat dataset instance
        train_img_path = glob.glob('../input/train/*.png')
        train_label_content = json.load(open('../input/train.json'))
        
        valid_img_path = glob.glob('../input/val/*.png')
        valid_label_content = json.load(open('../input/val.json'))
        
        train_img_path.sort()
        valid_img_path.sort()
        
        train_dataset = MyDataset(train_img_path,train_label_content,self.img_width,self.img_height)
        valid_dataset = MyDataset(valid_img_path,valid_label_content,self.img_width,self.img_height)

        #creat model and loss instance
        crnn_model = MyModel(self.T,self.img_width,self.dropout)
        ctc_loss = nn.CTCLoss()
        if torch.cuda.is_available():
            crnn_model.cuda()
            ctc_loss.cuda()
        
        #creat optim instance
        adam_optimizer = torch.optim.Adam(crnn_model.parameters(),lr=self.lr)
        batch_device = next(iter(crnn_model.parameters())).device
        print('train device:',batch_device)
        
        acc_max = 0
        epoch_index = 0

        while epoch_index < max_epoch :
        
            train_loss_list=[]
            train_acc_list =[]
            valid_acc_list =[]
            
            from tqdm import tqdm
            train_dataloader = MyDataLoader(train_dataset,batch_size=self.train_batch_size,shuffle=True).data_loader
            valid_dataloader = MyDataLoader(valid_dataset,batch_size=self.valid_batch_size,shuffle=True).data_loader
            
            train_dataloader = tqdm(train_dataloader,ncols=140)
            
            batch_index = 0
            
            #train each batch data
            for batch_train_data,batch_valid_data in zip(train_dataloader,valid_dataloader):
                
                crnn_model.train()
                
                batch_train_feat = batch_train_data[0]
                batch_train_label = batch_train_data[1]
                batch_train_label_len = batch_train_data[2]
                batch_valid_feat = batch_valid_data[0]
                batch_valid_label = batch_valid_data[1]
                batch_valid_label_len = batch_valid_data[2]
                
                if torch.cuda.is_available():
                    batch_train_feat = batch_train_feat.cuda()
                    batch_train_label = batch_train_label.cuda()
                    batch_train_label_len = batch_train_label_len.cuda()
                    batch_valid_feat = batch_valid_feat.cuda()
                    batch_valid_label = batch_valid_label.cuda()
                    batch_valid_label_len = batch_valid_label_len.cuda()
                    
                adam_optimizer.zero_grad()
                
                batch_train_y_hat = crnn_model(batch_train_feat) 
                batch_train_y_hat = batch_train_y_hat.to(torch.float64)  #经过crnn计算出来结果是float32,传入ctcloss使用GPU时会报错，把这个结果改成float64就不会报错
                batch_train_y_hat_len = torch.LongTensor([len(batch_train_y_hat)]*self.train_batch_size)
                
#                 print('1:',batch_y_hat,'\n','2:',batch_label,'\n','3:',batch_y_hat_len,'\n','4:',batch_label_len)                
                batch_train_loss = ctc_loss(batch_train_y_hat,batch_train_label,batch_train_y_hat_len,batch_train_label_len)
                batch_train_loss.backward()
                
                #进行梯度裁剪，避免梯度爆炸
                nn.utils.clip_grad_norm_(crnn_model.parameters(),grad_threshold)
                                
                #更新模型参数
                adam_optimizer.step()
                
                #调整学习率
                adam_optimizer.param_groups[0]['lr'] = self.lr*(0.8**(epoch_index%10))
                
                #计算验证数据准确率
                crnn_model.eval()
                with torch.no_grad():
                    batch_valid_y_hat = crnn_model(batch_valid_feat) 
                    batch_valid_y_hat = batch_valid_y_hat.to(torch.float64)  
                    
                    #更新tqdm进度条的描述内容
                    batch_valid_acc = self.get_accurancy(batch_valid_y_hat,batch_valid_label,batch_valid_label_len)
                    valid_acc_list.append(batch_valid_acc)
                    epoch_valid_acc_mean = round(sum(valid_acc_list)/len(valid_acc_list),3)

                    train_loss_list.append(batch_train_loss.item())
                    batch_train_loss_mean = round(sum(train_loss_list)/len(train_loss_list),3)
                    batch_train_acc = self.get_accurancy(batch_train_y_hat,batch_train_label,batch_train_label_len)
                    train_acc_list.append(batch_train_acc)
                    epoch_train_acc_mean = round(sum(train_acc_list)/len(train_acc_list),3)
                    
                    batch_lr = adam_optimizer.param_groups[0]['lr']
                    
                batch_index+=1
                
                train_dataloader.set_description(f'epoch{epoch_index},batch{batch_index},train_loss: {round(batch_train_loss.item(),3)},lr: {round(batch_lr,4)},train_loss_mean: {batch_train_loss_mean},train_acc: {epoch_train_acc_mean},valid_acc: {epoch_valid_acc_mean}')
           
            #保存最近一个epoch的模型参数，如果比之前的模型都好，则将其设置为最佳模型
            if epoch_valid_acc_mean > acc_max:
                acc_max = epoch_valid_acc_mean
                torch.save(crnn_model.state_dict(),os.path.join(self.out_dir,f'crnn_resnet18_ctc_valid_{epoch_valid_acc_mean}'))
            
            epoch_index+=1
    
    def fix_seed(self):
        import random
        import numpy as np
        import torch
        
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        torch.random.manual_seed(self.random_seed)
        torch.cuda.manual_seed_all(self.random_seed)
        torch.backends.cudnn.deterministic = True
        
    def get_accurancy(self,y_hat,label,label_len):
        #y_hat(T,N,C)
        y_hat = y_hat.permute(1,0,2).cpu()
#         print('y_hat:',y_hat[0])
        label =label.cpu()
        label_len = label_len.cpu()
        #pred(N,T)
        pred = torch.argmax(y_hat,dim=2)
#         print('pred:',pred[0])
        batch_size = pred.shape[0]
        acc = 0
        #decode
        for i in range(batch_size):
            raw_pred_list = list(pred[i].numpy())
#             print('raw_pred_list:',raw_pred_list)
            pred_data = []
            for j in range(len(raw_pred_list)):
                if j == 0 and raw_pred_list[0] != 0:
                    pred_data.append(raw_pred_list[0]-1)
                if j != 0 and raw_pred_list[j] != raw_pred_list[j-1] and raw_pred_list[j] != 0:
                    pred_data.append(raw_pred_list[j]-1)
#             print('pred_data:',pred_data)
            
            label_start_index = int(torch.sum(label_len[:i]).item())
            label_end_index = int(label_len[i].item())+label_start_index
            label_data = list(label[label_start_index:label_end_index].int().numpy())
            label_data = [x-1 for x in label_data]
#             print(f'pred_data:{pred_data},label_data:{label_data}\n')
            
            if pred_data == label_data:
                acc+=1
        acc/=batch_size
        return acc
          

In [None]:
if __name__ == '__main__':
    test_train = MyTrain(max_epoch=50,train_batch_size=30,valid_batch_size=10,lr=0.001,random_seed=1,grad_threshold=10,T=15,img_width=480,out_dir='./',dropout=0)
    test_train.train()

## 验证模块

In [88]:
import glob
import json
from torch.utils.data import DataLoader 
from collections import OrderedDict
from tqdm import tqdm
import pandas as pd

class ValidModel():
    def __init__(self,saved_model_path=None,T=30,img_width=200):
        self.img_width = img_width
        self.img_height = int(0.5*img_width)
        self.model = MyModel(T,img_width)
        if saved_model_path is not None:
            best_state_dict = torch.load(os.path.join('./',saved_model_path))
            self.model.load_state_dict(best_state_dict)
            if torch.cuda.is_available():
                self.model.cuda()
                print(saved_model_path,' cuda model load success!')
            else:
                print(saved_model_path,' model load success!')
        else:
            return('model loading failed')

    def predict(self,img_fold_path):
        img_path = glob.glob(os.path.join(img_fold_path,'*.png'))
        
        #创造一个虚拟的label_content，以正常使用前面写的Dataset类
        train_label_content = json.load(open('../input/train.json'))
        virtul_label_content={}
        for i in range(len(img_path)):
            virtul_label_content[os.path.split(img_path[i])[1]] = train_label_content[os.path.split(img_path[0])[1]]
            
        my_dataset = MyDataset(img_path,virtul_label_content,self.img_width,self.img_height)
        my_dataloader = DataLoader(my_dataset,batch_size=1)

        self.prediction_dict = OrderedDict()
        my_dataloader = tqdm(my_dataloader,ncols=60)
        
        with torch.no_grad():
            self.model.eval()
            
            for batch_index,batch_data in enumerate(my_dataloader):
                X,_ = batch_data
                X = X.float()
                if torch.cuda.is_available:
                    X = X.cuda()
                y_hat = self.model(X)
                y_hat = y_hat.permute(1,0,2).cpu()
                pred = torch.argmax(y_hat,dim=2)
                batch_size = pred.shape[0]
                pred_list = []
                
                for i in range(batch_size):
                    raw_pred_list = list(pred[i].numpy())
                    pred_data = []
                    for j in range(len(raw_pred_list)):
                        if j == 0 and raw_pred_list[0] != 0:
                            pred_data.append(raw_pred_list[0]-1)
                        if j != 0 and raw_pred_list[j] != raw_pred_list[j-1] and raw_pred_list[j] != 0:
                            pred_data.append(raw_pred_list[j]-1)
                    if len(pred_data) != 0:
                        pred_num = ''.join([str(int(num)) for num in pred_data])
                        pred_num = int(pred_num)
                        img_name = os.path.split(img_path[batch_index])[1]
                        self.prediction_dict[img_name] = pred_num
                    else:
                        img_name = os.path.split(img_path[batch_index])[1]
                        self.prediction_dict[img_name] = -1
        
        return self.prediction_dict
            
    def get_accurancy(self,label_path):
        predict_right_num = 0 
        prediction_dict = self.prediction_dict
        label_content = json.load(open(label_path))
        if len(prediction_dict.keys()) == len(label_content.keys()):
            for img_name in prediction_dict.keys():
                if prediction_dict[img_name] == int(''.join([str(x) for x in label_content[img_name]['label']])):
                    predict_right_num+=1
            acc = (predict_right_num/len(prediction_dict))
            return acc
        else:
            print('label path error or may be need to call predict first')
            return None
    
    def write_result(self,model_name):
        tmp_data = {'file_name':self.prediction_dict.keys(),'file_code':self.prediction_dict.values()}
        out_file_path = os.path.join('./',str(model_name)+'_preds.csv')
        write_data = pd.DataFrame(tmp_data)
        write_data.to_csv(out_file_path,sep='\t',index=False)
        print('test preds write down')
        
def main(model_name,pred_name='val',T=15,img_width=480):
    my_valid = ValidModel(saved_model_path=model_name,T=15,img_width=480)
    if pred_name == 'train' or pred_name == 'val':
        valid_dataset_prediction = my_valid.predict(f'../input/{pred_name}')
        print(f'{pred_name} len:',len(valid_dataset_prediction.keys()))
        print(f'{pred_name} acc:',my_valid.get_accurancy(f'../input/{pred_name}.json'))
    if pred_name == 'test_a':
        valid_dataset_prediction = my_valid.predict(f'../input/{pred_name}')
        my_valid.write_result(model_name)
        return valid_dataset_prediction

In [89]:
if __name__ == '__main__':
    model_name = 'crnn_resnet18_ctc_valid_0.655'
    prediction_dict = main(model_name,pred_name='test_a',T=15,img_width=480)

crnn_resnet18_ctc_valid_0.655  cuda model load success!


100%|█████████████████| 40000/40000 [09:58<00:00, 66.81it/s]


test preds write down


In [90]:
prediction_dict

OrderedDict([('000000.png', 159),
             ('000001.png', 290),
             ('000002.png', 113),
             ('000003.png', 97),
             ('000004.png', 63),
             ('000005.png', 6399),
             ('000006.png', 126),
             ('000007.png', 14751),
             ('000008.png', 4),
             ('000009.png', 18),
             ('000010.png', 281),
             ('000011.png', 610),
             ('000012.png', 60),
             ('000013.png', 772),
             ('000014.png', 836),
             ('000015.png', 40),
             ('000016.png', 793),
             ('000017.png', 60),
             ('000018.png', 15),
             ('000019.png', 204),
             ('000020.png', 284),
             ('000021.png', 14),
             ('000022.png', 245),
             ('000023.png', 2001),
             ('000024.png', 374),
             ('000025.png', 23),
             ('000026.png', 11),
             ('000027.png', 157),
             ('000028.png', 98),
             ('000029.p

In [91]:
tmp_data = {'file_name':prediction_dict.keys(),'file_code':prediction_dict.values()}
write_data = pd.DataFrame(tmp_data)
write_data

Unnamed: 0,file_name,file_code
0,000000.png,159
1,000001.png,290
2,000002.png,113
3,000003.png,97
4,000004.png,63
...,...,...
39995,039995.png,2123
39996,039996.png,31
39997,039997.png,1167
39998,039998.png,235


In [95]:
write_data.to_csv('crnn_resnet18_ctc_valid_0.655_preds.csv',index=False)