In [1]:
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from utils import CaptchaDataset2, plot_sample, decode_from_output, accuracy, total_chars

import torch.nn.functional as F
from itertools import groupby

In [2]:
batch_size = 64
lr = 0.0001
epochs = 1000

train_dataset = CaptchaDataset2('../data/train')
val_dataset = CaptchaDataset2('../data/test')
test_dataset = CaptchaDataset2('../data/original')


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
test_dataset.samples

array([['1', 'ZD5XL9OO'],
       ['2', '4067BWO4'],
       ['3', 'CU2XCT04'],
       ['4', 'WXNFLEBI'],
       ['5', 'KVUQBN8L'],
       ['6', '91J1R3KV'],
       ['7', '3VS0CIZI'],
       ['8', '1FPDJE4Z'],
       ['9', 'JXZ9TRIO'],
       ['10', '5EPPVU4U'],
       ['12', 'NUKODBH2'],
       ['14', '4JY4YW8F'],
       ['15', 'HAVEAP6I'],
       ['17', 'QSTPBWV4'],
       ['18', 'JA8J6HQM'],
       ['19', '7G6AEI1Q'],
       ['20', 'HQZTYUIS'],
       ['21', 'Y86GDHW1'],
       ['22', 'N7HLZDE'],
       ['23', 'VY84SU4I'],
       ['24', 'WEP8QDON'],
       ['25', '9KRGO672'],
       ['27', '9K6DMKEL'],
       ['28', 'L9XA5D2E'],
       ['29', 'DWOEHHVL'],
       ['30', 'DJC4KPXU'],
       ['31', 'I3ZUK1H8'],
       ['33', 'TNFI8GRS'],
       ['34', 'KWEAYPEX'],
       ['35', '5E9B1ZFL'],
       ['36', 'FCUC5M82'],
       ['37', 'EVPNS17C'],
       ['38', 'VXLNVAHI'],
       ['39', 'OUA49VIJ'],
       ['41', 'L6BV8HVP'],
       ['42', 'U9103PGP'],
       ['43', 'U7M18968'],
       ['44

In [5]:
for a,b in test_loader:
    print(b)

tensor([[25.,  3., 31., 23., 11., 35., 14., 14.],
        [30., 26., 32., 33.,  1., 22., 14., 30.],
        [ 2., 20., 28., 23.,  2., 19., 26., 30.],
        [22., 23., 13.,  5., 11.,  4.,  1.,  8.],
        [10., 21., 20., 16.,  1., 13., 34., 11.],
        [35., 27.,  9., 27., 17., 29., 10., 21.],
        [29., 21., 18., 26.,  2.,  8., 25.,  8.],
        [27.,  5., 15.,  3.,  9.,  4., 30., 25.],
        [ 9., 23., 25., 35., 19., 17.,  8., 14.],
        [31.,  4., 15., 15., 21., 20., 30., 20.],
        [13., 20., 10., 14.,  3.,  1.,  7., 28.],
        [30.,  9., 24., 30., 24., 22., 34.,  5.],
        [ 7.,  0., 21.,  4.,  0., 15., 32.,  8.],
        [16., 18., 19., 15.,  1., 22., 21., 30.],
        [ 9.,  0., 34.,  9., 32.,  7., 16., 12.],
        [33.,  6., 32.,  0.,  4.,  8., 27., 16.],
        [ 7., 16., 25., 19., 24., 20.,  8., 18.],
        [24., 34., 32.,  6.,  3.,  7., 22., 27.],
        [13., 33.,  7., 11., 25.,  3.,  4., 36.],
        [21., 24., 34., 30., 18., 20., 30.,  8.],


In [None]:
len(test_loader)

In [None]:
a, b= test_dataset[11]

In [None]:
plot_sample(a)
b

In [None]:
BLANK_LABEL = total_chars

In [None]:
a.shape

In [None]:
total_chars

In [None]:
class StackedLSTM(nn.Module):
    def __init__(self, input_size=30, output_size=total_chars+1, hidden_size=512, num_layers=3):
        super(StackedLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(hidden_size, output_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        
    def forward(self, inputs, hidden):
        batch_size, seq_len, input_size = inputs.shape
        outputs, hidden = self.lstm(inputs, hidden)
        outputs = self.dropout(outputs)
        outputs = torch.stack([self.fc(outputs[i]) for i in range(width)])
        outputs = F.log_softmax(outputs, dim=2)
        return outputs, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data 
        return (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
    
net = StackedLSTM().to(device)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = nn.CTCLoss(blank=BLANK_LABEL)

In [None]:
  # set network to training phase
    
batch_size = batch_size
# for each pass of the training dataset
for epoch in range(1, epochs+1):
    
    net.train()
    
    train_loss, train_correct, train_total = 0, 0, 0
    
    h = net.init_hidden(batch_size)
    
    # for each batch of training examples
    for batch_index, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        h = tuple([each.data for each in h])
        
        batch_size, channels, height, width = inputs.shape
        print(batch_size, channels, height, width)
        
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, batch_size, -1)))
                
        optimizer.zero_grad()  # zero the parameter gradients
        outputs, h = net(inputs, h)  # forward pass
        
        #print(outputs.shape)

        # compare output with ground truth
        input_lengths = torch.IntTensor(batch_size).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        print(outputs.shape, targets.shape, input_lengths.shape, target_lengths.shape)
        loss = criterion(outputs, targets, input_lengths, target_lengths)

        loss.backward()  # backpropagation
        #nn.utils.clip_grad_norm_(net.parameters(), 10)  # clip gradients
        optimizer.step()  # update network weights
        
        # record statistics
        prob, max_index = torch.max(outputs, dim=2)
        train_loss += loss.item()
        train_total += len(targets)

        for i in range(batch_size):
            raw_pred = list(max_index[:, i].cpu().numpy())
            #print(len(raw_pred))
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                train_correct += 1

        # print statistics every 10 batches
        if (batch_index + 1) % 200 == 0:
            print(f'Epoch {epoch }/{epochs}, ' +
                  f'Batch {batch_index + 1}/{len(train_loader)}, ' +
                  f'Train Loss: {(train_loss/1):.5f}, ' +
                  f'Train Accuracy: {(train_correct/train_total):.5f}')
            
            train_loss, train_correct, train_total = 0, 0, 0
            
    
    # validation
    net.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    
    # for each batch of training examples
    for batch_index, (inputs, targets) in enumerate(val_loader):
        inputs = inputs.to(device)
        batch_size, channels, height, width = inputs.shape
        h = net.init_hidden(batch_size)
        h = tuple([each.data for each in h])
        
        
        
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, batch_size, -1)))
                
        outputs, h = net(inputs, h)  # forward pass
        input_lengths = torch.IntTensor(batch_size).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        
        loss = criterion(outputs, targets, input_lengths, target_lengths)
        
        # record statistics
        prob, max_index = torch.max(outputs, dim=2)
        val_loss += loss.item()
        val_total += len(targets)

        for i in range(batch_size):
            raw_pred = list(max_index[:, i].cpu().numpy())
            #print(len(raw_pred))
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                val_correct += 1

        
    print(f'Epoch {epoch }/{epochs}, ' +
          f'Val Loss: {(val_loss/1):.5f}, ' +
          f'Val Accuracy: {(val_correct/val_total):.5f}')

    val_loss, val_correct, val_total = 0, 0, 0
    
    # test
    net.eval()
    test_loss, test_correct, test_total = 0, 0, 0
    
    # for each batch of training examples
    for batch_index, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        
        
        batch_size, channels, height, width = inputs.shape
        h = net.init_hidden(batch_size)
        
        h = tuple([each.data for each in h])
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, batch_size, -1)))
                
        outputs, h = net(inputs, h)  # forward pass
        

        # compare output with ground truth
        input_lengths = torch.IntTensor(batch_size).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        
        loss = criterion(outputs, targets, input_lengths, target_lengths)

        
        # record statistics
        prob, max_index = torch.max(outputs, dim=2)
        test_loss += loss.item()
        test_total += len(targets)

        for i in range(batch_size):
            raw_pred = list(max_index[:, i].cpu().numpy())
            #print(len(raw_pred))
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                test_correct += 1

        
        print(f'Epoch {epoch }/{epochs}, ' +
              f'Test Loss: {(test_loss/1):.5f}, ' +
              f'Test Accuracy: {(test_correct/test_total):.5f}')

        test_loss, test_correct, test_total = 0, 0, 0
    

