In [None]:
import cv2
import os
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torchvision.models as models

class MyDataset(Dataset):
    
    def __init__(self,img_path,label_content,img_width,img_height):
        self.img_path = img_path
        self.label_content = label_content
        self.img_width = img_width
        self.img_height = img_height
    
    def __getitem__(self,index):
        img_path = self.img_path[index]
        img_cv2 = cv2.imread(img_path)  #return （H，W，C），C list by B，G，R
        img_array = cv2.resize(img_cv2,(self.img_width,self.img_height))  #cv2.resize的输出为（w,h），宽在前，高在后
        
        img_name = os.path.split(img_path)[1]
        img_label = self.label_content[img_name]['label']
        #transform the label 0~9 to index 1~10，index 0 is for blank
        img_label_index_list = [x+1 for x in img_label]
        
        return img_array,img_label_index_list
    
    def __len__(self):
        return len(self.img_path)


class MyModel(nn.Module):
    def __init__(self,T,img_width,dropout):
        super().__init__()
        self.img_width = img_width
        self.cnn_out_width = self.cal_finnal_width()
        #self.cnn: input(N,3,100,200),output(N,512,4,7)
        trained_resnet18 = models.resnet18(weights='DEFAULT')
        self.cnn = nn.Sequential(trained_resnet18.conv1,
                                 trained_resnet18.bn1,
                                 trained_resnet18.relu,
                                 trained_resnet18.maxpool,
                                 trained_resnet18.layer1,
                                 trained_resnet18.layer2,
                                 trained_resnet18.layer3,
                                 trained_resnet18.layer4)
        self.dropout = nn.Dropout(p=dropout)
        self.feat_to_seq = nn.Conv2d(self.cnn_out_width,T,kernel_size=(1,1),stride=1)
        self.rnn = nn.GRU(input_size=512,hidden_size=64,bidirectional=True)
        self.linear = nn.Linear(128,11)
    
    def forward(self,X):
        X = self.cnn(X)
        X = self.dropout(X)
        X = X.mean(dim=2,keepdim=True)
        X = X.permute((0,3,1,2))
        X = self.feat_to_seq(X)
        X = X.permute((1,0,2,3))
        X = X.squeeze(dim=3)
        X,_ = self.rnn(X)
        X = self.linear(X)
#         y_hat = self.softmax(X)
#         return y_hat
        return X

    def cal_finnal_width(self):
        width = self.img_width
        for i in range(5):
            if width%2 == 0:
                width = int(width/2)
            else:
                width = int(width/2)+1
        return width

class MyDataLoader():
    def __init__(self,my_dataset,batch_size=20,use_my_collate=True,shuffle=False,):
        from torch.utils.data import DataLoader
        if use_my_collate:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle,collate_fn=self.my_collate)
        else:
            self.data_loader = DataLoader(my_dataset,batch_size=batch_size,shuffle=shuffle)
        
    def my_collate(self,batch_data):
        '''
        batch_data shape: [(feat_array,label_list),(feat_array,label_list),……]
        '''
        feat_list =[]
        label_list = []
        label_len_list = []
        for sample in batch_data:
            feat_list.append(torch.from_numpy(sample[0].transpose((2,0,1))))
            label_list.extend(sample[1])
            label_len_list.append(len(sample[1]))
            
        feat = torch.stack(feat_list).float()
        label = torch.Tensor(label_list).int()
        label_len = torch.Tensor(label_len_list).int()
        
        return feat,label,label_len
        
class MyTrain():
    def __init__(self,max_epoch=1,batch_size=20,lr=0.001,random_seed=None,grad_threshold=10,T=30,img_width=200,out_dir='./',dropout=0):
        self.max_epoch = max_epoch
        self.batch_size = batch_size
        self.lr = lr
        self.random_seed = random_seed
        self.grad_threshold = grad_threshold
        self.T = T
        self.img_width = img_width
        self.img_height = int(0.5*img_width)
        self.out_dir = out_dir
        self.dropout = dropout
    
    def train(self):
        import json
        import glob
        
        max_epoch,batch_size,grad_threshold = self.max_epoch,self.batch_size,self.grad_threshold
        
        #fix random seed
        if self.random_seed is not None:
            self.fix_seed()
            
        #creat dataset instance
        img_path = glob.glob('../input/train/*.png')
        label_content = json.load(open('../input/train.json'))
        img_path.sort()
        my_dataset = MyDataset(img_path,label_content,self.img_width,self.img_height)

        #creat model and loss instance
        crnn_model = MyModel(self.T,self.img_width,self.dropout)
        ctc_loss = nn.CTCLoss()
        if torch.cuda.is_available():
            crnn_model.cuda()
