In [1]:
# %rm -r datasets

In [1]:
# !mkdir utils
# !mkdir networks
# !mkdir network_utils
# !mkdir datasets
# %cd datasets
# !git clone https://github.com/bearpaw/clothing-co-parsing.git
# %cd clothing-co-parsing/
# %cd annotations
# %mv pixel-level ../
# %cd ..
# %rm -r annotations
# %mv pixel-level ../
# %mv photos ../
# %cd ../..

In [1]:
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchsummary.torchsummary import summary
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import torch.nn as nn 
import numpy as np
import torch
import os

from utils.utils import load_specific_image, read_image, read_mask, add_mask
from utils.utils import Visualize, Fashion_swapper_dataset
from CONSTS import IMAGEPATH, MASKPATH

from networks.discriminator import PatchDiscriminator
from networks.generator import ResNetGenerator
# %matplotlib inline
# from IPython.display import clear_output

In [2]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [3]:
def set_requires_grad(nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

In [4]:
def createswapper_loader(image_path, mask_path, object_one=31, object_two=40):
    trans = transforms.Compose([transforms.Resize((300, 200), 2), transforms.ToTensor()])
    first_object = load_specific_image(IMAGEPATH, MASKPATH, objects=[object_one, object_two])
    
    dataset_one = Fashion_swapper_dataset(first_object, objects=[object_one, object_two], transform=trans)
    loader_one = DataLoader(dataset_one, batch_size=4, shuffle=True)
    return loader_one

In [7]:
def create_model(c_dim=4, g_conv_dim=64, d_conv_dim=64, n_res_blocks=6):
    G_XtoY = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
#     G_YtoX = ResNetGenerator(conv_dim=g_conv_dim, c_dim=c_dim, repeat_num=n_res_blocks)
    
    D_X = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)
#     D_Y = PatchDiscriminator(c_dim=c_dim, conv_dim=d_conv_dim)

    # move models to GPU, if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        G_XtoY.to(device)
        G_YtoX.to(device)
        D_X.to(device)
        D_Y.to(device)
        print('Models moved to GPU.')
    else:
        print('Only CPU available.')
    return G_XtoY, D_X
#     return G_XtoY, G_YtoX, D_X, D_Y

In [8]:
# G_AtoB, G_BtoA, D_A, D_B = create_model()
G_AtoB, D_A = create_model()
loader_pants = createswapper_loader(IMAGEPATH, MASKPATH)

  1%|          | 8/1004 [00:00<00:13, 74.41it/s]

Only CPU available.


100%|██████████| 1004/1004 [00:13<00:00, 76.71it/s]


In [9]:
summary(D_A, (4, 300, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 150, 100]           4,160
         LeakyReLU-2         [-1, 64, 150, 100]               0
      SpectralNorm-3          [-1, 128, 75, 50]               0
       BatchNorm2d-4          [-1, 128, 75, 50]             256
         LeakyReLU-5          [-1, 128, 75, 50]               0
      SpectralNorm-6          [-1, 256, 37, 25]               0
       BatchNorm2d-7          [-1, 256, 37, 25]             512
         LeakyReLU-8          [-1, 256, 37, 25]               0
      SpectralNorm-9          [-1, 512, 36, 24]               0
      BatchNorm2d-10          [-1, 512, 36, 24]           1,024
        LeakyReLU-11          [-1, 512, 36, 24]               0
           Conv2d-12            [-1, 1, 36, 24]           4,609
Total params: 10,561
Trainable params: 10,561
Non-trainable params: 0
---------------------------------

In [12]:
def merge_masks(segs):
    """Merge masks (B, N, W, H) -> (B, 1, W, H)"""
    ret = torch.sum((segs + 1)/2, dim=1, keepdim=True)  # (B, 1, W, H)
    return ret.clamp(max=1, min=0) * 2 - 1

def get_weight_for_ctx(x, y):
    """Get weight for context preserving loss"""
    z = merge_masks(torch.cat([x, y], dim=1))
    return (1 - z) / 2

def weighted_L1_loss(src, tgt, weight):
    """L1 loss with given weight (used for context preserving loss)"""
    return torch.mean(weight * torch.abs(src - tgt))
    
def real_mse_loss(D_out):
    return torch.mean((D_out-1)**2)

def fake_mse_loss(D_out):
    return torch.mean(D_out**2)

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight):
    reconstr_loss = torch.nn.L1Loss()
    return lambda_weight*reconstr_loss(real_im, reconstructed_im)

