In [1]:
import os
from tqdm.auto import tqdm
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from model import PretrainedGenerator, ResnetEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Dataset(Dataset):
    def __init__(self, hq_path, lq_path, lq_transforms, hq_transforms):
        super(Dataset).__init__()
        self.hq_path = hq_path
        self.hq = os.listdir(hq_path)
        self.lq_path = lq_path
        self.lq = os.listdir(lq_path)
        self.lq_transforms = lq_transforms
        self.hq_transforms = hq_transforms
    
    def __len__(self):
        return len(self.hq)
        
    def __getitem__(self, idx):
        lq_image = Image.open(os.path.join(self.lq_path, self.lq[idx])).convert('RGB')
        hq_image = Image.open(os.path.join(self.hq_path, self.hq[idx])).convert('RGB')
        hq_image = self.hq_transforms(hq_image)
        lq_image = self.lq_transforms(lq_image)
        return lq_image, hq_image                        


hq_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

lq_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_ds = Dataset("hq_train_images", "lq_train_images", lq_transform, hq_transform)
train_dl = DataLoader(train_ds, shuffle = True, batch_size = 4, num_workers = 4)

In [3]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f5e615ddd90>

In [None]:
downscale = transforms.Resize((256, 256))

def l_rep_loss(hq_image, lq_image, encoder_hq, encoder_lq, generator):
    hq_code = encoder_hq(hq_image)
    lq_code = encoder_lq(lq_image)
    hq_reconstructed = downscale(generator(hq_code))
    lq_reconstructed = downscale(generator(lq_code))
    
    loss_hq = nn.L1Loss()(hq_image, hq_reconstructed)
    loss_lq = nn.L1Loss()(lq_image, lq_reconstructed)
    return loss_hq + loss_lq

# Initialize models
encoder_hq = ResnetEncoder().cuda()
encoder_lq = ResnetEncoder().cuda()
generator = PretrainedGenerator(model_path="stylegan2_pytorch/G.pth").cuda()
for param in generator.parameters():
    param.requires_grad = False

# Set models to training mode
encoder_hq.train()
encoder_lq.train()
generator.eval()  # Generator is pretrained and frozen

# Optimizer
optimizer = optim.Adam(list(encoder_hq.parameters()) + list(encoder_lq.parameters()), lr=1e-4)

# Training loop
num_epochs = 6
for epoch in range(num_epochs):
    running_loss = 0.0
    for lq_image, hq_image in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs}"):
        lq_image, hq_image = lq_image.cuda(), hq_image.cuda()

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Compute L_rep loss
        loss = l_rep_loss(hq_image, lq_image, encoder_hq, encoder_lq, generator)

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_dl):.4f}")

  return _deserialize(torch.load(fpath, map_location=map_location))
Epoch 1/6:   1%|▏         | 27/1873 [14:29<16:25:51, 32.04s/it]

In [None]:
torch.save(encoder_hq.state_dict(), "encoder_hq.pt")
torch.save(encoder_lq.state_dict(), "encoder_lq.pt")