In [None]:
inputs = inputs.to(device)

batch_size, channels, height, width = inputs.shape
h = net.init_hidden(batch_size)

inputs = (inputs
          .permute(3, 0, 2, 1)
          .contiguous()
          .view((width, batch_size, -1)))

# get prediction
outputs, h = net(inputs, h)  # forward pass
prob, max_index = torch.max(outputs, dim=2)
raw_pred = list(max_index[:, i].cpu().numpy())

# print raw prediction with BLANK_LABEL replaced with "-"
print('Raw Prediction: ' + ''.join([str(c) if c != BLANK_LABEL else '-' for c in raw_pred]))

pred = [str(c) for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
print(f"Prediction: {''.join(pred)}")

In [None]:
new_sample, new_label = test_dataset[1]

In [None]:
plot_sample(new_sample)

In [None]:
new_sample = new_sample.unsqueeze(0)

new_sample = (new_sample
          .permute(3, 0, 2, 1)
          .contiguous()
          .view((140, 1, -1))).to(device)

new_label = new_label.to(device)

In [None]:
h = net.init_hidden(1)
outputs, h = net(new_sample, h)  # forward pass

In [None]:
prob, max_index = torch.max(outputs, dim=2)
raw_pred = list(max_index[:, 0].cpu().numpy())

In [None]:
print('Raw Prediction: ' + ''.join([str(c) if c != BLANK_LABEL else '-' for c in raw_pred]))

pred = [str(c) for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
print(f"Prediction: {''.join(pred)}")

In [None]:
pred

In [None]:
new_label

In [None]:
import string

all_chars = string.ascii_uppercase + '0123456789'
total_chars = len(all_chars)
captcha_length = 8

encoding_dict = {l:e for e,l in enumerate(all_chars)}
decoding_dict = {e:l for l,e in encoding_dict.items()}

In [None]:
aa = ''
for z in pred:
    aa += decoding_dict[int(z)] 

In [None]:
aa

In [None]:
#torch.save(net.state_dict(), 'models/v0.pt')

In [None]:
len(val_dataset)

In [None]:
len(train_dataset)/64