In [0]:
# Mounting google drive to colab, and unzipping the training data

from google.colab import drive
drive.mount('/content/drive')

!unzip 'drive/My Drive/out.zip'

In [0]:
# Import libraries

import os
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms, utils
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

In [0]:
# Initialize hyperparameters and constants

dir_name = 'out'
height = 32
width = 100
valid_ratio = 0.02

num_rnn_inp = 512 #number of neurons in RNN input layer
num_rnn_hid = 256 #number of neurons in RNN hidden layer

num_epochs = 5
batch_num = 64
learning_rate = 0.00005

train_loss_step = 20 
val_acc_step = 100
display_label_step = 100
save_mode_step = 1000

blank_label = '~'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [0]:
#Create vocabulory from all the input lables 
img_names = [i[:-4] for i in os.listdir('out')]
labels = [i.split('_')[0] for i in img_names]
all_chars_combined = ''.join(labels)
vocab = list(set(all_chars_combined))
vocab.sort()

In [0]:
num_rnn_out = len(vocab) + 1

In [0]:
# in CTC loss 0 is defaulted for empty char. Start from 1

idx2char = {i+1:val for i,val in enumerate(vocab)}
char2idx = {val:i+1 for i,val in enumerate(vocab)}
idx2char[0] = blank_label


In [0]:
# Pytorch Dataset define

class OCRDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        
        self.img_names = os.listdir(img_dir)
        self.img_names.sort()
        
        self.labels = []
        for img_name in self.img_names:
            self.labels.append(img_name.split('_')[0])
        
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names]
        
       
            
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img = Image.open(img_name)
        
        if self.transform:
            img = self.transform(img)
            
        label = self.labels[idx]
        
        return (img, label)
    
    

