In [62]:
import cv2
import torch

device = "cuda" if torch.cuda.is_available() else "cpu" # I use AMD... 😭
device

'cpu'

## Initialize dataset directory

In [63]:
train_dir = ''
test_dir = '' 
val_dir = ''

## Initialize dataset

One of the data necessary is the content of the image. To achieve this, we will have to generate an image with plain white background, using Verily Serif Mono font.

In [64]:
from PIL import Image, ImageDraw, ImageFont
def draw_word(word: str) -> Image:
    text_len = len(word)
    font_size = 50
    w = max(256, int(text_len * font_size * 0.64))
    h = 256
    
    img = Image.new('RGB', (w,h), color=(255, 255, 255))
    font = ImageFont.truetype('VerilySerifMono.otf', font_size)
    d = ImageDraw.Draw(img)
    text_width, text_height = d.textsize(word, font)
    position = ((w - text_width) / 2, (h-text_height) / 2)
    
    d.text(position, word, font=font, fill=0)
    return img

In [65]:
import json
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random

class ImageDataset(Dataset):
    def __init__(self, image_dir, word_dir):
        self.image_dir = Path(image_dir)
        with open(word_dir, 'r') as f:
            self.words = json.load(f)
        self.images = list(self.image_dir.glob('*.jpg'))
        self.transform = transforms.Compose([
            transforms.Lambda(lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB)),
            transforms.ToTensor(),
            transforms.Resize((256, 256))
        ])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        assert idx <= len(self), 'Index out of range'
        try:
            rgb_img = self.transform(Image.open(self.images[idx]).convert('RGB'))
            
            content = random.choice(list(self.words.values()))
            allowed_symbols = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
            content = ''.join([i for i in content if i in allowed_symbols])
            while not content:
                content = random.choice(list(self.words.values()))
                content = ''.join([i for i in content if i in allowed_symbols])
            img_content = self.transform(draw_word(content))
            
            content_style = self.words[self.images[idx].stem]
            content_style = ''.join([i for i in content_style if i in allowed_symbols])
            if not content_style:
                content_style = 'o'
            img_content_style = self.transform(draw_word(content_style))
            return rgb_img, img_content, content, img_content_style, content_style
        except Exception as e:
            return torch.tensor(-1), torch.tensor(-1)
    
        # item = {'image': img, 'idx': idx, 'label': self.words[self.images[idx].name]}
        # return item

# Loss function

### 1. VGG19 Typeface classifier, C - Text Perceptual Loss
- Perceptual loss computed from the feature maps at layer i denoted as φi and Mi is the number of elements in the particular feature map which is used as normalization. Only computed for output image corresponding to original content.
- Texture loss / style loss computed from Gram matrix of the feature maps.
- Embedding-based loss computed from feature maps of the penultimate layer of this network. (???)

https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#gistcomment-3347450

