In [6]:
%load_ext autoreload
%autoreload 2

# import cv2
import sys
from IPython.display import display, Markdown
import plotly.express as px
from tensorboardX import SummaryWriter

import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torchvision.io import read_image

sys.path.append('../../')
sys.path.append('../')
from scikitools.deep_learning.dataload import LocalImageLoader
from src.celeb_a import CELEBADiscriminator, CELEBAGenerator
from src.generator import (
    generate_random_image,
    generate_random_seed
)

device = torch.device("cuda")
torch.cuda.is_available()
writer = SummaryWriter("celebA-GAN-log")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
celeb_datasets = LocalImageLoader(
    '../../../../../datasets/celeba/img_align_celeba/img_align_celeba',
)

In [3]:
celeb_datasets.plot_image(0)

In [4]:
celeb_datasets[0][0].contiguous().view((218*178*3))

tensor([253, 253, 253,  ...,  20,  24,  24], dtype=torch.uint8)

## 测试鉴别器

In [26]:
d = CELEBADiscriminator()
d.to(device)

for image_data_tensor, _ in celeb_datasets:
    d.train(image_data_tensor.cuda().float()/255., torch.cuda.FloatTensor([1.0]))
    d.train(
        generate_random_image((218, 178, 3)).cuda(),
        torch.FloatTensor([0.0]).cuda()
    )

counter = 10000
counter = 20000
counter = 30000
counter = 40000
counter = 50000
counter = 60000
counter = 70000
counter = 80000
counter = 90000
counter = 100000
counter = 110000
counter = 120000
counter = 130000
counter = 140000
counter = 150000
counter = 160000
counter = 170000
counter = 180000
counter = 190000
counter = 200000
counter = 210000
counter = 220000
counter = 230000
counter = 240000
counter = 250000
counter = 260000
counter = 270000
counter = 280000
counter = 290000
counter = 300000
counter = 310000
counter = 320000
counter = 330000
counter = 340000
counter = 350000
counter = 360000
counter = 370000
counter = 380000
counter = 390000
counter = 400000


In [27]:
d.plot_progress()

## profiling model

In [25]:
with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        d.train(image_data_tensor.cuda().float()/255., torch.cuda.FloatTensor([1.0]))
        d.train(
            generate_random_image((218, 178, 3)).cuda(),
            torch.FloatTensor([0.0]).cuda()
        )
display(Markdown(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)))


CUPTI tracing is not available, falling back to legacy CUDA profiling



------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           model_inference        20.92%       2.967ms        89.25%      12.656ms      12.656ms       2.966ms        15.69%      16.039ms      16.039ms             1  
                  Optimizer.step#Adam.step         6.16%     874.000us        16.16%       2.292ms       1.146ms     121.000us         0.64%       8.523ms       4.261ms             2  
                                aten::add_         3.20%     454.000us         3.20%     454.000us      11.350us       2.973ms        15.72%       2.973ms      74.325us            40  
                                  aten::to         0.58%      82.000us        29.96%       4.249ms     708.167us      17.000us         0.09%       2.086ms     347.667us             6  
                               aten::copy_        29.28%       4.152ms        29.28%       4.152ms     519.000us       2.077ms        10.98%       2.077ms     259.625us             8  
                            aten::addcmul_         0.56%      79.000us         1.65%     234.000us      19.500us      23.000us         0.12%       1.661ms     138.417us            12  
                             aten::addcmul         1.09%     155.000us         1.09%     155.000us      12.917us       1.638ms         8.66%       1.638ms     136.500us            12  
                                aten::mul_         2.12%     301.000us         2.12%     301.000us      12.542us       1.530ms         8.09%       1.530ms      63.750us            24  
                            aten::addcdiv_         0.68%      97.000us         1.66%     236.000us      19.667us      24.000us         0.13%       1.488ms     124.000us            12  
                             aten::addcdiv         0.98%     139.000us         0.98%     139.000us      11.583us       1.464ms         7.74%       1.464ms     122.000us            12  
------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 14.181ms
Self CUDA time total: 18.909ms


## Generator测试

In [56]:
g = CELEBAGenerator()
g.to(device)

output = g.forward(generate_random_seed(100).cuda())
img = output.detach().cpu().numpy()
px.imshow(img * 255.)

## 训练GAN的总代码

In [5]:
d = CELEBADiscriminator()
g = CELEBAGenerator()

d.to(device)
g.to(device)

epochs = 1

for epoch in range(epochs):
    print(f"epcho = {epoch+1}")
    for i, image_data_tensor_with_label in enumerate(celeb_datasets):
        image_data_tensor, _ = image_data_tensor_with_label
        d.train(image_data_tensor.cuda().float()/255., torch.cuda.FloatTensor([1.0]))
        d.train(
            g.forward(generate_random_seed(100).cuda()).detach(),
            torch.FloatTensor([0.0]).cuda()
        )
        g.train(
            d,
            generate_random_seed(100).cuda(),
            torch.FloatTensor([1.0]).cuda()
        )
        writer.add_scalar('Discriminator/Loss', d.temp_loss, g.counter)
        writer.add_scalar('Generator/loss', g.temp_loss, g.counter)

epcho = 1
counter = 10000
counter = 20000
counter = 30000
counter = 40000
counter = 50000
counter = 60000
counter = 70000
counter = 80000
counter = 90000
counter = 100000
counter = 110000
counter = 120000
counter = 130000
counter = 140000


KeyboardInterrupt: 