In [None]:
from tqdm import tqdm
from typing import Union, Tuple, List, Optional
import os
import random
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from dataclasses import dataclass
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
from itertools import chain

In [None]:
# ! pip install wandb
import wandb

In [None]:
@dataclass
class tox_config:
    batch_size: int = 64
    lr: float = 3e-4
    beta1: float = 0.9  # 0.0 for Adam-like optims
    beta2: float = 0.999  # 0.9 for Adam-like optims
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    val_size: float = 0.1

    # base params to create generator and discriminator
    in_channels: int = 1
    base_channels: int = 64
    latent_dim: int = 128
    num_blocks = [2, 2, 2, 2]  # [1, 1, 1, 1]
    num_epochs: int = 50
    random_seed: int = 42

    checkpoint_interval: int = 10
    checkpoint_dir: str = "checkpoints"
    img_dir: str = "generated"
    dataset_path: str = "triplet_extended_dataset.pt"  # dataset.pt

    generator_size: int = batch_size  # number of generated samples for evaluation (by eyes)
    num_dis: int = 1  # number of discriminator updates per iteration
    generator_batches: int = 2
    num_fakes = batch_size * generator_batches # number of fake examples for single generator update
    generator_loss: str = "mse_loss"  # ['mse_loss', 'cauchy_loss', 'gemanmcclure_loss', 'welsch_loss', 'l1_loss', 'binary_cross_entropy_loss', 'huber_loss', 'smooth_l1', 'another_smooth_l1_loss']
    generator_loss_weights = [0.4, 0.3, 0.3]
    use_gram_matrix: bool = False  # either use gram matrix for perception loss or usual L1Loss

    vgg16_feature_loss_p_norm: int = 1  # [1, 2]

    intensity_mode: str = "mean"  # [mean, max]
    intensity_reduction: str = "mean"  # [mean, sum]
    intensity_loss_coef: float = 0.001

    num_workers: int = 0

    description: str = '''aae with triplet loss upon arcface embeddings including general loss upon
                            anchors, positives and negatives (both generator & discriminator)'''
    adversarial_term: float = 0.001
    triplet_loss_term: float = 0.1

    condition_dim: int = 2

    model_path: str = "ArcFaceLoss_model.pt"


In [None]:
def random_roll(img, label):
    # possible problem: will it work properly if inside the dataloader?
    shifts = tuple(np.random.randint(low=0, high=img.shape, size=(2,)))
    return img.roll(shifts=shifts, dims=(0,1)), label

_Number = Union[float, int]

class MaterialPercentDataset(Dataset):
    image_area = 64 * 64

    def __init__(self,
                 images: torch.Tensor,
                 labels: torch.Tensor,
                 transform=[random_roll,]) -> None:
        super().__init__()
        
        self.images = images
        self.labels = labels
        self.transform = transform

    def calc_percent(self, image: torch.Tensor) -> torch.Tensor:
        material_count = image.sum()
        return material_count / self.image_area

    def __len__(self) -> int:
        return self.images.shape[0]

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.images[index]
        y = self.labels[index]
        if callable(self.transform):
            x, y = self.transform(x, y)
        elif isinstance(self.transform, list): # erzatz for transfoms.Compose
            for tr in self.transform:
                x, y = tr(x, y)
        
        percent = self.calc_percent(x)

        if len(x.shape) == 2:
            x = x.unsqueeze(0)

        return x, percent.unsqueeze(-1), y


class BinarizedDataset(MaterialPercentDataset):
    def __init__(self,
                 images: torch.Tensor,
                 labels: torch.Tensor,
                 bins_number: int = 15,
                 transform=[random_roll,]) -> None:
        super().__init__(images, labels, transform)

        self.bins_number = bins_number

        self.linspace = torch.linspace(0, 1, steps=bins_number)
    
    def one_hot(self, scalar: _Number) -> torch.Tensor:
        _, idx = torch.min(torch.abs(self.linspace - scalar), dim=0)
        out = torch.zeros(self.bins_number)
        out[idx] = 1
        return out
    
    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
        x = self.images[index]
        y = self.labels[index]
        if callable(self.transform):
            x, y = self.transform(x, y)
        elif isinstance(self.transform, list): # erzatz for transfoms.Compose
            for tr in self.transform:
                x, y = tr(x, y)
        
        is_broken = y[-1].unsqueeze(0)
        percent = self.calc_percent(x).item()
        condition = self.one_hot(percent)
        condition = torch.cat((condition, is_broken), dim=0)
        
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
        
        return x, condition, y


