In [None]:
import sys
sys.path.append('/content/R3GAN/') # Replace with the actual path


In [9]:
from R3GAN.Networks import Generator,Discriminator
from torch import nn
import torch.nn.functional as F
import torch
import os
from torch.utils.data import Dataset,DataLoader
from pathlib import  Path
from training.loss import R3GANLoss

In [2]:
class ImageUpsampler(nn.Module):
    def __init__(self):
        super(ImageUpsampler, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.upsample1 = nn.ConvTranspose2d(
            128, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x64 -> 64x128
        self.upsample2 = nn.ConvTranspose2d(
            128, 64, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x128 -> 64x256
        self.upsample3 = nn.ConvTranspose2d(
            64, 32, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x256 -> 64x512
        self.upsample4 = nn.ConvTranspose2d(32, 1, kernel_size=1)  # 64x512 -> 64x576

        self.final_conv = nn.Linear(512, 576)  # Refinement

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = F.relu(self.upsample1(x))
        x = F.relu(self.upsample2(x))
        x = F.relu(self.upsample3(x))
        x = F.relu(self.upsample4(x))

        x = self.final_conv(x)  # Output single-channel
        return x

In [4]:
class GeneratorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.c_dim=0
        self.z_dim=256
        self.generator = Generator(
            NoiseDimension=256,
            WidthPerStage=[128, 64, 32, 16, 1],
            CardinalityPerStage=[1, 1, 1, 1, 1],
            BlocksPerStage=[2, 2, 2, 2, 2],
            ExpansionFactor=2,
            ConditionDimension=None,
            ConditionEmbeddingDimension=0,
            KernelSize=3,
            ResamplingFilter=[1, 2, 1],
        )
        self.upsampler = ImageUpsampler()

    def forward(self, x,c=None):
        x = self.generator(x)
        x = self.upsampler(x)
        return x

In [5]:
class ImageDownsampler(nn.Module):
    def __init__(self):
        super(ImageDownsampler, self).__init__()

        self.conv1 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)

        self.upsample1 = nn.Conv2d(
            128, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x64 -> 64x128
        self.upsample2 = nn.Conv2d(
            64, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x128 -> 64x256
        self.upsample3 = nn.Conv2d(
            32, 64, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x256 -> 64x512
        self.upsample4 = nn.Conv2d(1, 32, kernel_size=1)  # 64x512 -> 64x576

        self.final_conv = nn.Linear(576, 512)  # Refinement

    def forward(self, x):
        x = self.final_conv(x)
        x = F.relu(self.upsample4(x))
        x = F.relu(self.upsample3(x))
        x = F.relu(self.upsample2(x))
        x = F.relu(self.upsample1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv1(x))
        return x

In [23]:
class DiscriminatorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = Discriminator(
            WidthPerStage=[1, 16, 32, 64, 128],
            CardinalityPerStage=[1, 1, 1, 1],
            BlocksPerStage=[2, 2, 2, 2],
            ExpansionFactor=2,
            ConditionDimension=None,
            ConditionEmbeddingDimension=0,
            KernelSize=3,
            ResamplingFilter=[1, 2, 1],
        )
        self.imagedownsampler=ImageDownsampler()
    
    def forward(self,x,c=None):
        x=self.imagedownsampler(x)
        x=self.discriminator(x)
        return x

In [7]:


class FixedBatchDataset(Dataset):
    def __init__(self, embedding_path,hidden_state_path, fixed_batch_size, transform=None, drop_last=True):
        """
        Args:
            root_dir (str): Directory containing .pt files.
            fixed_batch_size (int): The fixed number of examples per output sample.
            transform (callable, optional): Optional transform to be applied on a sample.
            drop_last (bool): If True, drop leftover samples that don’t form a full batch.
                               If False, pad them with zeros to create a full batch.
        """
        self.embedding_path =embedding_path
        self.hidden_state_path =hidden_state_path
        self.fixed_batch_size = fixed_batch_size
        self.transform = transform
        self.image_shape=[64,576]
        self.num_channels=1
        self.resolution=64
        self.has_labels=False
        self.has_onehot_labels=False
        self.label_shape=1
        self.label_dim=0
        self.name="something"
        
        # Find all .pt files in the root directory.
        self.embedding_files = [os.path.join(embedding_path, fname)
                      for fname in os.listdir(embedding_path) if fname.endswith('.pt')]

        # Build an index mapping from dataset index to (file_index, start_index)
        # Each file might yield multiple chunks.
        self.index_mapping = []
        for file_index, file_path in enumerate(self.embedding_files):
            # Load the tensor to know the number of samples.
            tensor = torch.load(file_path, map_location=torch.device('cpu'))
            num_samples = tensor.size(0)
            # Determine how many full fixed-size groups can be obtained.
            num_groups = num_samples // fixed_batch_size
            
            # For each full group, remember the starting index.
            for group in range(num_groups):
                start_index = group * fixed_batch_size
                self.index_mapping.append((file_index, start_index))
            
            # If drop_last is False and there are leftover samples
            # store the starting index of the remainder for later padding.
            if not drop_last and (num_samples % fixed_batch_size > 0):
                start_index = num_groups * fixed_batch_size
                self.index_mapping.append((file_index, start_index))

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        file_index, start_index = self.index_mapping[idx]
        file_path = self.embedding_files[file_index]
        # Load the entire tensor stored in the file.
        embeddings = torch.load(file_path, map_location=torch.device('cpu'))
        file_name=Path(file_path).stem
        hidden_state_path=os.path.join(self.hidden_state_path,file_name+".pt")
        hidden_state = torch.load(hidden_state_path, map_location=torch.device('cpu'))
        # Slice out the chunk that we want.
        batch_embedding = embeddings[start_index : start_index + self.fixed_batch_size]
        batch_hidden_state = hidden_state[start_index : start_index + self.fixed_batch_size]
        # pad the batch along dimension 0 with zeros.
        if batch_embedding.size(0) < self.fixed_batch_size:
            pad_amount = self.fixed_batch_size - batch_embedding.size(0)
            pad_tensor = torch.zeros((pad_amount,) + batch_embedding.shape[1:], dtype=batch_embedding.dtype)
            batch_embedding = torch.cat([batch_embedding, pad_tensor], dim=0)
            batch_hidden_state = torch.cat([batch_hidden_state, pad_tensor], dim=0)
        return batch_embedding,batch_hidden_state
    

In [16]:
generator=GeneratorModel()
discriminator=DiscriminatorModel()

In [11]:
training_set_kwargs={"embedding_path":"D:/Xelpmoc/R3GAN/embeddings","hidden_state_path":"D:/Xelpmoc/R3GAN/hidden_states"}
dataset=FixedBatchDataset(**training_set_kwargs,fixed_batch_size=4)
dataloader=DataLoader(dataset=dataset,batch_size=1)

  tensor = torch.load(file_path, map_location=torch.device('cpu'))


In [17]:
for embeddings,hidden_states in dataloader:
    print(embeddings.squeeze(0).unsqueeze(1).shape)
    print(hidden_states.squeeze(0).unsqueeze(1).shape)
    out=generator(embeddings.squeeze(0).unsqueeze(1))
    out2=discriminator(out)
    print(out.shape)
    print(out2.shape)
    break

  embeddings = torch.load(file_path, map_location=torch.device('cpu'))
  hidden_state = torch.load(hidden_state_path, map_location=torch.device('cpu'))


torch.Size([4, 1, 256])
torch.Size([4, 1, 64, 576])
torch.Size([4, 1, 64, 576])
torch.Size([4])


In [10]:
g_opt=torch.optim.Adam(generator.parameters(),betas=[0,0])
d_opt=torch.optim.Adam(discriminator.parameters(),betas=[0,0])

In [11]:
loss_func=R3GANLoss(G=generator,D=discriminator)

In [24]:
import os
import time
import copy
import json
import pickle
import psutil
import PIL.Image
import numpy as np
import torch
import dnnlib
from torch_utils import misc
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix

import legacy
from metrics import metric_main

def cosine_decay_with_warmup(cur_nimg, base_value, total_nimg, final_value=0.0, warmup_value=0.0, warmup_nimg=0, hold_base_value_nimg=0):
    decay = 0.5 * (1 + np.cos(np.pi * (cur_nimg - warmup_nimg - hold_base_value_nimg) / float(total_nimg - warmup_nimg - hold_base_value_nimg)))
    cur_value = base_value + (1 - decay) * (final_value - base_value)
    if hold_base_value_nimg > 0:
        cur_value = np.where(cur_nimg > warmup_nimg + hold_base_value_nimg, cur_value, base_value)
    if warmup_nimg > 0:
        slope = (base_value - warmup_value) / warmup_nimg
        warmup_v = slope * cur_nimg + warmup_value
        cur_value = np.where(cur_nimg < warmup_nimg, warmup_v, cur_value)
    return float(np.where(cur_nimg > total_nimg, final_value, cur_value))

#----------------------------------------------------------------------------

def setup_snapshot_image_grid(training_set, random_seed=0):
    rnd = np.random.RandomState(random_seed)
    gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
    gh = np.clip(4320 // training_set.image_shape[1], 4, 32)

    # No labels => show random subset of training samples.
    if not training_set.has_labels:
        all_indices = list(range(len(training_set)))
        rnd.shuffle(all_indices)
        grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]

    else:
        # Group training samples by label.
        label_groups = dict() # label => [idx, ...]
        for idx in range(len(training_set)):
            label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
            if label not in label_groups:
                label_groups[label] = []
            label_groups[label].append(idx)

        # Reorder.
        label_order = sorted(label_groups.keys())
        for label in label_order:
            rnd.shuffle(label_groups[label])

        # Organize into grid.
        grid_indices = []
        for y in range(gh):
            label = label_order[y % len(label_order)]
            indices = label_groups[label]
            grid_indices += [indices[x % len(indices)] for x in range(gw)]
            label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]

    # Load data.
    images, labels = zip(*[training_set[i] for i in grid_indices])
    return (gw, gh), np.stack(images), np.stack(labels)

#----------------------------------------------------------------------------

def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    img = img.reshape([gh, gw, C, H, W])
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape([gh * H, gw * W, C])

    assert C in [1, 3]
    if C == 1:
        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
    if C == 3:
        PIL.Image.fromarray(img, 'RGB').save(fname)

#----------------------------------------------------------------------------

def remap_optimizer_state_dict(state_dict, device):
    state_dict = copy.deepcopy(state_dict)
    for param in state_dict['state'].values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)
    return state_dict

#----------------------------------------------------------------------------

def training_loop(
    run_dir                 = '.',      # Output directory.
    training_set_kwargs     = {},       # Options for training set.
    data_loader_kwargs      = {},       # Options for torch.utils.data.DataLoader.      # Options for discriminator network.
    G_opt_kwargs            = {},       # Options for generator optimizer.
    D_opt_kwargs            = {},       # Options for discriminator optimizer.
    lr_scheduler            = None,
    beta2_scheduler         = None,
    augment_kwargs          = None,     # Options for augmentation pipeline. None = disable.
    loss_kwargs             = {},       # Options for loss function.
    gamma_scheduler         = None,
    metrics                 = [],       # Metrics to evaluate during training.
    random_seed             = 0,        # Global random seed.
    num_gpus                = 1,        # Number of GPUs participating in the training.
    rank                    = 0,        # Rank of the current process in [0, num_gpus[.
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
    g_batch_gpu             = 4,        # Number of samples processed at a time by one GPU.
    d_batch_gpu             = 4,        # Number of samples processed at a time by one GPU.
    ema_scheduler           = None,
    aug_scheduler           = None,
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    kimg_per_tick           = 4,        # Progress snapshot interval.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = disable.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable.
    resume_pkl              = None,     # Network pickle to resume training from.
    cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
    abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
    progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
):
    # Initialize.
    start_time = time.time()
    # device = torch.device('cuda', rank)
    device=torch.device("cpu")
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
    torch.backends.cuda.matmul.allow_tf32 = False       # Improves numerical accuracy.
    torch.backends.cudnn.allow_tf32 = False             # Improves numerical accuracy.
    conv2d_gradfix.enabled = True                       # Improves training speed.
    grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set =FixedBatchDataset(embedding_path=training_set_kwargs['embedding_path'],hidden_state_path=training_set_kwargs['hidden_state_path'],fixed_batch_size=batch_size) # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
    training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution)
    G = GeneratorModel().train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    D = DiscriminatorModel().train().requires_grad_(False).to(device) # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()

    # Resume from existing pickle.
    if resume_pkl is not None:
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        if rank == 0:
            print(f'Resuming from "{resume_pkl}"')
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
                misc.copy_params_and_buffers(resume_data[name], module, require_all=False)

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([min(g_batch_gpu, d_batch_gpu), G.z_dim], device=device)
        c = torch.empty([min(g_batch_gpu, d_batch_gpu), G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Setup augmentation.
    if rank == 0:
        print('Setting up augmentation...')
    augment_pipe = None

    if (augment_kwargs is not None) and (aug_scheduler is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
        
    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    for module in [G, D, G_ema]:
        if module is not None and num_gpus > 1:
            for param in misc.params_and_buffers(module):
                torch.distributed.broadcast(param, src=0)

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss
    phases = []
    
    opt = dnnlib.util.construct_class_by_name(params=D.parameters(), **D_opt_kwargs)
    if resume_pkl is not None:
        opt.load_state_dict(remap_optimizer_state_dict(resume_data['D_opt_state'], device))
    phases += [dnnlib.EasyDict(name='D', module=D, opt=opt, batch_gpu=d_batch_gpu)]
    
    opt = dnnlib.util.construct_class_by_name(params=G.parameters(), **G_opt_kwargs)
    if resume_pkl is not None:
        opt.load_state_dict(remap_optimizer_state_dict(resume_data['G_opt_state'], device))
    phases += [dnnlib.EasyDict(name='G', module=G, opt=opt, batch_gpu=g_batch_gpu)]
    
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
        save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
        grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(g_batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(g_batch_gpu)
        images = torch.cat([G_ema(z, c).cpu() for z, c in zip(grid_z, grid_c)]).to(torch.float).numpy()
        save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = resume_data['cur_nimg'] if resume_pkl is not None else 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
        
    # Dummy Timing, required to fix phase shift
    for phase in phases:
        if phase.start_event is not None:
            phase.start_event.record(torch.cuda.current_stream(device))
        if phase.end_event is not None:
            phase.end_event.record(torch.cuda.current_stream(device))
        
    while True:
        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            D_z, D_img = next(training_set_iterator)
            D_z, D_img= D_z.squeeze(0).unsqueeze(1), D_img.squeeze(0).unsqueeze(1)
            # D_z = torch.randn([batch_size, G.z_dim], device=device)
            D_img_c=torch.zeros(size=[batch_size])
            
            G_z,G_img = next(training_set_iterator)
            G_z,G_img=G_z.squeeze(0).unsqueeze(1),G_img.squeeze(0).unsqueeze(1)
            # G_z = torch.randn([batch_size, G.z_dim], device=device)
            G_img_c=torch.zeros(size=[batch_size])
            
            all_real_img = []
            all_real_c = []
            all_gen_z = []
            
            # D
            all_real_img += [(D_img.detach().clone().to(device).to(torch.float32)).split(d_batch_gpu)]
            all_real_c += [D_img_c.detach().clone().to(device).split(d_batch_gpu)]
            all_gen_z += [D_z.detach().clone().split(d_batch_gpu)]
            
            # G
            all_real_img += [(G_img.detach().clone().to(device).to(torch.float32)).split(g_batch_gpu)]
            all_real_c += [G_img_c.detach().clone().to(device).split(g_batch_gpu)]
            all_gen_z += [G_z.detach().clone().split(g_batch_gpu)]
            
        cur_lr = cosine_decay_with_warmup(cur_nimg, **lr_scheduler)
        cur_beta2 = cosine_decay_with_warmup(cur_nimg, **beta2_scheduler)
        cur_gamma = cosine_decay_with_warmup(cur_nimg, **gamma_scheduler)
        cur_ema_nimg = cosine_decay_with_warmup(cur_nimg, **ema_scheduler)
        cur_aug_p = cosine_decay_with_warmup(cur_nimg, **aug_scheduler)
        
        if augment_pipe is not None:
            augment_pipe.p.copy_(misc.constant(cur_aug_p, device=device))
        
        # Execute training phases.
        for phase, phase_gen_z, phase_real_img, phase_real_c in zip(phases, all_gen_z, all_real_img, all_real_c):
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))

            # Accumulate gradients.
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)
            for real_img, real_c, gen_z in zip(phase_real_img, phase_real_c, phase_gen_z):
                loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gamma=cur_gamma, gain=num_gpus * phase.batch_gpu / batch_size)
            phase.module.requires_grad_(False)
        
            # Update weights.  
            for g in phase.opt.param_groups:
                g['lr'] = cur_lr
                g['betas'] = (0, cur_beta2)
                      
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                params = [param for param in phase.module.parameters() if param.grad is not None]
                if len(params) > 0:
                    flat = torch.cat([param.grad.flatten() for param in params])
                    if num_gpus > 1:
                        torch.distributed.all_reduce(flat)
                        flat /= num_gpus
                    grads = flat.split([param.numel() for param in params])
                    for param, grad in zip(params, grads):
                        param.grad = grad.reshape(param.shape)
                phase.opt.step()

            # Phase done.
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_beta = 0.5 ** (batch_size / max(cur_ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
        fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
        fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
        fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
        fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
        fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
        fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
        fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
        torch.cuda.reset_peak_memory_stats()
        fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
        training_stats.report0('Progress/lr', cur_lr)
        training_stats.report0('Progress/ema_mimg', cur_ema_nimg / 1e6)
        training_stats.report0('Progress/beta2', cur_beta2)
        training_stats.report0('Progress/gamma', cur_gamma)
        training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([G_ema(z, c).cpu() for z, c in zip(grid_z, grid_c)]).to(torch.float).numpy()
            save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:09d}.png'), drange=[-1,1], grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(G=G, D=D, G_ema=G_ema, training_set_kwargs=dict(training_set_kwargs), cur_nimg=cur_nimg)
            for phase in phases:
                snapshot_data[phase.name + '_opt_state'] = remap_optimizer_state_dict(phase.opt.state_dict(), 'cpu')
            for key, value in snapshot_data.items():
                if isinstance(value, torch.nn.Module):
                    value = copy.deepcopy(value).eval().requires_grad_(False)
                    if num_gpus > 1:
                        misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)')
                        for param in misc.params_and_buffers(value):
                            torch.distributed.broadcast(param, src=0)
                    snapshot_data[key] = value.cpu()
                del value # conserve memory
            snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:09d}.pkl')
            if rank == 0:
                with open(snapshot_pkl, 'wb') as f:
                    pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        del snapshot_data # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')

#----------------------------------------------------------------------------


In [25]:
ema_nimg = 5000 * 1000
decay_nimg = 2e7
training_loop(run_dir="./",
              training_set_kwargs={"embedding_path":"D:\\Xelpmoc\\R3GAN\\embeddings","hidden_state_path":"D:\\Xelpmoc\\R3GAN\\hidden_states"},
              data_loader_kwargs=dnnlib.EasyDict(pin_memory=True, prefetch_factor=2,num_workers=2),
              G_opt_kwargs=dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0], eps=1e-8),
              D_opt_kwargs=dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0], eps=1e-8),
              lr_scheduler={ 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg },
              beta2_scheduler={ 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg },
              loss_kwargs=dnnlib.EasyDict(class_name='training.loss.R3GANLoss'),
              gamma_scheduler={ 'base_value': 0.05, 'final_value': 0.005, 'total_nimg': decay_nimg },
              metrics=['fid50k_full'],
              ema_scheduler={ 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg })

Loading training set...

Num images:  33
Image shape: [64, 576]
Label shape: 1

Constructing networks...


  tensor = torch.load(file_path, map_location=torch.device('cpu'))



GeneratorModel                   Parameters  Buffers  Output shape       Datatype
---                              ---         ---      ---                ---     
generator.MainLayers.0.Layers.0  34816       -        [4, 128, 4, 4]     float32 
generator.MainLayers.0.Layers.1  655872      -        [4, 128, 4, 4]     float32 
generator.MainLayers.0.Layers.2  655872      -        [4, 128, 4, 4]     float32 
generator.MainLayers.1.Layers.0  8192        16       [4, 64, 8, 8]      float32 
generator.MainLayers.1.Layers.1  164096      -        [4, 64, 8, 8]      float32 
generator.MainLayers.1.Layers.2  164096      -        [4, 64, 8, 8]      float32 
generator.MainLayers.2.Layers.0  2048        16       [4, 32, 16, 16]    float32 
generator.MainLayers.2.Layers.1  41088       -        [4, 32, 16, 16]    float32 
generator.MainLayers.2.Layers.2  41088       -        [4, 32, 16, 16]    float32 
generator.MainLayers.3.Layers.0  512         16       [4, 16, 32, 32]    float32 
generator.MainL

RuntimeError: Tried to instantiate dummy base class Event