In [11]:
import torch.optim as optim
import itertools

lr_g=0.0003
lr_d=0.0001
beta1=0.5
beta2=0.999

# g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters

# Create optimizers for the generators and discriminators
# optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(
#     G_AtoB.parameters(), 
#     G_BtoA.parameters())),
#     lr=lr_g, 
#     betas=(beta1, beta2))
optimizer_G = torch.optim.Adam(
    G_AtoB.parameters(),
    lr=lr_g, 
    betas=(beta1, beta2))
# optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, itertools.chain(D_A.parameters(), 
#                                                                                  D_B.parameters())), 
#                                lr=lr_d, 
#                                betas=(beta1, beta2))

optimizer_D = torch.optim.Adam(D_A.parameters(),
                               lr=lr_d, 
                               betas=(beta1, beta2))

In [13]:
def scale(x, feature_range=(-1, 1)):
    ''' Scale takes in an image x and returns that image, scaled
       with a feature_range of pixel values from -1 to 1. 
       This function assumes that the input x is already scaled from 0-1.'''
    
    min, max = feature_range
    x = x * (max - min) + min
    return x

In [14]:
def split(x):
    """Split data into image and mask (only assume 3-channel image)"""
    return x[:, :3, :, :], x[:, 3:, :, :]

In [None]:
def training_loop(dataloader, test_dataloader, n_epochs):
    print_every=10
    vis = Visualize()
    for epoch in tqdm(range(1, n_epochs+1)):
        for object_A in tqdm(dataloader)

In [0]:
# def training_loop(dataloader, test_dataloader, n_epochs=1000):
    
#     print_every=10
#     losses = []

#     for epoch in tqdm(range(1, n_epochs+1)):
        
#         for object_A, object_B in tqdm(dataloader):

#             # =========================================
#             #            TRAIN THE GENERATORS
#             # =========================================
#             # torch.train()
#             device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#             object_A = object_A.to(device)
#             object_A = scale(object_A)
#             object_A.requires_grad_(True)
#             # print(object_A)
#             # print(object_B)
#             object_B = object_B.to(device)
#             object_B = scale(object_B)
#             object_B.requires_grad_(True)
            
#             fake_B = G_AtoB(object_A)
#             rec_A = G_BtoA(fake_B)
            
#             image_b_fake, mask_b_fake = split(fake_B)
#             image_a_rec, mask_a_rec = split(rec_A)
            
#             fake_A = G_BtoA(object_B)
#             rec_B = G_AtoB(fake_A)
            
#             image_a_fake, mask_a_fake = split(fake_A)
#             image_b_rec, mask_b_rec = split(rec_B)
            
#             set_requires_grad([D_A, D_B], False)
#             optimizer_G.zero_grad()
            
#             out_a = D_B(fake_B)
#             loss_G_A = real_mse_loss(out_a)
#             # return out_a, loss_G_A
#             loss_cyc_A = cycle_consistency_loss(object_A, rec_A, 10)
#             loss_idt_B = cycle_consistency_loss(G_BtoA(object_A), object_A.detach(), 10)
#             weight_A = get_weight_for_ctx(object_A[:,3:,:,:], mask_b_fake)
#             loss_ctx_A = weighted_L1_loss(object_A[:, :3,:,:], image_b_fake, weight=weight_A) * 10. * 10.

#             out_b = D_A(fake_A)
#             loss_G_B = real_mse_loss(out_a)
#             loss_cyc_B = cycle_consistency_loss(object_B, rec_B, 10)
#             loss_idt_A = cycle_consistency_loss(G_AtoB(object_B), object_B.detach(), 10)
#             weight_B = get_weight_for_ctx(object_B[:,3:,:,:], mask_a_fake)
#             loss_ctx_B = weighted_L1_loss(object_A[:, :3,:,:], image_b_fake, weight=weight_B) * 10. * 10.
#             # return loss_G_A , loss_cyc_A , loss_idt_B , loss_ctx_A , loss_G_B , loss_cyc_B , loss_idt_A , loss_ctx_B
            
#             g_total_loss = loss_G_A + loss_cyc_A + loss_idt_B + loss_ctx_A + loss_G_B + loss_cyc_B + loss_idt_A + loss_ctx_B
#             g_total_loss.backward()
#             optimizer_G.step()
            