class TripletBinsDataset(BinarizedDataset):
    '''
        bins_indexes Tensor is supposed to be ids from the extended_clean_dataset.pt
    '''
    image_area = 64 * 64

    def __init__(self,
                 images: torch.Tensor,
                 labels: torch.Tensor,
                 bins_indexes: torch.Tensor,
                 transform=[random_roll,]) -> None:
        super().__init__()

        self.images = images
        self._labels = labels
        self.transform = transform
        self.bins_indexes = bins_indexes

        self.make_appropriate_labels()
    
    def make_appropriate_labels(self):
        new_labels = []

        for label in self._labels:
            if label[0].item() == 1.0:
                new_labels.append(0)
            else:
                new_labels.append(1)
        
        self.labels = torch.tensor(new_labels, dtype=torch.int8)
    
    def __len__(self) -> int:
        return self.images.shape[0]
    
    def make_condition(self,
                       x: torch.Tensor) -> torch.Tensor:
        percent = self.calc_percent(x).item()
        condition = self.one_hot(percent)
        return condition

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        anchor = self.images[index]
        anchor_label = self.labels[index]
        
        positive_list = self.bins_indexes[self.bins_indexes != index]
        positive_list = positive_list[self.labels[self.bins_indexes != index] == anchor_label]

        '''
            TODO: сделать выборку из негативов, которые могут состоять из того же скелета,
            тогда они будут находиться в других бинах. Если первой опции нет, то просто 
            берем выборку из других бинов.
        '''
        negative_list = self.bins_indexes[self.bins_indexes == index]
        negative_list = negative_list[self.labels[self.bins_indexes == index] != anchor_label]
        
        if not len(negative_list):
            negative_list = self.bins_indexes[self.bins_indexes != index]
            negative_list = negative_list[self.labels[self.bins_indexes == index] != anchor_label]
        
        positive = self.images[random.choice(positive_list)]
        negative = self.images[random.choice(negative_list)]

        anchor_condition = self.make_condition(anchor)
        pos_condition = self.make_condition(positive)
        neg_condition = self.make_condition(negative)

        for tr in self.transform:
            anchor, _  = tr(anchor, None)
            positive, _ = tr(positive, None)
            negative, _ = tr(negative, None)
        
        if anchor.ndim == 2:
            anchor = anchor.unsqueeze(0)
            positive = positive.unsqueeze(0)
            negative = negative.unsqueeze(0)
        
        return (anchor, anchor_condition) (positive, pos_condition) (negative, neg_condition)

In [None]:
DEFAULT_DTYPE = torch.float32


