In [1]:
import os
os.chdir('../')

In [2]:
from Losses.AdversarialLoss import calc_Dw_loss
from Models.Encoders.ID_Encoder import ID_Encoder
from Models.Encoders.Attribute_Encoder import Encoder_Attribute
from Models.Discrimanator import Discriminator
from Models.LatentMapper import LatentMapper
import torch
import torch.utils.data
import torchvision.datasets as dset
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import seaborn as sns
sns.set()


In [3]:
def plot_w_image(w):
    w = w.unsqueeze(0).cuda()
    sample, latents = generator(
      [w], input_is_latent=True, return_latents=True
    )
    new_image = sample.cpu().detach().numpy().transpose(0,2,3,1)[0]
    new_image = (new_image + 1) / 2
    plt.axis('off')
    plt.imshow(new_image)
    plt.show()

In [4]:
torch.cuda.empty_cache()
import gc
gc.collect()
import sys
sys.path.append(".")
sys.path.append("..")
from Models.StyleGan2.model import Generator

generator = Generator(1024,512,8).cuda()
state_dict = torch.load('./pretrained/800000.pt')
generator.load_state_dict(state_dict['g_ema'], strict=False)
generator = generator.eval()

In [5]:
E_id = ID_Encoder().cuda()
E_att = Encoder_Attribute().cuda()
discriminator = Discriminator().cuda()
mlp = LatentMapper().cuda()



In [6]:
E_id = E_id.eval()
E_att = E_att.eval()
discriminator = discriminator.train()
mlp = mlp.train()

In [7]:
def get_w_by_index(idx, root_dir = r"./pretrained/Dataset/small_w/0/"):
    if torch.is_tensor(idx):
        idx = idx.tolist()

    dir_idx = idx // 1000

    w_path = os.path.join(root_dir, str(dir_idx),str(idx)+ ".npy")
    w = np.load(w_path)

    return torch.tensor(w)

In [8]:
class WDataSet(Dataset):
    def __init__(self,root_dir):
        """
        Args:
            root_dir (string): Directory with all the w's.
        """
        self.root_dir = root_dir

    def __len__(self):
        ## TODO: Change
        return 6999

    def __getitem__(self, idx):
        return get_w_by_index(idx, self.root_dir)

In [9]:
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [10]:
data_dir = r"./pretrained/Dataset/image_med_res/"
attr_dataset = dset.ImageFolder(root=data_dir,
                                transform=transforms.Compose([
                                transforms.Resize(299),
                                transforms.CenterCrop(299),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                ]))

id_dataset = dset.ImageFolder(root=data_dir,
                                transform=transforms.Compose([
                                transforms.Resize(299),
                                transforms.CenterCrop(299),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                ]))


w_dataset = WDataSet(r"./pretrained/Dataset/w_med_res/")

In [11]:
def make_concat_loaders(batch_size, datasets):
    full_dataset = ConcatDataset(datasets)

    train_loader = torch.utils.data.DataLoader(dataset=full_dataset,
                                              batch_size=batch_size, shuffle = True)

    return train_loader

In [12]:
config = {
    'beta1' : 0.5,
    'beta2' : 0.999,
    'lrD' : 0.0004,
    'lrMLP' : 0.00003,
    'lrAttr' : 0.0001,
    'IdDiffersAttrTrainRatio' : 3, # 1/3
    'batchSize' : 8,
    'R1Param' : 10,
    'lambdaID' : 1,
    'lambdaLND' : 1,
    'lambdaREC' : 1
}

In [13]:
train_loader = make_concat_loaders(config['batchSize'],(id_dataset, attr_dataset,w_dataset))

In [14]:
discriminator = Discriminator().cuda()
mlp = LatentMapper().cuda()
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=config['lrD'], betas=(config['beta1'], config['beta2']))
optimizerMLP = torch.optim.Adam(mlp.parameters(), lr=config['lrMLP'], betas=(config['beta1'], config['beta2']))

In [15]:
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=config['lrD'], betas=(config['beta1'], config['beta2']))
optimizerMLP = torch.optim.Adam(mlp.parameters(), lr=config['lrMLP'], betas=(config['beta1'], config['beta2']))

