In [None]:
!pip install fonttools Pillow==9.1.1 torch torchmetrics torchmetrics[image]
!pip install -U torchvision torchaudio

In [None]:
%env OMP_NUM_THREADS=8
%env MKL_NUM_THREADS=8

In [None]:
import os
import shutil

import numpy as np
np.random.seed(24)

from tqdm import tqdm

from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont, ImageChops
from PIL.Image import Resampling

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["figure.figsize"] = (15, 15)

In [None]:
def get_glyph_image(glyph, font_name, glyph_size, img_size):
    font = ImageFont.truetype(f"fonts-src/{font_name}.otf", glyph_size - glyph_size // 10)

    image = Image.new("L", (img_size, img_size), "white")
    draw = ImageDraw.Draw(image)
    
    offset_w, offset_h = font.getoffset(glyph)
    w, h = draw.textsize(glyph, font = font)
    pos = ((img_size - w - offset_w) / 2, (img_size - h - offset_h) / 2)

    draw.text(pos, glyph, "black", font = font)

    return image


In [None]:
def draw_glyph_PNG(image, glyph_name, font_name, output_folder):
    if (not os.path.exists(f"{output_folder}")):
        os.mkdir(f"{output_folder}")
    
    if (not os.path.exists(f"{output_folder}/{font_name}")):
        os.mkdir(f"{output_folder}/{font_name}")
    
    image.save(f"{output_folder}/{font_name}/{glyph_name}.png")


In [None]:
def get_all_chars_from_font(font_name):
    with TTFont(f"fonts-src/{font_name}.otf") as font:
        characters = []
        for t in font["cmap"].tables:
            if (not t.isUnicode()):
                continue
            
            for c in t.cmap.items():
                characters.append((str(chr(c[0])), c[1]))
                
        return set(characters)


In [None]:
def get_glyph_size(glyph, font_name, img_size):
    image = Image.new("L", (img_size, img_size), "white")
    draw = ImageDraw.Draw(image)
    
    l, r = 1, img_size * 4
    while (l + 1 < r):
        m = (l + r) // 2
        
        re_font = ImageFont.truetype(f"fonts-src/{font_name}.otf", m)
        it_font = ImageFont.truetype(f"fonts-src/{font_name}i.otf", m)
        
        re_w, re_h = draw.textsize(glyph, font = re_font)
        it_w, it_h = draw.textsize(glyph, font = it_font)
        
        if (re_w > img_size or re_h > img_size or it_w > img_size or it_h > img_size):
            r = m
        else:
            l = m
    
    return l


In [None]:
def check_glyph_equality(glyph, font_name, img_size):
    re_font = ImageFont.truetype(f"fonts-src/{font_name}.otf", img_size // 2)
    it_font = ImageFont.truetype(f"fonts-src/{font_name}i.otf", img_size // 2)

    re_image = Image.new("L", (img_size, img_size), "white")
    re_draw = ImageDraw.Draw(re_image)
    it_image = Image.new("L", (img_size, img_size), "white")
    it_draw = ImageDraw.Draw(it_image)

    re_draw.text((img_size // 8, img_size // 8), glyph, "black", font = re_font)
    it_draw.text((img_size // 8, img_size // 8), glyph, "black", font = it_font)

    diff = ImageChops.difference(re_image, it_image)
    return (diff.getbbox() is None)


In [None]:
def draw_font_set_PNG(font_name, re_output, it_output, img_size):
    re_chars = get_all_chars_from_font(font_name)
    it_chars = get_all_chars_from_font(font_name + "i")
    
    chars = re_chars.intersection(it_chars)
    
    for glyph, glyph_name in chars:
        if (glyph.isspace()):
            continue

        if (glyph_name == ".null"):
            continue
            
        if (ord(glyph[0]) > 0x2116):
            continue
            
        glyph_size = get_glyph_size(glyph, font_name, img_size)

        re_img = get_glyph_image(glyph, font_name, glyph_size, img_size)
        it_img = get_glyph_image(glyph, font_name + "i", glyph_size, img_size)
        
        draw_glyph_PNG(re_img, glyph_name, font_name, re_output)
        draw_glyph_PNG(it_img, glyph_name, font_name + "i", it_output)


In [None]:
def draw_fonts_PNG(re_output, it_output, img_size):
    if (os.path.exists(re_output)):
        shutil.rmtree(re_output)
        os.mkdir(re_output)
        
    if (os.path.exists(it_output)):
        shutil.rmtree(it_output)
        os.mkdir(it_output)
    
    re_fonts = sorted(list(filter(lambda f: "i" not in f, os.listdir("fonts-src"))))
    it_fonts = sorted(list(filter(lambda f: "i" in f, os.listdir("fonts-src"))))
    assert(re_fonts == list(map(lambda s: s.replace("i", ""), it_fonts)))
    
    for font_name in tqdm(re_fonts):
        draw_font_set_PNG(font_name.replace(".otf", ""), re_output, it_output, img_size)


In [None]:
IMAGE_SIZE = 128
CHANNELS_CNT = 1

def draw_all_fonts(img_size):
    return
    
    draw_fonts_PNG(f"fonts-re-{img_size}", f"fonts-it-{img_size}", img_size)
    
    pass


draw_all_fonts(img_size = IMAGE_SIZE)

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0,1,2,3,4,5

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as tr
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from enum import Enum
import gc

torch.set_num_threads(8)

RE_FONTS_PATH = f"fonts-re-{IMAGE_SIZE}/"
IT_FONTS_PATH = f"fonts-it-{IMAGE_SIZE}/"

In [None]:
def clear_gpu():
    #return
    
    model.to("cpu")
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
def train_val_test_split():
    np.random.seed(24)
    
    fonts = list(filter(lambda f: "i" not in f, os.listdir("fonts-src")))
    val_cnt = len(fonts) // 15
    test_cnt = len(fonts) // 15
    
    fonts = np.array(fonts)
    np.random.shuffle(fonts)
    return list(fonts[: -(val_cnt + test_cnt)]), list(fonts[-(val_cnt + test_cnt) : -test_cnt]), list(fonts[-test_cnt :])


In [None]:
def retrieve_glyphs(fonts):
    re_glyphs, it_glyphs = [], []
    for font in fonts:
        re = sorted(os.listdir(f"{RE_FONTS_PATH}/{font.replace('.otf', '')}"))
        it = sorted(os.listdir(f"{IT_FONTS_PATH}/{font.replace('.otf', '')}i"))
        assert(re == it)
        
        for glyph in re:
            img = Image.open(f"{RE_FONTS_PATH}/{font.replace('.otf', '')}/{glyph}")
            re_glyphs.append(np.array(img))
            
        for glyph in it:
            img = Image.open(f"{IT_FONTS_PATH}/{font.replace('.otf', '')}i/{glyph}")
            it_glyphs.append(np.array(img))
            
    return re_glyphs, it_glyphs


In [None]:
transforms = tr.Compose([
    tr.ToTensor(),
    #tr.Normalize(channel_mean, channel_std),
])

In [None]:
class Mode(Enum):
    train = 0
    test = 1
    val = 2

class GlyphDataset(Dataset):
    def __init__(self, re, it, mode):
        assert(len(re) == len(it))
        self.re = re
        self.it = it
        
        self.mode = mode

    def __len__(self):
        return len(self.re)
    
    def __getitem__(self, index):
        re = transforms(self.re[index])
        it = transforms(self.it[index])
        
        return re, it
    

In [None]:
train_fonts, val_fonts, test_fonts = train_val_test_split()
print(len(train_fonts), len(val_fonts), len(test_fonts))

In [None]:
device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
%%time

BATCH_SIZE = 128

train_ds = GlyphDataset(*retrieve_glyphs(train_fonts), Mode.train)
train_dl = DataLoader(
    train_ds,
    batch_size = BATCH_SIZE,
    shuffle = True,
    drop_last = True,
    num_workers = 0
)

val_ds = GlyphDataset(*retrieve_glyphs(val_fonts), Mode.val)
val_dl = DataLoader(
    val_ds,
    batch_size = BATCH_SIZE,
    shuffle = False,
    drop_last = False,
    num_workers = 0
)

test_ds = GlyphDataset(*retrieve_glyphs(test_fonts), Mode.test)
test_dl = DataLoader(
    test_ds,
    batch_size = BATCH_SIZE,
    shuffle = False,
    drop_last = False,
    num_workers = 0
)

In [None]:
from copy import deepcopy

def correct_picture(pic, lower_bound = 0.0, upper_bound = 0.8):
    res = deepcopy(pic)
    
    for i in range(res.shape[0]):
        for j in range(res.shape[1]):
            if (res[i][j] > upper_bound):
                res[i][j] = 1
            elif (res[i][j] < lower_bound):
                res[i][j] = 0
            elif (lower_bound == upper_bound):
                res[i][j] = 0
            else:
                res[i][j] = (res[i][j] - lower_bound) / (upper_bound - lower_bound)
    
    return res


In [None]:
from torchmetrics import MeanSquaredError, MeanAbsoluteError, StructuralSimilarityIndexMeasure
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score

criterion = nn.MSELoss()

def train(model, optimizer, loader, criterion):
    model.train()
    losses = []
    
    for re, it in tqdm(loader):
        re = re.to(device)
        it = it.to(device)
        
        optimizer.zero_grad()
        out = model(re)
        loss = criterion(out, it)
        
        loss.backward()
        optimizer.step()
        losses.append(loss.item()) 
    
    return model, optimizer, np.mean(losses)


def val(model, loader, criterion):
    model.eval()
    
    mse = MeanSquaredError().to(device)
    mae = MeanAbsoluteError().to(device)
    acc = BinaryAccuracy().to(device)
    fsc = BinaryF1Score().to(device)
    ssim = StructuralSimilarityIndexMeasure().to(device)
    
    loss_l, mse_l, mae_l, acc_l, fsc_l, ssim_l = [], [], [], [], [], []
    
    with torch.no_grad():
        for re, it in tqdm(loader):
            re = re.to(device)
            it = it.to(device)
            
            out = model(re)
            
            loss_l.append(criterion(out, it).item())
            
            mse_l.append(mse(out, it).item())
            mae_l.append(mae(out, it).item())
            
            correcred_it = it.round().type(torch.int)
            
            acc_l.append(acc(out, correcred_it).item())
            fsc_l.append(fsc(out, correcred_it).item())
            
            ssim_l.append(ssim(out, it).item())
    
    loss = np.mean(loss_l)
    mse = np.mean(mse_l)
    mae = np.mean(mae_l)
    acc = np.mean(acc_l)
    fsc = np.mean(fsc_l)
    ssim = np.mean(ssim_l)
    
    print(
        f"loss: {'{:.6f}'.format(loss)}; " + \
        f"mse: {'{:.6f}'.format(mse)}; " + \
        f"mae: {'{:.6f}'.format(mae)}; " + \
        f"acc: {'{:.3f}'.format(acc)}; " + \
        f"fsc: {'{:.3f}'.format(fsc)}; " + \
        f"ssim: {'{:.3f}'.format(ssim)} ",
        end = "\n\n", flush = True
    )
    
    return loss, mse, mae, acc, fsc, ssim


In [None]:
def learning_loop(
    model, optimizer,
    train_loader, val_loader,
    criterion, epochs = 10,
    scheduler = None, min_lr = None,
    val_every = 1, draw_every = 1
):
    metrics = {"train": [], "val": []}

    best_loss = 2.0
    for epoch in range(1, epochs + 1):
        print(f"#{epoch}/{epochs}:", flush = True)
        
        model, optimizer, loss = train(model, optimizer, train_loader, criterion)
        metrics["train"].append(loss)

        if (epoch % val_every == 0):
            val_metrics = val(model, val_loader, criterion)
            metrics["val"].append(val_metrics)
            val_loss = val_metrics[0]
            
            if scheduler:
                scheduler.step(val_loss)
                
            if (val_loss < best_loss):
                best_loss = val_loss

                torch.save({
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict()
                }, "./model")
    
    return model, optimizer, metrics


In [None]:
def create_model_and_optimizer(model_class, model_params, lr = 1e-3, device = device):
    model = model_class(**model_params)
    model = model.to(device)
    
    params = []
    for param in model.parameters():
        if param.requires_grad:
            params.append(param)
    
    optimizer = torch.optim.Adam(params, lr)
    return model, optimizer


In [None]:
class simple_net(nn.Module):
    def __init__(self, input_size, num_layers, hidden_sizes, activations, dropouts, output_size):
        super(simple_net, self).__init__()
        
        flat = ("flat", nn.Flatten())
        in_to_hid = ("in2hid", nn.Linear(input_size, hidden_sizes))
        
        head = [
            (f"act_last", nn.ReLU()),
            ("hid2out", nn.Linear(hidden_sizes, output_size)),
            ("sigmoid", nn.Sigmoid())
        ]
        
        self.net = [flat, in_to_hid, *head]
        self.net = nn.Sequential(OrderedDict(self.net))
    
    def forward(self, inp):
        return torch.reshape(self.net(inp), (-1, CHANNELS_CNT, IMAGE_SIZE, IMAGE_SIZE))


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        stride = 2 if (in_channels != out_channels) else 1

        if (in_channels != out_channels):
            self.shortcut = nn.Sequential(OrderedDict([
                ("downsample_conv", nn.Conv2d(
                    in_channels, out_channels,
                    kernel_size = 1, stride = 2,
                    bias = False
                )),
                ("downsample_norm", nn.BatchNorm2d(out_channels))
            ]))
        
        else:
            self.shortcut = nn.Identity()
        
        self.activation = nn.ReLU()
        
        self.conv1 = nn.Conv2d(
            in_channels, out_channels,
            kernel_size = 3, stride = stride,
            padding = 1, dilation = 1,
            groups = 1, bias = False
        )
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size = 3, stride = 1,
            padding = 1, dilation = 1,
            groups = 1, bias = False
        )

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.stride = stride


    def forward(self, x):
        residual = self.shortcut(x)
        out = x

        out = self.conv1(out)
        out = self.bn1(out)
        out = self.activation(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = self.activation(out)

        return out


class ResNetLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.blocks = nn.Sequential(OrderedDict([
            ("residual1", ResidualBlock(in_channels, out_channels)),
            ("residual2", ResidualBlock(out_channels, out_channels))
        ]))


    def forward(self, x):
        return self.blocks(x)


class MyResNet64(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.image_size = 64
        self.channels_cnt = 1

        self.conv1 = nn.Conv2d(
            self.channels_cnt, 64,
            kernel_size = 3, stride = 1,
            padding = 1, dilation = 1,
            groups = 1, bias = False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.activation = nn.ReLU()
        self.maxpool = nn.MaxPool2d(
            kernel_size = 3, stride = 2,
            padding = 1, dilation = 1
        )

        self.layers = nn.Sequential(OrderedDict([
            #("resnet1", ResNetLayer(64, 64)),
            ("resnet2", ResNetLayer(64, 128)),
            ("resnet3", ResNetLayer(128, 256)),
            ("resnet4", ResNetLayer(256, 512))
        ]))

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_features = 512 * 4 * 4, out_features = self.image_size ** 2, bias = True)

        
    def forward(self, x):
        out = x

        out = self.conv1(out)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.maxpool(out)

        out = self.layers(out)

        out = self.flatten(out)
        out = self.fc(out)

        return torch.reshape(out, (-1, self.channels_cnt, self.image_size, self.image_size))


class MyResNet128(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.image_size = 128
        self.channels_cnt = 1

        self.conv1 = nn.Conv2d(
            self.channels_cnt, 64,
            kernel_size = 3, stride = 1,
            padding = 1, dilation = 1,
            groups = 1, bias = False
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.activation = nn.ReLU()
        self.maxpool = nn.MaxPool2d(
            kernel_size = 3, stride = 2,
            padding = 1, dilation = 1
        )

        self.layers = nn.Sequential(OrderedDict([
            ("resnet1", ResNetLayer(64, 128)),
            ("resnet2", ResNetLayer(128, 256)),
            ("resnet3", ResNetLayer(256, 512))
        ]))

        self.avgpool = nn.AvgPool2d(kernel_size = 2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_features = 512 * 4 * 4, out_features = self.image_size ** 2, bias = True)

        
    def forward(self, x):
        out = x

        out = self.conv1(out)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.maxpool(out)

        out = self.layers(out)

        out = self.avgpool(out)
        out = self.flatten(out)
        out = self.fc(out)

        return torch.reshape(out, (-1, self.channels_cnt, self.image_size, self.image_size))


In [None]:
%%time

'''
model, optimizer = create_model_and_optimizer(
    simple_net,
    {
        "input_size": IMAGE_SIZE ** 2,
        "num_layers": 0,
        "hidden_sizes": 2 ** 12,
        "activations": 0,
        "dropouts": 0,
        "output_size": IMAGE_SIZE ** 2
    },
    lr = 1e-4
)
'''

model, optimizer = create_model_and_optimizer(
    MyResNet128,
    {},
    lr = 1e-4
)

In [None]:
%%time

model, optimizer, losses = learning_loop(model, optimizer, train_dl, val_dl, criterion, epochs = 100)

In [None]:
def get_model_params_cnt(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
checkpoint = torch.load("./model", map_location = device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [None]:
print(get_model_params_cnt(model))
model

In [None]:
def show_results(model, loader):
    np.random.seed(24)
    model.eval()
    
    with torch.no_grad():
        for i, (re, it) in enumerate(loader):
            re = re.to(device)
            it = it.to(device)
            
            index = np.random.randint(0, re.shape[0])
            re = re[index].unsqueeze(0)
            it = it[index].unsqueeze(0)
            
            out = model(re)
            
            ini_pic = re.detach().cpu().squeeze().numpy()
            
            res_pic = correct_picture(out.detach().cpu().squeeze().numpy())
            
            # resampling
            # res_pic = Image.fromarray(np.uint8(res_pic))
            # width, height = res_pic.width, res_pic.height
            # res_pic = res_pic.resize((width * 2, height * 2))
            # res_pic = res_pic.resize((width, height), resample = Resampling.BOX)
            # res_pic = np.array(res_pic)
            
            tgt_pic = it.detach().cpu().squeeze().numpy()
            
            _, axarr = plt.subplots(1, 3)
            axarr[0].imshow(ini_pic, cmap = "gray", interpolation = None)
            axarr[1].imshow(res_pic, cmap = "gray", interpolation = None)
            axarr[2].imshow(tgt_pic, cmap = "gray", interpolation = None)
            plt.show()
            
            res_pic = Image.fromarray(np.uint8(res_pic * 255), mode = "L")
            res_pic.save(f"samples-png-in/{i}.png")


In [None]:
show_results(model, test_dl)