In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'D:\\Desktop\\abbyy\\data\\lmdb\\train.lmdb'
test_path = r'D:\\Desktop\\abbyy\\data\\lmdb\\test.lmdb'
gt_path = 'D:\\Desktop\\abbyy\\data\\gt.txt'

In [2]:
import cv2
import numpy as np
from torchvision import transforms as tt
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

from torchvision import models

from torch.nn import functional as F
from torch import linalg

from IPython.display import clear_output

In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
mean, std = 220.99527121034433, 17.8814207186396

# Чтобы не ждать при перезапуске ноутбука

# n_images = train_helper.size()
# multiplier = n_images ** -1
# mean = 0
# for idx in tqdm(range(n_images)):
#     img, label = train_helper.get_item(idx)
#     mean += img.mean() * multiplier

# std = 0
# for idx in tqdm(range(n_images)):
#     img, label = train_helper.get_item(idx)
#     std += (img.mean() - mean) ** 2
# std = np.sqrt(std / n_images)
#
mean, std

(220.99527121034433, 17.8814207186396)

In [7]:
input_shape = 64
transform = tt.Compose([tt.ToPILImage(),
                        tt.Resize((input_shape, input_shape)),
                       tt.ToTensor(),
                       tt.Normalize(mean / input_shape, std / input_shape),
                       tt.Lambda(lambda x: x.repeat(3, 1, 1) )])

In [8]:
class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper):
        self.helper = helper
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
#         return (cv2.resize(img, (32, 32)) - 127.5) / 255., label
        return transform(img), label

In [9]:
train_dataset = HWDBDataset(train_helper)
val_dataset = HWDBDataset(val_helper)

In [10]:
device = torch.device('cuda')

In [11]:
n_classes = train_helper.vocabulary.num_classes()
n_classes

7330

In [12]:
'''
source: https://github.com/shyhyawJou/ArcFace-Pytorch/blob/main/arcface.py
'''


class ArcFace(nn.Module):
    def __init__(self, cin, cout, s=8, m=0.5):
        super().__init__()
        self.s = s
        self.sin_m = torch.sin(torch.tensor(m))
        self.cos_m = torch.cos(torch.tensor(m))
        self.cout = cout
        self.fc = nn.Linear(cin, cout, bias=False)

    def forward(self, x, label=None):
        w_L2 = linalg.norm(self.fc.weight.detach(), dim=1, keepdim=True).T
        x_L2 = linalg.norm(x, dim=1, keepdim=True)
        cos = self.fc(x) / (x_L2 * w_L2)

        if label is not None:
            sin_m, cos_m = self.sin_m, self.cos_m
            one_hot = F.one_hot(label, num_classes=self.cout)
            sin = (1 - cos ** 2) ** 0.5
            angle_sum = cos * cos_m - sin * sin_m
            cos = angle_sum * one_hot + cos * (1 - one_hot)
            cos = cos * self.s
                        
        return cos

In [13]:
class resnet_arcface(nn.Module):
    def __init__(self, n_classes):
        super(resnet_arcface, self).__init__()
        self.net = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        for param in self.net.parameters():
            param.requires_grad = False
        n_features = self.net.fc.in_features
        self.net.fc = nn.Identity()
        self.classifier = ArcFace(n_features, n_classes, m=0.5)
        
    def forward(self, imgs, label=None):
        features = self.net(imgs)
        output = self.classifier(features) if label is None else self.classifier(features, label)
        return output

In [14]:
model = resnet_arcface(n_classes).to(device)

In [15]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False)

In [16]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [17]:
from tqdm import tqdm


def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            logits = model(X.to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        logits = model(X.to(torch.float32).cuda(), y.cuda())
        loss = loss_fn(logits, y.to(torch.long).cuda())
        
        optim.zero_grad()
        loss.backward()
        optim.step()

In [18]:
accuracies = []
for epoch in range(50):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn)
    accuracy = run_validation(val_loader, model)
    
    clear_output()
    print(f'Epoch {epoch}:')
    print(f'accuracy: {accuracy}')
    accuracies.append(accuracy)
    plt.plot(accuracies)
    plt.show()
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'resnet_arcface_epoch{epoch}.pth')

Epoch 0:


  1%|██▎                                                                                                                                                                                                  | 60/5036 [00:12<17:41,  4.69it/s]


KeyboardInterrupt: 