In [0]:
transform = transforms.Compose(
    [transforms.Resize((height,width)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [0]:
# Encode string to output and decode output to string

def encodeString(label):
    encoded = [char2idx[char] for char in label]
    return encoded

def encodeLabel(label_batch):
    combined = ''
    lengths = []
    for label in label_batch:
        combined += label
        lengths.append(len(label))
    return torch.IntTensor(encodeString(combined)), torch.IntTensor(lengths)


def decodeLable(pred):
    predict_labels = []
    out = pred.argmax(dim=2)
    out = out.permute(1,0)
    for y in out: 
        lab = ''
        for yi in y:
            lab += idx2char[yi.item()]
            
        lab += blank_label
        
        final_lab = ''
        for i in range(len(lab)):
            if (lab[i] != blank_label) and (lab[i] != lab[i+1]) :
                final_lab += lab[i]
        final_lab = final_lab.lower()
        
        predict_labels.append(final_lab)
    return predict_labels


def getAccuracy(actuals, predicted):
    correct = 0
    for act,pred in zip(actuals, predicted):
        if (act.lower() == pred):
            correct += 1
    return correct



In [36]:
data_loader = OCRDataset(dir_name, transform)

train, valid = random_split(data_loader, [int(len(data_loader)*(1-valid_ratio)), int(len(data_loader)*valid_ratio)])

print(len(train))
print(len(valid))

196000
4000


In [0]:
# Define CRNN model

class CNNFeatureGenerator(nn.Module):
    def __init__(self):
        super(CNNFeatureGenerator, self).__init__()
        self.conv1 = nn.Conv2d(3,64,3,1,1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(64,128,3,1,1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(128,256,3,1,1)
        self.conv4 = nn.Conv2d(256,256,3,1,1)
        self.pool3 = nn.MaxPool2d((2,2),(2,1),(0,1))
        self.conv5 = nn.Conv2d(256,512,3,1,1)
        self.bn1 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512,512,3,1,1)
        self.bn2 = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d((2,2),(2,1),(0,1))
        self.conv7 = nn.Conv2d(512,512,2,1,0)
        
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool3(F.relu(self.conv4(x)))
        x = F.relu(self.bn1(self.conv5(x)))
        x = F.relu(self.bn2(self.conv6(x)))
        x = self.pool4(F.relu(self.conv7(x)))
        
        x = x.squeeze(2)
        x = x.permute(2,0,1)
        
        return x


class BidirectionalLSTM(nn.Module):
    def __init__(self,inp_size, hid_size, out_size):
        super(BidirectionalLSTM, self).__init__()
        self.inp2hid = nn.LSTM(inp_size, hid_size, bidirectional=True)
        self.hid2out = nn.Linear(2*hid_size, out_size)
        
    def forward(self, x):
        output, _ = self.inp2hid(x)
        s, b, n = output.size()
        output = self.hid2out(output.view(s*b, n))
        output = output.view(s,b,-1)
        
        return output

class CRNN(nn.Module):
    def __init__(self, rnn_inp_size, rnn_hid_size, rnn_out_size):
        super(CRNN, self).__init__()
        self.cnn = CNNFeatureGenerator()
        self.rnn1 = BidirectionalLSTM(rnn_inp_size, rnn_hid_size, rnn_hid_size)
        self.rnn2 = BidirectionalLSTM(rnn_hid_size, rnn_hid_size, rnn_out_size)
        
    def forward(self, x):
        x = self.cnn(x)
        x = self.rnn1(x)
        x = self.rnn2(x)
        x = F.log_softmax(x, dim=2)
        return x

In [40]:
device.type

'cuda'

In [41]:
crnn = CRNN(num_rnn_inp, num_rnn_hid, num_rnn_out).to(device)


# In case labels are completely inprobable (ex - lenght = 0), loss can go to inf
# Taking care of NaN loss during training
def backward_hook(self, grad_input, grad_output):
  for g in grad_input:
      g[g != g] = 0   # replace all nan/inf in gradients to zero


crnn.register_backward_hook(backward_hook)

<torch.utils.hooks.RemovableHandle at 0x7f435a4c1668>

In [0]:
criterion = nn.CTCLoss().to(device)
optimizer = optim.Adam(crnn.parameters(), lr=learning_rate)

In [0]:
# Word level accuracy on validation set

def validationAccuracy(valid_data):
    valid_loader = DataLoader(valid_data, batch_size=batch_num)
    num_correct = 0
    for i, data in enumerate(valid_loader, 0):
        inputs, labels = data
        preds = crnn(inputs.to(device))
        pred_labs = decodeLable(preds)
        
        num_correct += getAccuracy(labels, pred_labs)
    
    accuracy = num_correct/len(valid_data)
    return accuracy

In [0]:
loss_hist = []
valid_hist = []


# Trainign begins here

for epoch in range(num_epochs):
    train_loader = DataLoader(train, batch_size=batch_num, shuffle=True)
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # start = time.time()
        inputs, labels = data
        inputs = inputs.to(device)
        if (len(labels) != batch_num):
            continue
        
        # forward 
        preds = crnn(inputs)
        input_lengths = torch.IntTensor([preds.size(0)] * batch_num)
        targets, target_lengths = encodeLabel(labels)
        
        '''
        input for CTC loss pytorch 
            Log_probs - Tensor of size (T, N, C) T = input length , N = batch size and C =number of classes with blank
            Targets - Tensor of size (N, S) or (sum of target lenghts)
            Input_lengths - tuple or tensor of size n where n = batch size
            Target_lengths - tuple or tensor of size n where n = batch size
        
        For batch size as 2
            input lengths = tensor([26, 26], dtype=torch.int32)
            target lengths = tensor([13,  5], dtype=torch.int32)
            targets = tensor([30, 18, 15, 11, 30, 30, 19, 13, 11, 22, 19, 29, 23, 20, 11, 19, 22, 29], dtype=torch.int32)
       
           in this case labels are - ('theatticalism', 'jails')
        '''
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # backward + optimize
        loss = criterion(preds, targets, input_lengths, target_lengths) /batch_num
        loss.backward()
        optimizer.step()

        # end = time.time()

        # print("time is " + str(end-start))
        
        #print actual and predicted label
        if i%display_label_step == (display_label_step-1):
            out_labels = torch.argmax(preds,dim=2)
            out = [idx2char[i[0].item()] if i[0].item() != 0 else blank_label for i in out_labels[:,0:1]]
            out = ''.join(out)
            print ('-------------------------------------------------------------------')
            print (out  + " ||||| " + labels[0])
            print ('-------------------------------------------------------------------')
            
        #print statistics
        running_loss += loss.item()
        if i % train_loss_step == (train_loss_step-1):    # print every train_loss_step mini-batches
            print('[%d, %5d] loss: %f' %
                  (epoch + 1, i + 1, running_loss / train_loss_step))
            loss_hist.append(running_loss / train_loss_step)
            running_loss = 0.0
        
        #print accuracy
        if i %val_acc_step == (val_acc_step - 1):
            print ('-------------------------------------------------------------------')
            print('validation Accuracy = ' + str(validationAccuracy(valid)))
            valid_hist.append(validationAccuracy(valid))
            print ('-------------------------------------------------------------------')
            
        if i %save_mode_step == (save_mode_step - 1):
            torch.save(crnn, 'drive/My Drive/saved_models_crnn/epoch_' + str(epoch) + ' ' + 'iter_' + str(i) + '.pt')