def seed_everything(seed: int = tox_config.random_seed) -> None:
    random.seed(seed)
    os.environ['PYTHONASSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def l2_normalize(x: torch.Tensor, eps=1e-12):
    return x / (x.pow(2).sum() + eps).sqrt()


class Reshape(nn.Module):
    def __init__(self, shape) -> None:
        super().__init__()

        self.shape = shape

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.reshape(self.shape)

In [None]:
class AAEDiscriminatorLoss(_Loss):
    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super().__init__(size_average, reduce, reduction)

        self.loss = nn.BCEWithLogitsLoss()

    def forward(self,
                target_scores: torch.Tensor,
                encoder_scores: torch.Tensor) -> torch.Tensor:

        labels_true = torch.ones_like(target_scores)
        labels_fake = torch.zeros_like(encoder_scores)

        loss = self.loss(target_scores, labels_true) + self.loss(encoder_scores, labels_fake)
        return loss / 2


class AAEGeneratorLoss(_Loss):
    width_dim = 64
    height_dim = 64

    def __init__(self,
                 adversarial_term: float,
                 weights: List[float],
                 size_average=None,
                 reduce=None,
                 reduction: str = 'mean') -> None:
        super().__init__(size_average, reduce, reduction)

        self.adversarial_loss = nn.BCEWithLogitsLoss()
        self.pixelwise_loss = nn.L1Loss()
        self.tiling_loss = nn.L1Loss()

        self.alpha = adversarial_term

        self.general_coef, self.vertical_coef, self.horizontal_coef = weights

    def forward(self,
                encoder_scores: torch.Tensor,
                decoded_images: torch.Tensor,
                real_images: torch.Tensor) -> torch.Tensor:
        labels_true = torch.ones_like(encoder_scores)

        adversarial_loss = self.adversarial_loss(encoder_scores, labels_true)
        reconstruction_term = self.pixelwise_loss(decoded_images, real_images)

        base_loss = self.alpha * adversarial_loss + (1 - self.alpha) * reconstruction_term
        vertical_loss = self.vertical_loss(decoded_images)
        horizontal_loss = self.horizontal_loss(decoded_images)

        return self.general_coef * base_loss + vertical_loss + horizontal_loss

    def vertical_loss(self, tile: torch.Tensor) -> torch.Tensor:
        upper_line = tile[:, :, 0, :]
        lower_line = tile[:, :, self.height_dim - 1, :]

        return self.vertical_coef * self.tiling_loss(upper_line, lower_line)

    def horizontal_loss(self, tile: torch.Tensor) -> torch.Tensor:
        left_line = tile[:, :, :, 0]
        right_line = tile[:, :, :, self.width_dim - 1]

        return self.horizontal_coef * self.tiling_loss(left_line, right_line)


In [None]:
class Reshape(nn.Module):
    def __init__(self, shape: Tuple[int, int, int, int]) -> None:
        super().__init__()

        self.shape = shape

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.reshape(self.shape)


class DecoderBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int) -> None:
        super().__init__()

        self.block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        )

        if in_channels == out_channels:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.block(x)
        out += self.skip_connection(x)
        return out


class DiscriminatorBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 downsample: bool = False) -> None:
        super().__init__()

        self.first_activation = nn.ReLU()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1 + downsample, 1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        )

        if in_channels == out_channels and not downsample:
            self.skip_connection = nn.Identity()
        else:
            skip = [nn.AvgPool2d(2)] if downsample else []
            skip.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0))
            self.skip_connection = nn.Sequential(*skip)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.block(self.first_activation(x))
        out += self.skip_connection(x)
        return out

