In [2]:
import torch
import numpy as np
import pickle

In [3]:
class UnifontModule(torch.nn.Module):
    '''https://github.com/aimagelab/VATr/blob/7952b16e4549811c442fb46ed49ac2585908e832/models/unifont_module.py#L6'''
    def __init__(self, 
                 alphabet,
                 out_dim,
                 device='cpu', 
                 input_type='unifont',
                 linear=True):
        super(UnifontModule, self).__init__()
        self.device = device
        self.alphabet = alphabet
        self.symbols = self.get_symbols('unifont')
        self.symbols_repr = self.get_symbols(input_type)

        if linear:
            self.linear = torch.nn.Linear(self.symbols_repr.shape[1], out_dim)
        else:
            self.linear = torch.nn.Identity()

    def get_symbols(self, input_type):
        with open(f"configs/{input_type}.pickle", "rb") as f:
            symbols = pickle.load(f)

        symbols = {sym['idx'][0]: sym['mat'].astype(np.float32).flatten() for sym in symbols}
        # self.special_symbols = [self.symbols[ord(char)] for char in special_alphabet]
        symbols = [symbols[ord(char)] for char in self.alphabet]
        symbols.insert(0, np.zeros_like(symbols[0]))
        symbols = np.stack(symbols)
        return torch.from_numpy(symbols).float().to(self.device)

    def forward(self, QR):
        return self.linear(self.symbols_repr[QR])

In [4]:
module = UnifontModule(
    alphabet='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
    out_dim=100,
)

In [13]:
len('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')

52

In [14]:
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'[26]

'a'

In [12]:
module.symbols_repr.shape

torch.Size([53, 256])

In [11]:
module.symbols_repr[2].reshape(16,16)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 1

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
state_dict = torch.load('/data/ocr/namvt17/FontDiffuser/outputs/FontDiffuser/global_step_40000/total_model.pth')

In [7]:
for key in state_dict.keys():
    if key.startswith('module.unet') or key.startswith('module.style_encoder') or key.startswith('module.content_encoder'):
        continue
    print(key)

module.label_encoder.linear.weight
module.label_encoder.linear.bias


In [8]:
unet= torch.load('/data/ocr/namvt17/FontDiffuser/outputs/FontDiffuser/global_step_40000/unet.pth')

In [9]:
unet.keys()

odict_keys(['conv_in.weight', 'conv_in.bias', 'time_embedding.linear_1.weight', 'time_embedding.linear_1.bias', 'time_embedding.linear_2.weight', 'time_embedding.linear_2.bias', 'down_blocks.0.resnets.0.norm1.weight', 'down_blocks.0.resnets.0.norm1.bias', 'down_blocks.0.resnets.0.conv1.weight', 'down_blocks.0.resnets.0.conv1.bias', 'down_blocks.0.resnets.0.time_emb_proj.weight', 'down_blocks.0.resnets.0.time_emb_proj.bias', 'down_blocks.0.resnets.0.norm2.weight', 'down_blocks.0.resnets.0.norm2.bias', 'down_blocks.0.resnets.0.conv2.weight', 'down_blocks.0.resnets.0.conv2.bias', 'down_blocks.0.resnets.1.norm1.weight', 'down_blocks.0.resnets.1.norm1.bias', 'down_blocks.0.resnets.1.conv1.weight', 'down_blocks.0.resnets.1.conv1.bias', 'down_blocks.0.resnets.1.time_emb_proj.weight', 'down_blocks.0.resnets.1.time_emb_proj.bias', 'down_blocks.0.resnets.1.norm2.weight', 'down_blocks.0.resnets.1.norm2.bias', 'down_blocks.0.resnets.1.conv2.weight', 'down_blocks.0.resnets.1.conv2.bias', 'down_bloc