In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
import torch.nn.functional as F



In [None]:
import numpy as np
bs = 128


data = np.load("/mnt/c/Users/cck20/pytorch/data/dataset.npz")
x_train = data["x_train"]
y_train = data["y_train"]
x_valid = data["x_valid"]
y_valid = data["y_valid"]
print(x_train.shape)


In [None]:
# make make the second index to RGB
x_train = x_train.transpose(0, 3, 1, 2)
x_valid = x_valid.transpose(0, 3, 1, 2)



x_train = torch.tensor(x_train, dtype=torch.float)
y_train = torch.tensor(y_train, dtype=torch.long)
x_valid = torch.tensor(x_valid, dtype=torch.float)
y_valid = torch.tensor(y_valid, dtype=torch.long)

train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)
train_loader = DataLoader(train_ds, bs, shuffle=True)
valid_loader = DataLoader(valid_ds, bs, shuffle=True)

In [None]:
lr = 0.001
epochs = 200
from tqdm import tqdm
class CNN_captcha_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=2)
        self.fc10 = nn.Linear(12672, 1000)
        self.fc11 = nn.Linear(1000, 62)
        self.fc20 = nn.Linear(12672, 1000)
        self.fc21 = nn.Linear(1000, 62)
        self.fc30 = nn.Linear(12672, 1000)
        self.fc31 = nn.Linear(1000, 62)
        self.fc40 = nn.Linear(12672, 1000)
        self.fc41 = nn.Linear(1000, 62)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.25)


    def forward(self, xb):
        xb = self.conv1(xb)
        xb = F.relu(xb)

        xb = self.conv2(xb)
        xb = F.relu(xb)
        xb = F.max_pool2d(xb, kernel_size=2, stride=2)
        xb = self.conv3(xb)
        xb = F.relu(xb)
        xb = F.max_pool2d(xb, kernel_size=2, stride=2)

        xb = xb.flatten(start_dim=1)

        xb = self.dropout1(xb)

        xb1 = self.fc10(xb)
        xb1 = F.relu(xb1)
        xb1 = self.dropout2(xb1)
        xb1 = self.fc11(xb1)

        xb2 = self.fc20(xb)
        xb2 = F.relu(xb2)
        xb2 = self.dropout2(xb2)
        xb2 = self.fc21(xb2)

        xb3 = self.fc30(xb)
        xb3 = F.relu(xb3)
        xb3 = self.dropout2(xb3)
        xb3 = self.fc31(xb3)

        xb4 = self.fc40(xb)
        xb4 = F.relu(xb4)
        xb4 = self.dropout2(xb4)
        xb4 = self.fc41(xb4)
        return xb1, xb2, xb3, xb4
    
# pred 128*248 logits real 128*4 need to be one-hot encoded
loss_func = nn.CrossEntropyLoss()


def train_loop(model):
    opt = optim.Adam(model.parameters(), lr=lr)
    count = 0
    model.train()
    for epoch in tqdm(range(epochs)):
        for xb, yb in train_loader:
            count += 1
            c1pred, c2pred, c3pred, c4pred = model(xb)

            loss1 = loss_func(c1pred, yb[:,0])
            loss2 = loss_func(c2pred, yb[:,1])
            loss3 = loss_func(c3pred, yb[:,2])
            loss4 = loss_func(c4pred, yb[:,3])
            loss = (loss1+loss2+loss3+loss4)
            loss.backward()
            
            opt.step()
            opt.zero_grad()
            if count % 100 == 0:
                model.eval()
                with torch.no_grad():
                    loss_ave = 0
                    loss1_ave = 0
                    loss2_ave = 0
                    loss3_ave = 0
                    loss4_ave = 0
                    size_v = 0
                    for xv, yv in valid_loader:
                        c1v, c2v, c3v, c4v = model(xv)

                        vloss1 = loss_func(c1v, yv[:,0])
                        vloss2 = loss_func(c2v, yv[:,1])
                        vloss3 = loss_func(c3v, yv[:,2])
                        vloss4 = loss_func(c4v, yv[:,3])
                        vloss = (vloss1+vloss2+vloss3+vloss4)
                        
                        loss_ave += vloss.item()
                        loss1_ave += vloss1.item()
                        loss2_ave += vloss2.item()
                        loss3_ave += vloss3.item()
                        loss4_ave += vloss4.item()

                        size_v += 1
                    print(f"{loss_ave/size_v}")
                model.train()


model = CNN_captcha_model()




In [None]:
train_loop(model)

In [None]:
# show the picture and the prediction
import string
import matplotlib.pyplot as plt
def decode(char_pred: torch.Tensor):
    char_table = string.digits + string.ascii_letters
    max_idx = char_pred.argmax()
    return char_table[max_idx]
    

for xv, yv in valid_loader:
    c1v, c2v, c3v, c4v = model(xv)
    for i in range(xv.shape[0]):
        text = ""
        text += decode(c1v[i])
        text += decode(c2v[i])
        text += decode(c3v[i])
        text += decode(c4v[i])
        print(text)
        img = xv[i].transpose(0, 1)
        img = img.transpose(1, 2)
        img = img.clamp(0, 255).numpy().astype("uint8")
        plt.imshow(img)
        plt.axis("off")
        plt.show()
        
    


    

In [None]:
y = torch.tensor([[[1], [2]],[[1], [2]],[[1], [2]]])
print(y.shape)
y = y.transpose(2, 0)
print(y.shape)