In [None]:
class Encoder(nn.Module):
    '''
        Architecture is the same as in Discriminator except the head projects to the latent space with latent_dim
    '''
    def __init__(self,
                 in_channels: int,
                 base_channels: int = 64,
                 latent_dim: int = 128,
                 num_blocks: List[int] = [1, 1, 1, 1]) -> None:
        super().__init__()

        self.latent_dim = latent_dim

        first_dis_block = DiscriminatorBlock(in_channels, base_channels, downsample=True)
        first_dis_block.first_activation = nn.Identity()

        self.net = nn.Sequential(
            first_dis_block,
            self.make_layer(base_channels, 2 * base_channels, num_blocks[0]),
            self.make_layer(2 * base_channels, 4 * base_channels, num_blocks[1]),
            self.make_layer(4 * base_channels, 8 * base_channels, num_blocks[2]),
            self.make_layer(8 * base_channels, 16 * base_channels, num_blocks[3]),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(start_dim=1),
            nn.Linear(16 * base_channels, latent_dim)
        )

        self.reset_parameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(
                    m.weight, a=0.0, mode='fan_in', nonlinearity="leaky_relu"
                )

    @staticmethod
    def make_layer(in_channels: int,
                   out_channels: int,
                   num_blocks: int) -> nn.Sequential:
        return nn.Sequential(
            DiscriminatorBlock(in_channels, out_channels, downsample=True),
            *[DiscriminatorBlock(out_channels, out_channels) for _ in range(num_blocks - 1)]
        )

    def get_model_size(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class Decoder(nn.Module):
    def __init__(self,
                 condition_dim: int,
                 in_channels: int,
                 base_channels: int = 64,
                 latent_dim: int = 128,
                 num_blocks: List[int] = [1, 1, 1, 1]) -> None:
        super().__init__()

        self.latent_dim = latent_dim

        self.net = nn.Sequential(
            nn.Linear(latent_dim + condition_dim, 256 * base_channels, bias=False),
            Reshape((-1, 16 * base_channels, 4, 4)),
            self.make_layer(16 * base_channels, 8 * base_channels, num_blocks[0]),
            self.make_layer(8 * base_channels, 4 * base_channels, num_blocks[1]),
            self.make_layer(4 * base_channels, 2 * base_channels, num_blocks[2]),
            self.make_layer(2 * base_channels, base_channels, num_blocks[3]),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            nn.Conv2d(base_channels, in_channels, 1, 1, 0),
            nn.Tanh()
        )

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(
                    m.weight, a=0.0, mode='fan_in', nonlinearity="leaky_relu"
                )

    @staticmethod
    def make_layer(in_channels: int,
                   out_channels: int,
                   num_blocks: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            DecoderBlock(in_channels, out_channels),
            *[DecoderBlock(out_channels, out_channels) for _ in range(num_blocks - 1)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.net(x)
        return out

    def sample(self, num_samples: int, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
        device = next(self.parameters()).device

        if noise is None:
            noise = torch.randn((num_samples, self.latent_dim))

        return self.forward(noise.to(device))

    def get_model_size(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class AAEDiscriminator(nn.Module):
    def __init__(self,
                 condition_dim: int,
                 input_dim: int,
                 hidden_dim: int = 64) -> None:
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim + condition_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.net(z)

    def enable_grads(self):
        for p in self.parameters():
            p.requires_grad = True

    def disable_grads(self):
        for p in self.parameters():
            p.requires_grad = False

In [None]:
class Flatten(nn.Module):
    def __init__(self,
                 start_dim: int = 0,
                 end_dim: int = -1) -> None:
        super().__init__()

        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.flatten(x, self.start_dim, self.end_dim)


class ResidualBlock(nn.Module):
    """
        https://arxiv.org/abs/1603.05027
    """

    def __init__(self,
                 num_channels: int) -> None:
        super().__init__()
        self.num_channels = num_channels

        self.net = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(num_channels, num_channels // 4, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_channels // 4, num_channels, kernel_size=1, bias=False)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.net(x)


class Classifier(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 hidden_channels: int = 256) -> None:
        super().__init__()

        self.projection_net = nn.Sequential(
            self.make_block(in_channels, hidden_channels // 16),
            self.make_block(hidden_channels // 16, hidden_channels // 8),
            self.make_block(hidden_channels // 8, hidden_channels // 4),
            self.make_block(hidden_channels // 4, hidden_channels // 2),
            self.make_block(hidden_channels // 2, hidden_channels),
            Flatten(1),
            nn.ReLU(),
            nn.Linear(1024, 3)
        )
        # self.head = nn.Linear(3, 1)

    @staticmethod
    def make_block(in_channels: int,
                   out_channels: int,
                   kernel_size: int = 4,
                   stride: int = 2,
                   padding: int = 1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            ResidualBlock(out_channels)
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        embeddings = self.projection_net(x)
        # scores = self.head(embeddings)

        return None, embeddings

    def get_model_size(self) -> int:
        return sum([p.numel() for p in self.parameters() if p.requires_grad])

In [None]:
class TripletEmbeddingLoss(_Loss):
    '''
        Triplet loss realized over embeddings of some representative network
    '''
    def __init__(self,
                 feature_extractor: nn.Module,
                 margin: float = 1.0,
                 swap: bool = True,
                 size_average=None,
                 reduce=None,
                 reduction: str = 'mean') -> None:
        super().__init__(size_average, reduce, reduction)

        self.feature_extractor = feature_extractor
        self.disable_feature_extractor_grads()

        self.margin = margin
        self.swap = swap
    
    def disable_feature_extractor_grads(self):
        for p in self.feature_extractor.parameters():
            p.requires_grad = False
    
    def calc_euclidean(self,
                       x1: torch.Tensor,
                       x2: torch.Tensor) -> torch.Tensor:
        _, x1 = self.feature_extractor(x1)
        _, x2 = self.feature_extractor(x2)

        return (x1 - x2).pow(2).sum(1)
    
    def casual_loss(self,
                    anchor: torch.Tensor,
                    positive: torch.Tensor,
                    negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)

        # this ensures that the hardest negative inside the triplet is used for backpropagation
        if self.swap:
            distance_negative_a = self.calc_euclidean(anchor, negative)
            distance_negative_p = self.calc_euclidean(positive, negative)
            
            distance_negative = torch.minimum(distance_negative_a, distance_negative_p)
        else:
            distance_negative = self.calc_euclidean(anchor, negative)


        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses
    
    def forward(self,
                anchor: torch.Tensor,
                positive: torch.Tensor,
                negative: torch.Tensor) -> torch.Tensor:
        
        loss = self.casual_loss(anchor, positive, negative)

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()


class TripletImageLoss(_Loss):
    '''
        Triplet loss upon generations grom the generator.
    '''
    def __init__(self,
                 feature_extractor=None,
                 margin: float = 1.0,
                 swap: bool = True,
                 use_gram_matrix: bool = False,
                 size_average=None,
                 reduce=None,
                 reduction: str = 'mean') -> None:
        super().__init__(size_average, reduce, reduction)

        self.margin = margin
        self.use_gram_matrix = use_gram_matrix
        self.swap = swap
        
        self.loss_fn = self.calc_euclidean
        if use_gram_matrix:
            self.loss_fn = self.gram_matrix_loss
    
    @staticmethod
    def calc_euclidean(x1: torch.Tensor,
                       x2: torch.Tensor,
                       dims: Tuple[int] = (1, 2, 3)) -> torch.Tensor:
        # dims = list(range(1, len(x1.shape)))
        return (x1 - x2).pow(2).sum(dim=dims).unsqueeze(1)
    
    def casual_loss(self,
                    anchor: torch.Tensor,
                    positive: torch.Tensor,
                    negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.loss_fn(anchor, positive)

        # this ensures that the hardest negative inside the triplet is used for backpropagation
        if self.swap:
            distance_negative_a = self.loss_fn(anchor, negative)
            distance_negative_p = self.loss_fn(positive, negative)
            
            distance_negative = torch.minimum(distance_negative_a, distance_negative_p)
        else:
            distance_negative = self.loss_fn(anchor, negative)


        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses

    @staticmethod
    def compute_gram_matrix(y: torch.Tensor) -> torch.Tensor:
        b, с, h, w = y.shape
        return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)
    
    def gram_matrix_loss(self,
                         x: torch.Tensor,
                         y: torch.Tensor) -> torch.Tensor:
        assert len(x.shape) == 4
        # batch_size, c, h, w = x.shape

        G = self.compute_gram_matrix(x)
        A = self.compute_gram_matrix(y)
        #return A.shape

        return (G - A).pow(2).sum(1)

In [None]:
class TripletAAETrainer:
    def __init__(self,
                 cfg=tox_config()) -> None:

        self.cfg = cfg
        self.device = cfg.device
        seed_everything(cfg.random_seed)
#         self.checkpoint_path = os.path.join(cfg.checkpoint_dir, f"conditional_aae.pt")

        make_dir(cfg.checkpoint_dir)
        make_dir(cfg.img_dir)

        self.encoder = Encoder(cfg.in_channels, cfg.base_channels, cfg.latent_dim, cfg.num_blocks).to(self.device)
        self.generator = Decoder(cfg.condition_dim, cfg.in_channels, cfg.base_channels, cfg.latent_dim, cfg.num_blocks).to(self.device)
        self.discriminator = AAEDiscriminator(cfg.condition_dim, cfg.latent_dim).to(self.device)

        self.gen_optim = torch.optim.AdamW(chain(self.encoder.parameters(), self.generator.parameters()),
                                           lr=cfg.lr,
                                           betas=(cfg.beta1, cfg.beta2))
        self.discr_optim = torch.optim.AdamW(self.discriminator.parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2))

        #self.gen_criterion = GeneratorLoss()
        self.gen_criterion = AAEGeneratorLoss(cfg.adversarial_term, cfg.generator_loss_weights)
        self.discr_criterion = AAEDiscriminatorLoss()

        # init dataset and dataloader
        imgs, labels, indexes = torch.load(cfg.dataset_path)
        traindata = TripletBinsDataset(imgs, labels, indexes)
        self.train_loader = DataLoader(
            traindata, shuffle=True, batch_size=cfg.batch_size, drop_last=True, pin_memory=False, num_workers=cfg.num_workers
        )

        feature_extractor = Classifier().to(self.device)
        arcface_state_dict = torch.load(cfg.model_path)
        feature_extractor.load_state_dict(arcface_state_dict)

        self.triplet_criterion = TripletEmbeddingLoss(feature_extractor)

    def fit(self):
        run_name = f"conditional_tiling_bce_" + str(self.cfg.random_seed)
        print(f"Training starts on {self.cfg.device} 🚀")

        with wandb.init(project="topology_topxpy", group=f"triplet_embedding", name=run_name, job_type="training"):
            wandb.config.update({k: v for k, v in self.cfg.__dict__.items() if not k.startswith('__')})

            gen_noise = torch.randn((self.cfg.generator_size, self.generator.latent_dim)).to(self.device)

            fixed_condition = torch.zeros(self.cfg.generator_size, self.cfg.condition_dim)
            fixed_condition[:32] = torch.tensor([1.0, 0.0])
            fixed_condition[32:] = torch.tensor([0.0, 1.0])
            fixed_condition = fixed_condition.to(self.device)


            for e in range(self.cfg.num_epochs):
                self.encoder.train()
                self.generator.train()
                self.discriminator.train()

                total_discriminator_loss = 0
                total_generator_loss = 0
                total_triplet_loss = 0
                total_overall_generator_loss = 0
                total_discriminator_counter = 0
                total_generator_counter = 0

                with tqdm(self.train_loader, desc=f"{e + 1}/{self.cfg.num_epochs} epochs", total=self.cfg.num_epochs) as t:
                    #self.generator.train()

                    for i, (anchor_tuple, pos_tuple, neg_tuple) in enumerate(self.train_loader):
                        self.encoder.train()
                        self.generator.train()
                        self.discriminator.train()

                        anchor, anchor_condition = anchor_tuple
                        positive, positive_condition = pos_tuple
                        negative, negative_condition = neg_tuple

                        anchor, anchor_condition = anchor.to(self.device), anchor_condition.to(self.device)
                        positive, positive_condition = positive.to(self.device), positive_condition.to(self.device)
                        negative, negative_condition = negative.to(self.device), negative_condition.to(self.device)

                        size = anchor.shape[0]

                        # generator step
                        self.gen_optim.zero_grad()
                        self.discriminator.disable_grads()

                        encoded_anchor = self.encoder(anchor)
                        decoded_anchor = self.generator(encoded_anchor, anchor_condition)

                        encoded_positive = self.encoder(positive)
                        decoded_positive = self.generator(encoded_positive, anchor_condition)

                        encoded_negative = self.encoder(negative)
                        decoded_negative = self.generator(encoded_negative, anchor_condition)

                        encoder_scores_anchor = self.discriminator(encoded_anchor, anchor_condition)
                        encoder_scores_positive = self.discriminator(encoded_positive, positive_condition)
                        encoder_scores_negative = self.discriminator(encoded_negative, negative_condition)


                        generator_loss_anchor = self.gen_criterion(encoder_scores_anchor, decoded_anchor, anchor)
                        generator_loss_positive = self.gen_criterion(encoder_scores_positive, decoded_positive, positive)
                        generator_loss_negative = self.gen_criterion(encoder_scores_negative, decoded_negative, negative)
                        generator_loss = (generator_loss_anchor + generator_loss_positive + generator_loss_negative) / 3

                        triplet_loss = self.triplet_criterion(decoded_anchor, decoded_positive, decoded_negative)
                        overall_generator_loss = (1 - self.cfg.triplet_loss_term) * generator_loss + self.cfg.triplet_loss_term * triplet_loss

                        overall_generator_loss.backward()
                        self.gen_optim.step()

                        total_generator_loss += generator_loss.item() * size
                        total_generator_counter += size
                        total_triplet_loss += triplet_loss.item() * size
                        total_overall_generator_loss += overall_generator_loss.item() * size

                        self.discriminator.enable_grads()

                        # discriminator step
                        self.discr_optim.zero_grad()

                        z = torch.randn(size, self.cfg.latent_dim).to(self.device)

                        target_scores_anchor = self.discriminator(z, anchor_condition)
                        target_scores_positive = self.discriminator(z, positive_condition)
                        target_scores_negative = self.discriminator(z, negative_condition)

                        encoder_scores_anchor = self.discriminator(encoded_anchor.detach(), anchor_condition)
                        encoder_scores_positive = self.discriminator(encoded_positive.detach(), positive_condition)
                        encoder_scores_negative = self.discriminator(encoded_negative.detach(), negative_condition)

                        discriminator_loss_anchor = self.discr_criterion(target_scores_anchor, encoder_scores_anchor)
                        discriminator_loss_positive = self.discr_criterion(target_scores_positive, encoder_scores_positive)
                        discriminator_loss_negative = self.discr_criterion(target_scores_negative, encoder_scores_negative)

                        discriminator_loss = (discriminator_loss_anchor + discriminator_loss_positive + discriminator_loss_negative) / 3

                        discriminator_loss.backward()
                        self.discr_optim.step()

                        total_discriminator_loss += discriminator_loss.item() * size
                        total_discriminator_counter += size

                        t.set_postfix({
                            "discr_loss": total_discriminator_loss / total_discriminator_counter,
                            "gen_loss": total_generator_loss / total_generator_counter,
                        })

                        if not i % self.cfg.checkpoint_interval:
                            wandb.log({
                                "discriminator_loss": total_discriminator_loss / total_discriminator_counter,
                                "generator_loss": total_generator_loss / total_generator_counter,
                                "overall_generator_loss": total_overall_generator_loss / total_generator_counter,
                                "triplet_loss": total_triplet_loss / total_generator_counter
                            })

                        if i == len(self.train_loader) - 1:
                            self.generator.eval()

                            with torch.no_grad():
                                generated_images = self.generator.sample(self.cfg.generator_size, gen_noise, fixed_condition).cpu()
                            generated_images = make_grid(
                                generated_images, nrow=8, normalize=True, value_range=(-1, 1)).numpy().transpose(1, 2, 0)
                            plt.imsave(os.path.join(self.cfg.img_dir, f"{e + 1}.jpg"), generated_images)

                            if not (e + 1) % self.cfg.checkpoint_interval:
                                pass
                                # self.save()

    def save(self):
        state_dict = {
            "generator": self.generator.state_dict(),
            "discriminator": self.discriminator.state_dict(),
            "generator_optim": self.gen_optim.state_dict(),
            "discriminator_optim": self.discr_optim.state_dict(),
            "encoder": self.encoder.state_dict()
            }
        torch.save(state_dict, self.checkpoint_path)

    def load(self, filename):
        state_dict = torch.load(filename, map_location=self.device)

        self.generator.load_state_dict(state_dict["generator"])
        self.discriminator.load_state_dict(state_dict["discriminator"])
        self.gen_optim.load_state_dict(state_dict["generator_optim"])
        self.discr_optim.load_state_dict(state_dict["discriminator_optim"])
        self.encoder.load_state_dict(state_dict["encoder"])