#             # ============================================
#             #            TRAIN THE DISCRIMINATORS
#             # ============================================
            
#             set_requires_grad([D_A, D_B], True)
#             optimizer_D.zero_grad()
            
#             pred_real_B = D_B(object_B)
#             loss_D_real_B = real_mse_loss(object_B)
#             pred_fake_B = D_B(fake_B.detach())
#             loss_D_fake_B = fake_mse_loss(pred_fake_B)
#             loss_D_B = (loss_D_real_B + loss_D_fake_B) / 2
#             loss_D_B.backward()
            
#             pred_real_A = D_A(object_A)
#             loss_D_real_A = real_mse_loss(object_A)
#             pred_fake_A = D_A(fake_A.detach())
#             loss_D_fake_A = fake_mse_loss(pred_fake_A)
#             loss_D_A = (loss_D_real_A + loss_D_fake_A) / 2
#             loss_D_A.backward()
            
#             optimizer_D.step()
            
#         if epoch % print_every == 0:
#           # append real and fake discriminator losses and the generator loss
#           losses.append((loss_D_A.item(), loss_D_B.item(), g_total_loss.item()))
#           print('Epoch [{:5d}/{:5d}] | loss_D_A: {:6.4f} | loss_D_B: {:6.4f} | g_total_loss: {:6.4f}'.format(
#                   epoch, n_epochs, loss_D_A.item(), loss_D_B.item(), g_total_loss.item()))


#         sample_every=10
#         # Save the generated samples
#         if epoch % sample_every == 0:
#             G_BtoA.eval() # set generators to eval mode for sample generation
#             G_AtoB.eval()
#             A, B = next(iter(loader_pants))
#             device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#             A = A.to(device)
#             A = scale(object_A)
#             A.requires_grad_(True)

#             B = B.to(device)
#             B = scale(object_B)
#             B.requires_grad_(True)

#             save_image(G_AtoB(A)[0,:3,:,:], f'G_AtoB_{epoch}.jpg')
#             save_image(G_BtoA(B)[0,:3,:,:], f'G_BtoA_{epoch}.jpg')
#             G_BtoA.train()
#             G_AtoB.train()

#       # uncomment these lines, if you want to save your model
# #         checkpoint_every=1000
# #         # Save the model parameters
# #         if epoch % checkpoint_every == 0:
# #             checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

#     return losses

In [0]:
G_AtoB.eval()
A, B = next(iter(loader_pants))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
A = A.to(device)
A = scale(A)
A.requires_grad_(True)
res = G_AtoB(A)


In [0]:
res.shape
save_image(res[0,:3,:,:], "dfdf.jpg")

