In [1]:
import numpy as np
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import imageio 
import glob

In [1]:
# pip install torchvision==0.14.1

In [2]:
# pip install tensorboard==2.11.1

In [3]:
import datetime
import time
import torchvision
from torch.utils.tensorboard import SummaryWriter
from data.dataloader import create_dataloader
from models.pose_transfer_model import PoseTransferModel

In [4]:
import datetime
from torch.utils.tensorboard import SummaryWriter

In [5]:
# configurations
# -----------------------------------------------------------------------------
root_path = '/home/ec2-user/SageMaker/pose-transfer-apts'
dataset_name = 'deepfashion_full'

dataset_root = f'{root_path}/datasets/{dataset_name}'
img_pairs_train = f'{dataset_root}/train_img_pairs.csv'
img_pairs_test = f'{dataset_root}/test_img_pairs.csv'
pose_maps_dir_train = f'{dataset_root}/train_pose_maps'
pose_maps_dir_test = f'{dataset_root}/test_pose_maps'

In [6]:
gpu_ids = None #[0]

batch_size_train = 8
batch_size_test = 8
n_epoch = 100
out_freq = 500

ckpt_id = None
ckpt_dir = None

run_info = ''
out_path = f'{root_path}/output/{dataset_name}'

In [7]:
# create timestamp and infostamp
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
infostamp = f'_{run_info.strip()}' if run_info.strip() else ''

# create tensorboard logger
logger = SummaryWriter(f'{out_path}/runs/{timestamp}{infostamp}')

# create transforms
img_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
map_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# create dataloaders
train_dataloader = create_dataloader(dataset_root, img_pairs_train, pose_maps_dir_train,
                                     img_transform, map_transform,
                                     batch_size=batch_size_train, shuffle=True)
test_dataloader = create_dataloader(dataset_root, img_pairs_test, pose_maps_dir_test,
                                    img_transform, map_transform,
                                    batch_size=batch_size_test, shuffle=False)

# create fixed batch for testing
fixed_test_batch = next(iter(test_dataloader))

# create model
model = PoseTransferModel(gpuids=gpu_ids)
model.print_networks(verbose=False)

# load pretrained weights into model
if ckpt_id and ckpt_dir:
    model.load_networks(ckpt_dir, ckpt_id, verbose=True)

# train model
n_batch = len(train_dataloader)
w_batch = len(str(n_batch))
w_epoch = len(str(n_epoch))
n_iters = 0

pthA:  img/WOMEN/Blouses_Shirts/id_00003372/03_3_back.jpg
pthA:  img/MEN/Tees_Tanks/id_00007301/04_7_additional.jpg
pthA:  img/WOMEN/Shorts/id_00005138/03_4_full.jpg
pthA:  img/WOMEN/Dresses/id_00007132/04_4_full.jpg
pthA:  img/WOMEN/Blouses_Shirts/id_00001530/01_3_back.jpg
pthA:  img/WOMEN/Blouses_Shirts/id_00007873/01_7_additional.jpg
pthA:  img/WOMEN/Jackets_Coats/id_00004502/02_1_front.jpg
pthA:  img/WOMEN/Blouses_Shirts/id_00004556/02_7_additional.jpg
[INFO] Using device: CPU
[INFO] Network netG initialized
[INFO] Network netD initialized
--------------------------------------------------------------------------------
[INFO] Total parameters of network netG: 90.17M
[INFO] Total parameters of network netD: 2.77M
--------------------------------------------------------------------------------


In [None]:
for epoch in range(n_epoch):
    for batch, data in enumerate(train_dataloader):
        time_0 = time.time()
        model.set_inputs(data)
        model.optimize_parameters()
        losses = model.get_losses()
        loss_G = losses['lossG']
        loss_D = losses['lossD']
        time_1 = time.time()
        print(f'[TRAIN] Epoch: {epoch+1:{w_epoch}d}/{n_epoch} | Batch: {batch+1:{w_batch}d}/{n_batch} |',
              f'LossG: {loss_G:7.4f} | LossD: {loss_D:7.4f} | Time: {round(time_1-time_0, 2):.2f} sec |')
        
        if (n_iters % out_freq == 0) or (batch+1 == n_batch and epoch+1 == n_epoch):
            model.save_networks(f'{out_path}/ckpt/{timestamp}{infostamp}', n_iters, verbose=True)
            for loss_name, loss in losses.items():
                loss_group = 'LossG' if loss_name.startswith('lossG') else 'LossD'
                logger.add_scalar(f'{loss_group}/{loss_name}', loss, n_iters)
            model.set_inputs(fixed_test_batch)
            visuals = model.compute_visuals()
            logger.add_image(f'Iteration_{n_iters}', visuals, n_iters)
        
        n_iters += 1


pthA:  img/WOMEN/Blouses_Shirts/id_00003299/01_1_front.jpg
pthA:  img/WOMEN/Cardigans/id_00006702/07_3_back.jpg
pthA:  img/WOMEN/Blouses_Shirts/id_00007060/01_4_full.jpg
pthA:  img/WOMEN/Tees_Tanks/id_00001212/06_1_front.jpg
pthA:  img/MEN/Jackets_Vests/id_00002754/01_2_side.jpg
pthA:  img/WOMEN/Graphic_Tees/id_00002865/01_2_side.jpg
pthA:  img/WOMEN/Tees_Tanks/id_00002957/04_7_additional.jpg
pthA:  img/WOMEN/Tees_Tanks/id_00001665/08_2_side.jpg
[TRAIN] Epoch:   1/100 | Batch:     1/12746 | LossG: 68.5791 | LossD:  1.9068 | Time: 48.20 sec |
[INFO] Network netG weights saved to /home/ec2-user/SageMaker/pose-transfer-apts/output/deepfashion_full/ckpt/2023-04-21-14-22-38/netG_0.pth
[INFO] Network netD weights saved to /home/ec2-user/SageMaker/pose-transfer-apts/output/deepfashion_full/ckpt/2023-04-21-14-22-38/netD_0.pth
pthA:  img/WOMEN/Dresses/id_00004588/03_2_side.jpg
pthA:  img/MEN/Tees_Tanks/id_00005703/01_4_full.jpg
pthA:  img/WOMEN/Tees_Tanks/id_00001567/03_1_front.jpg
pthA:  img/W