In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory

import os
print(os.listdir("../input"))

# Any results you write to the current directory are saved as output.

In [None]:
import torch
torch.cuda.is_available()

In [None]:
model_head = '../input/zero-models-attn/'
# os.listdir(model_head)

In [None]:
head = '../input/another-animals-dataset/ex_imgs/'

cat_list = list(sorted(os.listdir(head)))
cat_list

### Read Classes

In [None]:
class_csv = pd.read_csv('../input/another-animals-dataset/labels/classes.txt', sep='\t', header=None)
class_csv.sample(5)

In [None]:
class_to_label_dict = {}
for cls in class_csv.values:
    class_to_label_dict[cls[1]] = cls[0] - 1

label_to_class_dict = {v: k for k, v in class_to_label_dict.items()}

In [None]:
label_to_class_dict[0], class_to_label_dict[label_to_class_dict[0]]

### Read Attributes

In [None]:
attr_csv = pd.read_csv('../input/another-animals-dataset/labels/predicate-matrix-binary.txt', sep=' ', header=None)
attr_csv.sample(5)

In [None]:
attr_matrix = attr_csv.values.astype(np.float32)
attr_matrix[0]

In [None]:
img_list = []
label_list = []


for cat in cat_list:
    c_img_list = [head + cat + '/' + ipl for ipl in os.listdir(head + cat)]
    img_list += c_img_list
    label_list += [class_to_label_dict[cat]] * len(c_img_list)
    
len(img_list), img_list[0], len(label_list), label_list[0]

In [None]:
a_head = '../input/awa-dataset/imgs_e/'

In [None]:
for cat in cat_list:
    c_img_list = [a_head + cat + '/' + ipl for ipl in os.listdir(a_head + cat)]
    img_list += c_img_list
    label_list += [class_to_label_dict[cat]] * len(c_img_list)
    
len(img_list), img_list[-1], len(label_list), label_list[-1]

In [None]:
from PIL import Image, ImageOps, ImageEnhance

import matplotlib.pyplot as plt

img = Image.open(img_list[0])
plt.imshow(img)
plt.show()

In [None]:
img = Image.open(img_list[-1])
plt.imshow(img)
plt.show()

In [None]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models
import torchvision.transforms as transforms

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary

In [None]:
class AwaDataSet(Dataset):
    #  img_path_list: text. labels: text, need to locate through
    def __init__(self, img_path_list, label_list, x_transform):
        self.img_path_list = img_path_list
        self.label_list = label_list
        self.x_transform = x_transform

    def __getitem__(self, idx):
        
#         print(self.img_path_list[idx])
        
        try:
            img = Image.open(self.img_path_list[idx])
        except:
            img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
                
        if head in self.img_path_list[idx]:
            if np.random.randint(1000) <= 125:
                img = ImageOps.mirror(img)

            if np.random.randint(1000) <= 50:
                img = img.rotate(np.random.randint(-18, 18))

        else:
            img = img.convert('RGBA')
            
            blank = img.copy()
        
            if np.random.randint(1000) <= 125:
                img = ImageOps.mirror(img)

            if np.random.randint(1000) <= 50:
                img = img.rotate(np.random.randint(-18, 18), expand=True)
                
            width, height = img.size
            target_size = max(img.size)

            blank = blank.resize((target_size * 3, target_size * 3))
            blank = blank.crop((target_size, target_size, 2 * target_size, 2 * target_size))
            
            blank.paste(img, (int((max(img.size) - width) / np.random.uniform(1.75, 2.25)),
                          int((max(img.size) - height) / np.random.uniform(1.75, 2.25))), img)
            img = blank
            
        img = img.convert('RGB')
        img = img.resize((128, 128))
        img = np.array(img, dtype=np.float32) / 255

        if np.random.randint(1000) <= 125:
            img = np.clip(img * np.random.uniform(0.8, 1.28), 0, 1)
        
        img = self.x_transform(img)
        
        label = self.label_list[idx]
        
        attr = attr_matrix[label]

        return img, label, attr

    def __len__(self):
        return len(self.img_path_list)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,) * 3, std=(0.5,) * 3),
])

In [None]:
train_dataset = AwaDataSet(img_list, label_list, transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
for db in train_loader:
    img, label, attr = db
#     print(img.mean())
#     print(label)
#     print(label)
    
    print(label_to_class_dict[label[0].item()])
    
    img = img[0].cpu().data.numpy()
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img * 0.5 + 0.5)
    plt.show()

    break