In [0]:
# loss_G_A , loss_cyc_A , loss_idt_B , loss_ctx_A , loss_G_B , loss_cyc_B , loss_idt_A , loss_ctx_B = training_loop(loader_pants, loader_pants)
loss = training_loop(loader_pants, loader_pants)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:28<00:06,  2.05s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:30<00:04,  2.06s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:32<00:02,  2.06s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:33<00:00,  1.75s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  0%|          | 1/1000 [02:33<42:42:16, 153.89s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:34,  2.06s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:32,  2.06s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:30,  2.06s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:28,  2.06s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.06s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.06s/it][A[A[A[A[A[A[A






  9%|▉         | 7

Epoch [   10/ 1000] | loss_D_A: 0.8770 | loss_D_B: 0.9243 | g_total_loss: 4.1687


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:29<00:06,  2.05s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:32<00:04,  2.05s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:34<00:02,  2.06s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:35<00:00,  1.75s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  1%|          | 11/1000 [28:26<42:37:43, 155.17s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:36,  2.09s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:34,  2.09s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:31,  2.08s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:29,  2.07s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:26,  2.07s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:24,  2.06s/it][A[A[A[A[A[A[A






  9%|▉         | 

Epoch [   20/ 1000] | loss_D_A: 0.8742 | loss_D_B: 0.9799 | g_total_loss: 3.5007


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:29<00:06,  2.07s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:32<00:04,  2.06s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:34<00:02,  2.06s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:35<00:00,  1.76s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  2%|▏         | 21/1000 [54:19<42:13:52, 155.29s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:34,  2.06s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:32,  2.06s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:30,  2.06s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:28,  2.06s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.06s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.05s/it][A[A[A[A[A[A[A






  9%|▉         | 

Epoch [   30/ 1000] | loss_D_A: 0.7472 | loss_D_B: 0.9618 | g_total_loss: 3.3890


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:29<00:06,  2.05s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:31<00:04,  2.05s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:34<00:02,  2.06s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:35<00:00,  1.76s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  3%|▎         | 31/1000 [1:20:11<41:45:33, 155.14s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:33,  2.04s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:31,  2.05s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:29,  2.05s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:27,  2.05s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.05s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.05s/it][A[A[A[A[A[A[A






  9%|▉         

Epoch [   40/ 1000] | loss_D_A: 0.7944 | loss_D_B: 0.9905 | g_total_loss: 3.2953


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:30<00:06,  2.05s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:32<00:04,  2.06s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:34<00:02,  2.05s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:35<00:00,  1.75s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  4%|▍         | 41/1000 [1:46:03<41:21:59, 155.29s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:33,  2.05s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:31,  2.05s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:29,  2.05s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:27,  2.05s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.05s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.05s/it][A[A[A[A[A[A[A






  9%|▉         

Epoch [   50/ 1000] | loss_D_A: 0.8198 | loss_D_B: 0.9119 | g_total_loss: 3.2176


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:29<00:06,  2.06s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:32<00:04,  2.06s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:34<00:02,  2.06s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:35<00:00,  1.76s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  5%|▌         | 51/1000 [2:11:54<40:54:00, 155.15s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:33,  2.05s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:31,  2.05s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:29,  2.05s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:27,  2.05s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.05s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.05s/it][A[A[A[A[A[A[A






  9%|▉         

Epoch [   60/ 1000] | loss_D_A: 1.1208 | loss_D_B: 0.9086 | g_total_loss: 3.1691


[1;30;43mStreaming output truncated to the last 5000 lines.[0m



 96%|█████████▌| 73/76 [02:29<00:06,  2.05s/it][A[A[A[A[A[A[A






 97%|█████████▋| 74/76 [02:31<00:04,  2.05s/it][A[A[A[A[A[A[A






 99%|█████████▊| 75/76 [02:33<00:02,  2.05s/it][A[A[A[A[A[A[A






100%|██████████| 76/76 [02:34<00:00,  1.75s/it][A[A[A[A[A[A[A






[A[A[A[A[A[A[A





  6%|▌         | 61/1000 [2:37:46<40:29:06, 155.21s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:34,  2.06s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:32,  2.06s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:30,  2.06s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:27,  2.05s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:25,  2.05s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:23,  2.05s/it][A[A[A[A[A[A[A






  9%|▉         

Epoch [   70/ 1000] | loss_D_A: 0.8011 | loss_D_B: 1.0138 | g_total_loss: 3.3275








  7%|▋         | 70/1000 [3:01:01<40:03:30, 155.07s/it][A[A[A[A[A[A






  0%|          | 0/76 [00:00<?, ?it/s][A[A[A[A[A[A[A






  1%|▏         | 1/76 [00:02<02:34,  2.05s/it][A[A[A[A[A[A[A






  3%|▎         | 2/76 [00:04<02:31,  2.05s/it][A[A[A[A[A[A[A






  4%|▍         | 3/76 [00:06<02:29,  2.05s/it][A[A[A[A[A[A[A






  5%|▌         | 4/76 [00:08<02:28,  2.06s/it][A[A[A[A[A[A[A






  7%|▋         | 5/76 [00:10<02:26,  2.06s/it][A[A[A[A[A[A[A






  8%|▊         | 6/76 [00:12<02:24,  2.06s/it][A[A[A[A[A[A[A






  9%|▉         | 7/76 [00:14<02:21,  2.05s/it][A[A[A[A[A[A[A






 11%|█         | 8/76 [00:16<02:19,  2.05s/it][A[A[A[A[A[A[A






 12%|█▏        | 9/76 [00:18<02:17,  2.05s/it][A[A[A[A[A[A[A






 13%|█▎        | 10/76 [00:20<02:15,  2.05s/it][A[A[A[A[A[A[A






 14%|█▍        | 11/76 [00:22<02:13,  2.06s/it][A[A[A[A[A[A[A






 16%|█▌        | 12/76 [00:24<02

In [0]:
loss_ctx_B.shape
# b

torch.Size([4, 1, 300, 200])

In [0]:
# a = next(iter(loader_pants))[0][:,3:,:,:]

In [0]:
a = []
while True:
  a += [1,2,3]
  a.extend(a)