In [16]:
def train_discriminator(optimizer, real_w, generated_w):

    optimizer.zero_grad()

    # 1.1 Train on Real Data
    prediction_real = discriminator(real_w).view(-1)
    # Calculate error and backpropagate
    error_real = calc_Dw_loss(prediction_real, 1)
    error_real.backward()

    generated_w = generated_w.clone().detach()
    # 1.2 Train on Fake Data
    prediction_fake = discriminator(generated_w).view(-1)
    # Calculate error and backpropagate
    error_fake = calc_Dw_loss(prediction_fake, 0)

    error_fake.backward()


    # 1.3 Update weights with gradients
    optimizer.step()

    # Return error and predictions for real and fake inputs
    # return error_real + error_fake, prediction_real, prediction_fake
    return error_real, prediction_real, error_fake, prediction_fake

In [17]:
def train_mapper(optimizer, generated_w):
  
    optimizer.zero_grad()
    prediction = discriminator(generated_w).view(-1)
    # Calculate error and backpropagate
    error = calc_Dw_loss(prediction, 1)
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error, prediction

In [18]:
for idx, data in enumerate(train_loader):

  id_images, attr_images, ws = data

  torch.cuda.empty_cache()
  id_images = id_images[0].cuda()
  attr_images = attr_images[0].cuda()
  ws_single = ws.cuda()

  if idx % config['IdDiffersAttrTrainRatio'] == 0:
    different_attr_images = torch.empty_like(attr_images, device='cuda')
    different_attr_images[0] = attr_images[7]
    different_attr_images[1:] = attr_images[:7]
    attr_images = different_attr_images
  
  with torch.no_grad():
    id_vec = E_id(id_images)
    attr_vec = E_att(attr_images)
    # different image to id and attr
    id_vec = torch.squeeze(id_vec)
    attr_vec = torch.squeeze(attr_vec)
    encoded_vec = torch.cat((id_vec,attr_vec), dim=1)
  test_vec = encoded_vec
  break

In [19]:
MLP_losses = []
D_losses = []

# Training only the mapper and discriminator

In [20]:
####### Discriminator back pass #######
epochs = 4
for epoch in range(epochs):
  for idx, data in enumerate(train_loader):

    id_images, attr_images, ws = data

    torch.cuda.empty_cache()
    id_images = id_images[0].cuda()
    attr_images = attr_images[0].cuda()
    ws = ws.cuda()

    if idx % config['IdDiffersAttrTrainRatio'] == 0:
      different_attr_images = torch.empty_like(attr_images, device='cuda')
      different_attr_images[0] = attr_images[7]
      different_attr_images[1:] = attr_images[:7]
      attr_images = different_attr_images
    
    with torch.no_grad():
      id_vec = E_id(id_images)

    attr_vec = E_att(attr_images)
    id_vec = torch.squeeze(id_vec)
    attr_vec = torch.squeeze(attr_vec)
    encoded_vec = torch.cat((id_vec,attr_vec), dim=1)

    fake_data = mlp(encoded_vec)
    error_real, prediction_real, error_fake, prediction_fake = train_discriminator(optimizerD, ws, fake_data)
    print(f"\n error_real: {error_real}, error_fake: {error_fake} \n prediction_real: {torch.mean(prediction_real)}, prediction_fake: {torch.mean(prediction_fake)}")
    g_error, g_pred = train_mapper(optimizerMLP, fake_data)
    print(f"\n g_error: {g_error}, g_pred: {torch.mean(g_pred)}")

    MLP_losses.append(g_error)
    D_losses.append((error_real + error_fake) /2)

    if idx % 5 == 0:
        with torch.no_grad():
          plot_w_image(mlp(test_vec)[0])



/opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [0,0,0], thread: [0,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [0,0,0], thread: [2,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [0,0,0], thread: [4,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [0,0,0], thread: [6,0,0] Assertion `input_val >= zero && input_val <= one` failed.
/opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cuda/Loss.cu:92: operator(): block: [0,0,0], thread: [7,0,0] Assertion `input_val >= zero && input_val <= one` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
plt.figure(figsize=(10,5))
plt.title("Mapper and Discriminator Loss During Training")
plt.plot(MLP_losses,label="MLP")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()