In [66]:
# Paper trains this model with Synth-Font dataset, however I could not find it
class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks= [torchvision.models.vgg19(pretrained=True).features[:4].eval(),
                 torchvision.models.vgg19(pretrained=True).features[4:9].eval(),
                 torchvision.models.vgg19(pretrained=True).features[9:16].eval(),
                 torchvision.models.vgg19(pretrained=True).features[16:23].eval()]
        for bl in blocks:
            for p in bl.parameters():   
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.resize = resize
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    # default: feature_layers=[0, 1, 2, 3], style_layers=[]    
    def forward(self, prediction, target, feature_layers=None, style_layers=None):    
        if style_layers is None:
            style_layers = [0, 1, 2, 3]
        if feature_layers is None:
            feature_layers = [2]
        if prediction.shape[1] != 3:
            prediction = prediction.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        prediction = (prediction - self.mean) / self.std
        target = (target-self.mean)/self.std
        if self.resize:
            prediction = self.transform(prediction, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = prediction
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += torch.nn.functional.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

### 2. OCR, R - Text Content Loss
- The relevant modules has already been added along with the saved configuration mentioned in the paper, and the PyTorch model can be downloaded [here](https://drive.google.com/file/d/1b59rXuGGmKne1AuHnkgDzoYgKeETNMv9/view?usp=drive_link). Have not implemented it to calculate loss yet.
- The content loss is computed by measuring the cross entropy between the sequence of characters in the input string, c1, c2, the predicted string, c'1, c'2 respectively and are represented as one-hot vectors.

https://github.com/clovaai/deep-text-recognition-benchmark

In [67]:
from torch import nn
from ocr.utils import AttnLabelConverter
from ocr.model import Model

class opt:
    # Hardcoding the arguments instead of using argument parsers 
    # as per https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/demo.py#L96
    image_folder = 'images'
    workers = 4
    batch_size = 256
    saved_model = 'ocr/TPS-ResNet-BiLSTM-Attn.pth'
    batch_max_length = 25
    imgH = 32
    imgW = 100
    rgb = False # See input_channel comment below
    character = '0123456789abcdefghijklmnopqrstuvwxyz'
    sensitive = True
    PAD = True
    Transformation = 'TPS'
    FeatureExtraction = 'ResNet'
    SequenceModeling = 'BiLSTM'
    Prediction = 'Attn'
    num_fiducial = 20
    input_channel = 1
    output_channel = 512
    hidden_size = 256
    num_class = 0
    
class OCRLoss(nn.Module):
    def __init__(self):
        super(OCRLoss, self).__init__()
        
        self.converter = AttnLabelConverter(opt.character)
        self.opt = opt()
        self.opt.num_class = len(self.converter.character)
        
        if self.opt.rgb:
            self.opt.input_channel = 3 # This breaks loading, input_channel has to be 1, or rgb false
        
        self.model = Model(opt)
        self.model = torch.nn.DataParallel(self.model).to(device)
        
        mappings = torch.load(opt.saved_model, map_location=device)
        self.model.load_state_dict(mappings)
        self.model.eval()
        self.criterion = nn.CrossEntropyLoss(ignore_index=True) # Unsure what setting for this
        
        
    def forward(self, image, label):
        batch_size = image.size(0)
        text = torch.LongTensor(batch_size, opt.batch_max_length +1).fill_(0).to(device)
        preds = self.model(image, text, isTrain=False)
        loss = self.criterion(
            preds.view(-1, preds.shape[-1]), 
            label.contiguous().view(-1)
        )
        
        return loss

# 1. Style Encoder
This is a homebrew version of ResNet34 architecture,I have no idea if this works, but I am prepared to alter it in the events it fails. 

In [68]:
# class ResidualBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ResidualBlock, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
#         self.bn1 = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU()
#         self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3)
#         self.bn2 = nn.BatchNorm2d(out_channels)
#         
#         self.shortcut = nn.Sequential()
#         if in_channels != out_channels:
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channels, out_channels, kernel_size=1),
#                 nn.BatchNorm2d(out_channels)
#             )
#         
#     def forward(self, x):
#         out = self.relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += self.shortcut(x)
#         out = self.relu(out)
#         return out

In [69]:
# from torch import nn
# import torchvision
# class StyleEncoder(nn.Module):
#     def __init__(self):
#         super(StyleEncoder, self).__init__()
#         self.layer_stack = nn.Sequential(
#             nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3), #256x256
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(32),
#             nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(64),
#             nn.MaxPool2d(kernel_size=2, stride=2), #128x128
#             self._make_layer(ResidualBlock, 64, 128, stride=2),
#             nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(128),
#             nn.MaxPool2d(kernel_size=2, stride=2), #64x64
#             self._make_layer(ResidualBlock, 128, 256, stride=2),
#             nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(256),
#             nn.MaxPool2d(kernel_size=2, stride=2), #32x32
#             self._make_layer(ResidualBlock, 256, 512, stride=2),
#             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(512),
#             nn.MaxPool2d(kernel_size=2, stride=2), #16x16
#             self._make_layer(ResidualBlock, 512, 512, stride=2),
#             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1),
#         )
#         
#     def _make_layer(self, block: ResidualBlock, in_channels, out_channels, num_blocks, stride):
#         strides = [stride] + [1]*(num_blocks-1)
#         layers = []
#         for stride in strides:
#             layers.append(block(in_channels, out_channels))
#             self.in_channels = out_channels
#         return nn.Sequential(*layers)
#     
#     def forward(self, x):
#         return torchvision.ops.roi_align(input=self.layer_stack(x)) #1x1

# 2. Content Encoder

In [70]:
# from torch import nn
# import torchvision
# class ContentEncoder(nn.Module):
#     def __init__(self):
#         super(ContentEncoder, self).__init__()
#         self.layer_stack = nn.Sequential(
#             nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3), #256x256
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(32),
#             nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(64),
#             nn.MaxPool2d(kernel_size=2, stride=2), #128x128
#             self._make_layer(self, ResidualBlock, 64, 128, stride=2),
#             nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(128),
#             nn.MaxPool2d(kernel_size=2, stride=2), #64x64
#             self._make_layer(self, ResidualBlock, 128, 256, stride=2),
#             nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(256),
#             nn.MaxPool2d(kernel_size=2, stride=2), #32x32
#             self._make_layer(self, ResidualBlock, 256, 512, stride=2),
#             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
#             nn.ReLU(inplace=True),
#             nn.BatchNorm2d(512),
#             nn.MaxPool2d(kernel_size=2, stride=2), #16x16
#             self._make_layer(self, ResidualBlock, 512, 512, stride=2),
#             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1),
#         )
#         
#     def _make_layer(self, block, in_channels, out_channels, num_blocks, stride):
#         strides = [stride] + [1]*(num_blocks-1)
#         layers = []
#         for stride in strides:
#             layers.append(block(in_channels, out_channels, stride))
#             self.in_channels = out_channels
#         return nn.Sequential(*layers)
#     
#     def forward(self, x):
#         return self.layer_stack(x)

In [71]:
from torchvision.ops import roi_align
from torchvision import models
from torchvision.models.resnet import BasicBlock

class ContentEncoder(models.ResNet):
    def __init__(self):
        # resnet18 init
        super().__init__(BasicBlock, [2, 2, 2, 2])

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        #x = self.avgpool(x)
        #x = torch.flatten(x, 1)
        #x = self.fc(x)

        return x


class StyleEncoder(models.ResNet):
    def __init__(self):
        # resnet18 init
        super().__init__(BasicBlock, [2, 2, 2, 2])
        self.fc = torch.nn.Identity()

# 3. Style Mapping Network
- Converts style representations to layer-specific style representations which are then fed as AdaIN normalization coefficient to each layer of the generator. 
-  eliminate the use of the noise vector input of the standard StyleGAN2

In [72]:
import functools
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

# Training

In [73]:
class StyleGanTrainer:
    def __init__(self,
                 model_G: nn.Module,
                 model_D: nn.Module,
                 style_encoder: nn.Module,
                 content_encoder: nn.Module,
                 optimizer_G: torch.optim.Optimizer,
                 optimizer_D: torch.optim.Optimizer,
                 scheduler_G: torch.optim.lr_scheduler.LRScheduler,
                 scheduler_D: torch.optim.lr_scheduler.LRScheduler,
                 train_dataloader: torch.utils.data.DataLoader,
                 val_dataloader: torch.utils.data.DataLoader,
                 total_epochs: int,
                 ocr_loss: nn.Module,
                 # typeface_loss: nn.Module,
                 perc_loss: nn.Module,
                 cons_loss: nn.Module,
                 adv_loss: nn.Module,
                 device: str):
        
        self.style_encoder = style_encoder,
        self.content_encoder = content_encoder,
        self.model_G = model_G,
        self.model_D = model_D,
        self.optimizer_G = optimizer_G,
        self.optimizer_D = optimizer_D,
        self.scheduler_G = scheduler_G,
        self.scheduler_D = scheduler_D,
        self.train_dataloader = train_dataloader,
        self.val_dataloader = val_dataloader,
        self.device = device
        self.total_epochs = total_epochs,
        # self.typeface_loss = typeface_loss.to(device)
        self.ocr_loss = ocr_loss.to(device)
        self.perc_loss = perc_loss.to(device)
        self.cons_loss = cons_loss.to(device)
        self.adv_loss = adv_loss.to(device)
        
    def set_requires_grad(self, net: nn.Module, requires_grad: bool=False):
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad
                
    def model_D_loss(self, style_img: torch.Tensor, content_encodes: torch.Tensor, style_encodes: torch.Tensor):
        pred_D_fake = self.model_D(self.model_G(content_encodes, style_encodes).detach())
        pred_D_real = self.model_D(style_img)
        fake = torch.tensor(0.).expand_as(pred_D_fake).to(self.device)
        real = torch.tensor(1.).expand_as(pred_D_real).to(self.device)
        return (self.adv_loss(pred_D_real, real) + self.adv_loss(pred_D_fake, fake)) / 2.
    
    def model_G_loss(self, preds: torch.Tensor):
        pred_D_fake = self.model_D(preds)
        valid = torch.tensor(1.).expand_as(pred_D_fake).to(self.device)
        return self.adv_loss(pred_D_fake, valid)
                
    def train(self):
        print("Start training...")
        self.model_G.train()
        self.model_D.train()
        self.style_encoder.train()
        self.content_encoder.train()
        
        for style_img, desired_content, desired_labels, style_content, style_labels in self.train_dataloader:
            if max(len(label) for label in desired_labels) > 25:
                continue
            if max(len(label) for label in style_labels) > 25:
                continue
            
            self.optimizer_G.zero_grad()
            self.optimizer_D.zero_grad()
            
            style_img = style_img.to(self.device)
            desired_content = desired_content.to(self.device)
            style_content = style_content.to(self.device)
            
            style_encodes = self.style_encoder(style_img)
            content_encodes = self.content_encoder(desired_content)
            style_content_encodes = self.content_encoder(style_content)
            
            ## Calculate D loss
            self.set_requires_grad(self.model_D, True)
            self.set_requires_grad(self.model_G, False)
            loss_D = self.model_D_loss(style_img, style_content_encodes, style_encodes)
            
            ## Calculate G loss
            self.set_requires_grad(self.model_D, False)
            self.set_requires_grad(self.model_G, True)
            
            preds = self.model_G(content_encodes, style_encodes)
            ocr_loss = self.ocr_loss(preds, desired_labels)
            
            reconstructed = self.model_G(style_content_encodes, style_encodes)
            reconstructed_loss = self.cons_loss(style_img, reconstructed)
            
            reconstructed_style_encode = self.style_encoder(reconstructed)
            cycled = model_G(style_content_encodes, reconstructed_style_encode)
            cycled_loss = self.cons_loss(style_img, cycled)
            
            ocr_loss_rec = self.ocr_loss(reconstructed, style_labels)
            ocr_loss_total = (ocr_loss + ocr_loss_rec) / 2.
            
            perc_loss, tex_loss = self.perc_loss(style_img, reconstructed)
            enc_loss = 0  # self.typeface_loss(style_img, reconstructed)
            
            adv_loss = self.model_G_loss(reconstructed)
            
            loss_G = 0.07 * ocr_loss_total + 2.0 * cycled_loss + 2.0 * reconstructed_loss + 25.0 * perc_loss + 7.0 * tex_loss + 0 * enc_loss + 0.06 * adv_loss
            
            self.set_requires_grad(self.model_D, True)
            
            loss_G.backward()
            loss_D.backward()
            
            self.optimizer_D.step()
            self.optimizer_G.step()
            
    def validate(self, epoch: int):
        print("Start validating...")
        self.model_G.eval()
        self.model_D.eval()
        self.style_encoder.eval()
        self.content_encoder.eval()
        
        for style_img, desired_content, desired_labels, style_content, style_labels in self.val_dataloader:
            if max(len(label) for label in desired_labels) > 25:
                continue
            if max(len(label) for label in style_labels) > 25:
                continue
            
            self.optimizer_G.zero_grad()
            self.optimizer_D.zero_grad()
            
            style_img = style_img.to(self.device)
            desired_content = desired_content.to(self.device)
            style_content = style_content.to(self.device)
            style_encodes = self.style_encoder(style_img)
            content_encodes = self.content_encoder(desired_content)
            
            preds = self.model_G(content_encodes, style_encodes)
            ocr_loss = self.ocr_loss(preds, desired_labels)
            
            style_label_encodes = self.content_encoder(style_labels)
            
            reconstructed = model_G(style_label_encodes, style_encodes)
            reconstructed_loss = self.cons_loss(style_img, reconstructed)
            
            reconstructed_style_encode = self.style_encoder(reconstructed)
            cycle = model_G(style_label_encodes, reconstructed_style_encode)
            cycle_loss = self.cons_loss(style_img, cycle)
            
            ocr_loss_rec = self.ocr_loss(reconstructed, style_labels)
            ocr_loss_total = (ocr_loss + ocr_loss_rec) / 2.
            
            perc_loss, tex_loss = self.perc_loss(style_img, preds)
            enc_loss = 0 # self.typeface_loss(style_img, preds)
            adv_loss = self.model_G_loss(reconstructed)
            
            loss = 0.07 * ocr_loss_total + 2.0 * cycle_loss + 2.0 * reconstructed_loss + 25.0 * perc_loss + 7.0 * tex_loss + 0 * enc_loss + 0.06 * adv_loss
                
    def run(self):
        for epoch in range(self.total_epochs):
            self.train()
            with torch.no_grad():
                self.validate(epoch)
            if self.scheduler_G is not None:
                self.scheduler_G.step()
            if self.scheduler_D is not None:
                self.scheduler_D.step()

In [74]:
from stylegan import StyleBased_Generator

device = "cuda" if torch.cuda.is_available() else "cpu"
total_epochs = 500

batch_size = 16
train_dataloader = DataLoader(ImageDataset(train_dir, train_dir+'/words.json'), shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(ImageDataset(val_dir, val_dir+'/words.json'), batch_size=batch_size)

model_G = StyleBased_Generator(dim_latent=512)
model_G.to(device)
style_encoder = StyleEncoder().to(device)
content_encoder = ContentEncoder().to(device)
model_D = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=(lambda x : torch.nn.Identity()))
model_D.to(device)

optimizer_G = torch.optim.AdamW(
    list(model_G.parameters()) + 
    list(style_encoder.parameters()) +
    list(style_encoder.parameters()),
    lr=1e-3,
    weight_decay=1e-6
)
scheduler_G = torch.optim.lr_scheduler.ExponentialLR(
    optimizer_G,
    gamma=0.9
)
optimizer_D = torch.optim.AdamW(model_D.parameters(), lr=1e-4)
scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
    optimizer_D,
    gamma=0.9
)

trainer = StyleGanTrainer(  
    model_G,
    model_D,
    style_encoder,
    content_encoder,
    optimizer_G,
    optimizer_D,
    scheduler_G,
    scheduler_D,
    train_dataloader,
    val_dataloader,
    total_epochs=total_epochs,
    ocr_loss=OCRLoss(),
    perc_loss=VGGPerceptualLoss(),
    cons_loss=torch.nn.L1Loss(),
    adv_loss=torch.nn.MSELoss(),
    device=device,
)

trainer.run()

C:\Users\Ryan Chin\Documents\Projects\AI_dataset\cropped\train
C:\Users\Ryan Chin\Documents\Projects\AI_dataset\cropped\train




FileNotFoundError: [Errno 2] No such file or directory: 'ocr/TPS-ResNet-BiLSTM-Attn.pth'