#             ctc_loss.cuda()
        
        #creat optim instance
        adam_optimizer = torch.optim.Adam(crnn_model.parameters(),lr=self.lr,weight_decay=0.0001)
        batch_device = next(iter(crnn_model.parameters())).device
        print('train device:',batch_device)
        
        acc_max = 0
        epoch_index = 0

        while acc_max < 0.90 and epoch_index < max_epoch :
        
            loss_list=[]
            acc_list =[]
            
            from tqdm import tqdm
            my_dataloader = MyDataLoader(my_dataset,batch_size=batch_size,shuffle=True).data_loader
            my_dataloader = tqdm(my_dataloader,ncols=120)
            
            #train each batch data
            for batch_index,batch_data in enumerate(my_dataloader):
                
                batch_feat = batch_data[0]
                batch_label = batch_data[1]
                batch_label_len = batch_data[2]
                if torch.cuda.is_available():
                    batch_feat = batch_feat.cuda()
                    batch_label = batch_label.cuda()
                    batch_label_len = batch_label_len.cuda()
                    
                adam_optimizer.zero_grad()
                
                batch_y_hat = crnn_model(batch_feat) 
                batch_y_hat = batch_y_hat.to(torch.float64)  #经过crnn计算出来结果是float32,传入ctcloss使用GPU时会报错，把这个结果改成float64就不会报错
                batch_y_hat_len = torch.LongTensor([len(batch_y_hat)]*batch_size)
                
#                 print('1:',batch_y_hat,'\n','2:',batch_label,'\n','3:',batch_y_hat_len,'\n','4:',batch_label_len)                
                batch_loss = ctc_loss(batch_y_hat,batch_label,batch_y_hat_len,batch_label_len)
                batch_loss.backward()
                
                #进行梯度裁剪，避免梯度爆炸
                nn.utils.clip_grad_norm_(crnn_model.parameters(),grad_threshold)
                                
                #更新模型参数
                adam_optimizer.step()
                
                #调整学习率
                adam_optimizer.param_groups[0]['lr'] = self.lr*(0.8**(epoch_index%10))
                
                #更新tqdm进度条的描述内容
                batch_acc = self.get_accurancy(batch_y_hat,batch_label,batch_label_len)
                loss_list.append(batch_loss)
                acc_list.append(batch_acc)
                batch_lr = adam_optimizer.param_groups[0]['lr']
                batch_loss_mean = sum(loss_list)/len(loss_list)
                batch_acc_mean = round(sum(acc_list)/len(acc_list),3)
                                
                my_dataloader.set_description(f'epoch{epoch_index}| batch{batch_index}| loss: {round(batch_loss.item(),3)}| lr: {round(batch_lr,4)}| loss_mean: {round(batch_loss_mean.item(),3)}| batch_acc: {batch_acc_mean}')
           
            #保存最近一个epoch的模型参数，如果比之前的模型都好，则将其设置为最佳模型
            if batch_acc_mean > acc_max:
                acc_max = batch_acc_mean
                torch.save(crnn_model.state_dict(),os.path.join(self.out_dir,f'crnn_resnet18_ctc_dropout{self.dropout}_best'))
            
            epoch_index+=1
    
    def fix_seed(self):
        import random
        import numpy as np
        import torch
        
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        torch.random.manual_seed(self.random_seed)
        torch.cuda.manual_seed_all(self.random_seed)
        torch.backends.cudnn.deterministic = True
        
    def get_accurancy(self,y_hat,label,label_len):
        #y_hat(T,N,C)
        y_hat = y_hat.permute(1,0,2).cpu()
#         print('y_hat:',y_hat[0])
        label =label.cpu()
        label_len = label_len.cpu()
        #pred(N,T)
        pred = torch.argmax(y_hat,dim=2)
#         print('pred:',pred[0])
        batch_size = pred.shape[0]
        acc = 0
        #decode
        for i in range(batch_size):
            raw_pred_list = list(pred[i].numpy())
#             print('raw_pred_list:',raw_pred_list)
            pred_data = []
            for j in range(len(raw_pred_list)):
                if j == 0 and raw_pred_list[0] != 0:
                    pred_data.append(raw_pred_list[0]-1)
                if j != 0 and raw_pred_list[j] != raw_pred_list[j-1] and raw_pred_list[j] != 0:
                    pred_data.append(raw_pred_list[j]-1)
#             print('pred_data:',pred_data)
            
            label_start_index = int(torch.sum(label_len[:i]).item())
            label_end_index = int(label_len[i].item())+label_start_index
            label_data = list(label[label_start_index:label_end_index].int().numpy())
            label_data = [x-1 for x in label_data]
#             print(f'pred_data:{pred_data},label_data:{label_data}\n')
            
            if pred_data == label_data:
                acc+=1
        acc/=batch_size
        return acc
        
    def model_save(self,model):
        pass
    def model_load(self):
        pass
        

In [None]:
if __name__ == '__main__':
    dropout_list = [0.7]
    for dropout in dropout_list:
        print('Dropout:',dropout)
        test_train = MyTrain(max_epoch=50,T=30,img_width=200,dropout=dropout)
        test_train.train()