### weight init...

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

### Define spectral norm

In [None]:
from torch.nn import Parameter

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

### Self Attention Layer

In [None]:
class SelfAttn(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttn, self).__init__()
        self.Q = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.K = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.V = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        '''
            input_size: B * C * W * H
            return:
                output: self attn value + input
                attn: B * N * N (N = W * H)
        '''
        
        b_size, C, W, H = x.size()
        
        proj_Q = self.Q(x).view(b_size, -1, W * H).permute(0, 2, 1)
        proj_K = self.K(x).view(b_size, -1, W * H)
        energy = torch.bmm(proj_Q, proj_K)
        
        attn = self.softmax(energy)
        proj_V = self.V(x).view(b_size, -1, W * H)
        
        output = torch.bmm(proj_V, attn.permute(0, 2, 1))
        output = output.view(b_size, C, W, H)
        
        output = self.gamma * output + x
        return x

### Define Generator

In [None]:
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        
        self.sigma = nn.Parameter(torch.ones(512))
        self.myu = nn.Parameter(torch.zeros(512))
        
        self.sigma_256 = nn.Parameter(torch.ones(256))
        self.myu_256 = nn.Parameter(torch.zeros(256))
        
        self.encoder = nn.Sequential(
            nn.Linear(85, 128, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            
            nn.Linear(128, 192, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            
            nn.Linear(192, 256, bias=False),
        )
        
        self.main = nn.Sequential(
            SpectralNorm(nn.ConvTranspose2d(256 + 256, 1024, 4, 1, 0, bias=False)),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.ReLU(inplace=True),
            # size: 4 x 4

            SpectralNorm(nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.ReLU(inplace=True),
            # size: 8 x 8

            SpectralNorm(nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
#             nn.ReLU(inplace=True),
            # size: 16 x 16

            SpectralNorm(nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            SelfAttn(128),
            # size: 32 x 32
            
            SpectralNorm(nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            SelfAttn(64),
#             nn.Tanh()
            # size: 64 x 64     
            
            # size: 128 x 128 
        )
        
        self.output_awa = nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False)
        self.output_rf = nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, x, mode='attr'):  
        
        if mode == 'attr':
            
            # y: attr
            # x: noise
            
            y = x[1].view(-1, 85)
            y = self.encoder(y).view(-1, 256)
            
            x = x[0].view(-1, 256)
            x = self.sigma_256 * x + self.myu_256
            
            x = torch.cat((x, y), dim=1).view(-1, 512, 1, 1)
            
        elif mode == 'rf':
            
            # x: noise
            
            x = x.view(-1, 512)
            x = self.sigma * x + self.myu
            x = x.view(-1, 512, 1, 1)
            
        elif mode == 'rec':
            
            # y: attr
            # x: enc
            
            y = x[1].view(-1, 85)
            y = self.encoder(y).view(-1, 256)
            
            x = x[0].view(-1, 256)
            
            x = torch.cat((x, y), dim=1).view(-1, 512, 1, 1)
        
        x = self.main(x)
        
        if mode == 'attr':
            x = self.output_awa(x)
            x = self.tanh(x)

        elif mode == 'rf':
            x = self.output_awa(x)
            x = self.tanh(x)
        
        elif mode == 'rec':
            x = self.output_awa(x)
            x = self.sigmoid(x)
        
        return x

In [None]:
G = generator().cuda()

In [None]:
tesy = torch.ones(4, 85, 1, 1).cuda()
tesn = torch.ones(4, 256, 1, 1).cuda()
tes_output = G((tesn, tesy), mode='attr')
tes_output.size()

In [None]:
tesz = torch.ones(4, 512, 1, 1).cuda()
tes_output_ = G(tesz, mode='rf')
tes_output_.size()

In [None]:
tes_output_ = G((tesn, tesy), mode='rec')
tes_output_.size()

In [None]:
if os.path.exists(model_head + 'g_model.pth'):
    G.load_state_dict(torch.load(model_head + 'g_model.pth'))
    print('load...')
else:
    G.apply(weights_init)
    print('init...')

### define discriminator

In [None]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.main = nn.Sequential(
            SpectralNorm(nn.Conv2d(3, 32, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 64 x 64
            
            SpectralNorm(nn.Conv2d(32, 64, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(64),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 32 x 32
            
            SpectralNorm(nn.Conv2d(64, 128, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(128),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 16 x 16
            
            SpectralNorm(nn.Conv2d(128, 256, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(256),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            # size: 8 x 8
            
            SpectralNorm(nn.Conv2d(256, 512, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(512),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            SelfAttn(512),
            # size: 4 x 4
            
#             nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            SpectralNorm(nn.Conv2d(512, 1024, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(1024),
            nn.InstanceNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            SelfAttn(1024),
            # size: 2 x 2
            
#             nn.Sigmoid(),
        )
        
        self.rf = nn.Conv2d(1024, 1, 4, 2, 1, bias=False)
        self.encoder = nn.Conv2d(1024, 256, 4, 2, 1, bias=False)
        self.clsfy = nn.Linear(256, 50, bias=False)
                
    def forward(self, x, mode='rf', clsfy=False):
        x = self.main(x)
        
        if mode == 'rf':
            x = self.rf(x).view(-1)
        elif mode == 'encode':
            x = self.encoder(x).view(-1, 256)
            if clsfy:
                x = self.clsfy(x)
            else:
                x = x / (torch.sum(x ** 2, dim=1, keepdim=True) + 1e-8)
        
        return x

In [None]:
D = discriminator().cuda()

In [None]:
tesx_ = torch.ones(4, 3, 128, 128).cuda()
D(tesx_, mode='rf').size()

In [None]:
tesy_ = torch.ones(4, 3, 128, 128).cuda()
D(tesy_, mode='encode', clsfy=False).size()

In [None]:
tesz_ = torch.ones(4, 3, 128, 128).cuda()
D(tesz_, mode='encode', clsfy=True).size()

In [None]:
if os.path.exists(model_head + 'd_model.pth'):
    D.load_state_dict(torch.load(model_head + 'd_model.pth'))
    print('load..')
else:
    D.apply(weights_init)
    print('init...')

### Attribute Encoder

In [None]:
class attr_encoder(nn.Module):
    def __init__(self):
        super(attr_encoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(85, 128, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            
            nn.Linear(128, 192, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            
            nn.Linear(192, 256, bias=False),
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = x / (torch.sum(x ** 2, dim=1, keepdim=True) + 1e-8)
        
        return x

In [None]:
E = attr_encoder().cuda()
summary(E, (85,))

In [None]:
if os.path.exists(model_head + 'e_model.pth'):
    E.load_state_dict(torch.load(model_head + 'e_model.pth'))
    print('load...')

### Visual Encoder (for rec)

In [None]:
class visual_encoder(nn.Module):
    def __init__(self):
        super(visual_encoder, self).__init__()
        self.main = nn.Sequential(
            SpectralNorm(nn.Conv2d(3, 16, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 64 x 64
            
            SpectralNorm(nn.Conv2d(16, 32, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(64),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 32 x 32
            
            SpectralNorm(nn.Conv2d(32, 64, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(128),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # size: 16 x 16
            
            SpectralNorm(nn.Conv2d(64, 128, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(256),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            # size: 8 x 8
            
            SpectralNorm(nn.Conv2d(128, 256, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(512),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
#             Self_Attn(512),
            # size: 4 x 4
            
            SpectralNorm(nn.Conv2d(256, 512, 4, 2, 1, bias=False)),
#             nn.BatchNorm2d(1024),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            # size: 2 x 2
            
#             nn.Sigmoid(),
        )
        
        self.encoder_output = nn.Conv2d(512, 1024, 4, 2, 1, bias=False)
        
        self.mu = nn.Linear(1024, 256)
        self.logvar = nn.Linear(1024, 256)
        
    def encode(self, x):
        x = self.main(x)
        x = self.encoder_output(x)
        x = x.view(-1, 1024)
        
        return self.mu(x), self.logvar(x)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        
        return mu + eps * std
                
    def forward(self, x):
        
        mu, logvar = self.encode(x)
        x = self.reparameterize(mu, logvar)

        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-8)
        x = x.view(-1, 256, 1, 1)

        return x, mu, logvar

In [None]:
VE = visual_encoder().cuda()

In [None]:
tesx = torch.ones(4, 3, 128, 128).cuda()
[x.size() for x in VE(tesx)]

In [None]:
if os.path.exists(model_head + 've_model.pth'):
    VE.load_state_dict(torch.load(model_head + 've_model.pth'))
    print('load...')
else:
    VE.apply(weights_init)
    print('init...')

### vae loss

In [None]:
def criterion_vae(rec_imgs, imgs, mu, logvar):
    
    ### remember: the range of output image is (-1, 1),
    ### without making it to range(0, 1), using bce_loss would raise some strange error.
    
#     print(rec_imgs.max(), rec_imgs.min(), imgs.max(), imgs.min())
    bce_loss = F.binary_cross_entropy(rec_imgs.view(-1, 3 * 128 * 128), imgs.view(-1, 3 * 128 * 128), reduction='sum')
    
    kld_loss = - 0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp())
    
    return (bce_loss + kld_loss) / rec_imgs.size()[0]

In [None]:
class KanCosineSimilarity(nn.Module):
    def __init__(self):
        super(KanCosineSimilarity, self).__init__()

    def forward(self, x, y, labels):
        return 1/2 * torch.mean((labels - F.cosine_similarity(x, y, eps=1e-6)) ** 2)

In [None]:
G.sigma_256.mean(), G.sigma.std(), G.myu_256.mean(), G.myu_256.std()

In [None]:
criterion_bce = nn.BCELoss()
criterion_mse = nn.MSELoss()
criterion_ce = nn.CrossEntropyLoss()
criterion_cos = KanCosineSimilarity()

g_optimizer = optim.Adam(G.parameters(), lr=5e-5, betas=(0.5, 0.999))
d_optimizer = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
e_optimizer = optim.Adam(E.parameters(), lr=2e-4, betas=(0.5, 0.999))

ag_optimizer = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))
ae_optimizer = optim.Adam(VE.parameters(), lr=1e-4, betas=(0.5, 0.9))

In [None]:
import matplotlib.pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 6.4, 6.4

In [None]:
ae_epoch = 1
d_epoch = 1
g_epoch = 1

In [None]:
import sys

In [None]:
nc = 2e-1
decay = 0.9999

for epoch in range(20):
    d_running_loss_rf = 0.0
    d_running_loss_cos = 0.0
    d_running_loss_cls = 0.0
    
    g_running_loss_rf = 0.0
    g_running_loss_cos = 0.0
    
    ae_running_loss = 0.0
    
    torch.save(G.state_dict(), 'g_model.pth')
    torch.save(D.state_dict(), 'd_model.pth')
    torch.save(E.state_dict(), 'e_model.pth')
    torch.save(VE.state_dict(), 've_model.pth')
    
    sys.stdout.write('model saved...')
    
    for i, data_batch in enumerate(train_loader):
        
        if i % 50 == 49:
            print('nc:', nc)
        
        real_imgs, labels, attrs = data_batch
        real_imgs = real_imgs.cuda()
        attrs = attrs.cuda()
        attrs_ = (attrs - 0.5) / 0.5

        b_size = real_imgs.size()[0]
        
        if b_size == 1:
            break
            
        label_real = labels.cuda().long()
        label_fake = (torch.ones_like(label_real) * 50).cuda().long()
        
        rf_labels_real = torch.ones(b_size).cuda()
        rf_labels_fake = torch.zeros(b_size).cuda()
        
        rf_labels_real_s = 0.7 * rf_labels_real + 0.55 * torch.rand_like(rf_labels_real).cuda()
        rf_labels_fake_s = 0.25 * torch.rand_like(rf_labels_fake).cuda()
        
        rf_labels_real_ = torch.ones(b_size).cuda()
        
        cos_labels_real = torch.ones(b_size)
        cos_labels_fake = torch.zeros(b_size)
        cos_labels_real_ = torch.ones(b_size).cuda()
        
        cos_labels_real_s = cos_labels_real - 0.2 * torch.rand(b_size)
        cos_labels_fake_s = 0.375 * torch.rand(b_size) - 0.25
        
        cos_labels_real = cos_labels_real.cuda()
        cos_labels_fake = cos_labels_fake.cuda()
        
        cos_labels_real_s = cos_labels_real_s.cuda()
        cos_labels_fake_s = cos_labels_fake_s.cuda()
        
        if np.random.randint(1000) < 25:
            rf_labels_real, rf_labels_fake = rf_labels_fake, rf_labels_real
            rf_labels_real_s, rf_labels_fake_s = rf_labels_fake_s, rf_labels_real_s
            
        if np.random.randint(1000) < 50:
            cos_labels_real, cos_labels_fake = cos_labels_fake, cos_labels_real
            cos_labels_real_s, cos_labels_fake_s = cos_labels_fake_s, cos_labels_real_s
        
        if b_size == 1:
            break
            
        for _ in range(ae_epoch):
            VE.zero_grad()
            G.zero_grad()
            
            ### encoder used the encoded latent vector and attr
            
            enc_outputs, mu, logvar = VE(real_imgs)
            
            z = torch.randn(b_size, 256, 1, 1).cuda()
            f_attr_matrix = torch.abs(torch.randn_like(attrs_)) * attrs_ * 0.55 + attrs_ * 0.7
            
            rec_imgs = G((enc_outputs, f_attr_matrix), mode='rec')
            
            ae_loss = 1e-2 * criterion_vae(rec_imgs, real_imgs * 0.5 + 0.5, mu, logvar)
            ae_loss.backward()
            
            ag_optimizer.step()
            ae_optimizer.step()
            
            ae_running_loss += ae_loss.item()
            
        for _ in range(d_epoch):
            # -------------------
            D.zero_grad()

            d_real_outputs_rf = D(real_imgs + nc * torch.randn_like(real_imgs), mode='rf').view(-1)
            d_real_loss_rf = criterion_mse(d_real_outputs_rf, rf_labels_real_s)
            
            z = torch.randn(b_size, 512, 1, 1).cuda()
            fake_imgs = G(z, mode='rf')
            
            d_fake_outputs_rf = D(fake_imgs.detach() + nc * torch.randn_like(fake_imgs), mode='rf').view(-1)
            d_fake_loss_rf = criterion_mse(d_fake_outputs_rf, rf_labels_fake_s)
            
            d_loss_rf = d_real_loss_rf + d_fake_loss_rf
            d_loss_rf.backward()
            d_optimizer.step()
            
            d_running_loss_rf += d_loss_rf.item()
            
            # -------------------
            D.zero_grad()
            E.zero_grad()
            
            d_real_encode_v = D(real_imgs + nc * torch.randn_like(real_imgs), mode='encode', clsfy=False)
            d_real_encode_s = E(attrs_)

            d_real_loss_cos = criterion_cos(d_real_encode_v, d_real_encode_s, cos_labels_real_s)
            
            z = torch.randn(b_size, 256, 1, 1).cuda()
            f_attr_matrix = torch.abs(torch.randn_like(attrs_)) * attrs_ * 0.55 + attrs_ * 0.7
            
            fake_imgs = G((z, f_attr_matrix), mode='attr')
            
            d_fake_encode_v = D(fake_imgs + nc * torch.randn_like(real_imgs), mode='encode', clsfy=False)
            
            d_fake_loss_cos = criterion_cos(d_fake_encode_v, d_real_encode_s, cos_labels_fake_s)
            
            d_loss_cos = d_real_loss_cos + d_fake_loss_cos
            d_loss_cos.backward()
            d_optimizer.step()
            e_optimizer.step()
            
            d_running_loss_cos += d_loss_cos.item()
            
            # -------------------
            d_outputs_cls = D(real_imgs, mode='encode', clsfy=True)
            
            d_loss_cls = criterion_ce(d_outputs_cls, label_real)
            
            lim = 1.0
            d_loss_cls = lim * d_loss_cls
            d_loss_cls.backward()
            d_optimizer.step()
            
            d_running_loss_cls += d_loss_cls.item() / lim
            
            # -------------------
            
            nc *= decay

        for _ in range(g_epoch):
            # --------------------
            G.zero_grad()

            z = torch.randn(b_size, 512, 1, 1).cuda()
            fake_imgs = G(z, mode='rf')
            
            d_outputs_rf = D(fake_imgs, mode='rf').view(-1)
            g_loss_rf = criterion_mse(d_outputs_rf, rf_labels_real_)
            
            g_loss_rf.backward()
            g_optimizer.step()
            
            g_running_loss_rf += g_loss_rf.item()
            
            # --------------------
            G.zero_grad()
            z = torch.randn(b_size, 256, 1, 1).cuda()
            f_attr_matrix = torch.abs(torch.randn_like(attrs_)) * attrs_ * 0.55 + attrs_ * 0.7     
            
            fake_imgs = G((z, f_attr_matrix), mode='attr')
            
            d_encode_v = D(fake_imgs, mode='encode', clsfy=False)
            d_encode_s = E(attrs_)
            
            g_loss_cos = criterion_cos(d_encode_v, d_encode_s, cos_labels_real_)
            g_loss_cos.backward()
            g_optimizer.step()
            
            g_running_loss_cos += g_loss_cos.item()
            
            # --------------------  

        t = 50
        if i % t == t-1 or i == 0:

            print(epoch+1, (i+1) * 64, 'd_loss_rf:', d_running_loss_rf / (t if i != 0 else 1),
                  ', d_loss_cos:', d_running_loss_cos / (t if i != 0 else 1), 
                  ', d_loss_cls:', d_running_loss_cls / (t if i != 0 else 1))
            
            print(epoch+1, (i+1) * 64, 'g_loss_rf:', g_running_loss_rf / (t if i != 0 else 1),
                  ', g_loss_cos:', g_running_loss_cos / (t if i != 0 else 1))
            
            print('ae_loss:', ae_running_loss / (t if i != 0 else 1))
            
            d_running_loss_rf = 0.0
            d_running_loss_cos = 0.0
            d_running_loss_cls = 0.0
            
            g_running_loss_rf = 0.0
            g_running_loss_cos = 0.0
            
            ae_running_loss = 0.0
            
            real_samples = real_imgs[:6]
            real_samples_ = real_samples.cpu().data.numpy()
            real_samples_ = np.transpose(real_samples_, (0, 2, 3, 1))            
            
            plt.subplots_adjust(wspace=0.025, hspace=0.025)
            plt.grid(False)
            for k in range(6):
                plt.subplot(1, 6, k+1)
#                 plt.title(label_to_class_dict[f_labels[k]])
                plt.axis('off')

                plt.imshow(real_samples_[k] * 0.5 + 0.5)
            plt.show()
            
            enc_samples, _, _ = VE(real_samples)
            f_attr_matrix = torch.abs(torch.randn_like(attrs_)) * attrs_ * 0.55 + attrs_ * 0.7
            
#             z = torch.randn(4, 256, 1, 1).cuda()
#             enc_samples_ = torch.cat((z, enc_samples), dim=1)
            rec_samples = G((enc_samples, f_attr_matrix[:6]), mode='rec')
            rec_samples_ = rec_samples.cpu().data.numpy()
            rec_samples_ = np.transpose(rec_samples_, (0, 2, 3, 1))
            
            plt.subplots_adjust(wspace=0.025, hspace=0.025)
            plt.grid(False)
            for k in range(6):
                plt.subplot(1, 6, k+1)
#                 plt.title(label_to_class_dict[f_labels[k]])
                plt.axis('off')

                plt.imshow(rec_samples_[k])
            plt.show()
            

            z = torch.randn(6, 256, 1, 1).cuda()

            f_attr_matrix = torch.abs(torch.randn_like(attrs_)) * attrs_ * 0.55 + attrs_ * 0.7
            fake_imgs = G((z, f_attr_matrix[:6]), mode='attr').cpu().data.numpy()
            fake_imgs = np.transpose(fake_imgs, (0, 2, 3, 1))

            plt.subplots_adjust(wspace=0.025, hspace=0.025)
            plt.grid(False)
            for k in range(6):
                plt.subplot(1, 6, k+1)
                plt.title(label_to_class_dict[labels.cpu().data.numpy()[k]])
                plt.axis('off')

                plt.imshow(fake_imgs[k] * 0.5 + 0.5)
            plt.show()

            z = torch.randn(6, 512, 1, 1).cuda()

            fake_imgs = G(z, mode='rf').cpu().data.numpy()
            fake_imgs = np.transpose(fake_imgs, (0, 2, 3, 1))

            #             f_labels = f_labels.cpu().data.numpy()

            plt.subplots_adjust(wspace=0.025, hspace=0.025)
            plt.grid(False)
            for k in range(6):
                plt.subplot(1, 6, k+1)
                plt.axis('off')

                plt.imshow(fake_imgs[k] * 0.5 + 0.5)
            plt.show()

            #         break
            #     break