In [19]:
from R3GAN.Networks import Generator,Discriminator
from torch import nn
import torch.nn.functional as F
import torch
import os
from torch.utils.data import Dataset,DataLoader
from pathlib import  Path
from training.loss import R3GANLoss
from dnnlib import EasyDict
import numpy as np  
from time import time
from torch_utils import training_stats,misc
import dnnlib
import torchmetrics
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix
import psutil

In [2]:
class ImageUpsampler(nn.Module):
    def __init__(self):
        super(ImageUpsampler, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.upsample1 = nn.ConvTranspose2d(
            128, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x64 -> 64x128
        self.upsample2 = nn.ConvTranspose2d(
            128, 64, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x128 -> 64x256
        self.upsample3 = nn.ConvTranspose2d(
            64, 32, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x256 -> 64x512
        self.upsample4 = nn.ConvTranspose2d(32, 1, kernel_size=1)  # 64x512 -> 64x576

        self.final_conv = nn.Linear(512, 576)  # Refinement

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = F.relu(self.upsample1(x))
        x = F.relu(self.upsample2(x))
        x = F.relu(self.upsample3(x))
        x = F.relu(self.upsample4(x))

        x = self.final_conv(x)  # Output single-channel
        return x

In [3]:
class GeneratorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.c_dim=0
        self.z_dim=256
        self.generator = Generator(
            NoiseDimension=256,
            WidthPerStage=[128,64,32,16,1],
            CardinalityPerStage=[1,1,1,1,1],
            BlocksPerStage=[2 * x for x in [1, 1, 1, 1, 1]],
            ExpansionFactor=2,
            ConditionDimension=None,
            ConditionEmbeddingDimension=0,
            KernelSize=3,
            ResamplingFilter=[1, 2, 1],
        )
        self.upsampler = ImageUpsampler()

    def forward(self, x,c=None):
        x = self.generator(x)
        x = self.upsampler(x)
        return x

In [4]:
class ImageDownsampler(nn.Module):
    def __init__(self):
        super(ImageDownsampler, self).__init__()

        self.conv1 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)

        self.upsample1 = nn.Conv2d(
            128, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x64 -> 64x128
        self.upsample2 = nn.Conv2d(
            64, 128, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x128 -> 64x256
        self.upsample3 = nn.Conv2d(
            32, 64, kernel_size=(1, 2), stride=(1, 2)
        )  # 64x256 -> 64x512
        self.upsample4 = nn.Conv2d(1, 32, kernel_size=1)  # 64x512 -> 64x576

        self.final_conv = nn.Linear(576, 512)  # Refinement

    def forward(self, x):
        x = self.final_conv(x)
        x = F.relu(self.upsample4(x))
        x = F.relu(self.upsample3(x))
        x = F.relu(self.upsample2(x))
        x = F.relu(self.upsample1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv1(x))
        return x

In [5]:
class DiscriminatorModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = Discriminator(
            WidthPerStage=[1, 16, 32, 64, 128],
            CardinalityPerStage=[1, 1, 1, 1],
            BlocksPerStage=[2, 2, 2, 2],
            ExpansionFactor=2,
            ConditionDimension=None,
            ConditionEmbeddingDimension=0,
            KernelSize=3,
            ResamplingFilter=[1, 2, 1],
        )
        self.imagedownsampler=ImageDownsampler()
    
    def forward(self,x,c=None):
        x=self.imagedownsampler(x)
        x=self.discriminator(x)
        return x

In [6]:
class FixedBatchDataset(Dataset):
    def __init__(self, embedding_path,hidden_state_path, fixed_batch_size, transform=None, drop_last=True):
        """
        Args:
            root_dir (str): Directory containing .pt files.
            fixed_batch_size (int): The fixed number of examples per output sample.
            transform (callable, optional): Optional transform to be applied on a sample.
            drop_last (bool): If True, drop leftover samples that don’t form a full batch.
                               If False, pad them with zeros to create a full batch.
        """
        self.embedding_path =embedding_path
        self.hidden_state_path =hidden_state_path
        self.fixed_batch_size = fixed_batch_size
        self.transform = transform
        self.image_shape=[1,64,576]
        self.num_channels=1
        self.resolution=64
        self.has_labels=False
        self.has_onehot_labels=False
        self.label_shape=1
        self.label_dim=0
        self.name="something"
        
        # Find all .pt files in the root directory.
        self.embedding_files = [os.path.join(embedding_path, fname)
                      for fname in os.listdir(embedding_path) if fname.endswith('.pt')]

        # Build an index mapping from dataset index to (file_index, start_index)
        # Each file might yield multiple chunks.
        self.index_mapping = []
        for file_index, file_path in enumerate(self.embedding_files):
            # Load the tensor to know the number of samples.
            tensor = torch.load(file_path, map_location=torch.device('cpu'))
            num_samples = tensor.size(0)
            # Determine how many full fixed-size groups can be obtained.
            num_groups = num_samples // fixed_batch_size
            
            # For each full group, remember the starting index.
            for group in range(num_groups):
                start_index = group * fixed_batch_size
                self.index_mapping.append((file_index, start_index))
            
            # If drop_last is False and there are leftover samples
            # store the starting index of the remainder for later padding.
            if not drop_last and (num_samples % fixed_batch_size > 0):
                start_index = num_groups * fixed_batch_size
                self.index_mapping.append((file_index, start_index))

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        file_index, start_index = self.index_mapping[idx]
        file_path = self.embedding_files[file_index]
        # Load the entire tensor stored in the file.
        embeddings = torch.load(file_path, map_location=torch.device('cpu'))
        file_name=Path(file_path).stem
        hidden_state_path=os.path.join(self.hidden_state_path,file_name+".pt")
        hidden_state = torch.load(hidden_state_path, map_location=torch.device('cpu'))
        # Slice out the chunk that we want.
        batch_embedding = embeddings[start_index : start_index + self.fixed_batch_size]
        batch_hidden_state = hidden_state[start_index : start_index + self.fixed_batch_size]
        # pad the batch along dimension 0 with zeros.
        if batch_embedding.size(0) < self.fixed_batch_size:
            pad_amount = self.fixed_batch_size - batch_embedding.size(0)
            pad_tensor = torch.zeros((pad_amount,) + batch_embedding.shape[1:], dtype=batch_embedding.dtype)
            batch_embedding = torch.cat([batch_embedding, pad_tensor], dim=0)
            batch_hidden_state = torch.cat([batch_hidden_state, pad_tensor], dim=0)
        return batch_embedding.to(torch.float32),batch_hidden_state.to(torch.float32)
    

In [7]:
generator=GeneratorModel().to('cuda')
discriminator=DiscriminatorModel().to('cuda')

In [8]:
opt_g=torch.optim.Adam(generator.parameters(),betas=[0.0,0.0])
opt_d=torch.optim.Adam(discriminator.parameters(),betas=[0.0,0.0])

In [9]:
loss_func=R3GANLoss(G=generator,D=discriminator)

In [20]:
training_set_kwargs={"embedding_path":"/mnt/d/work/R3GAN/embeddings","hidden_state_path":"/mnt/d/work/R3GAN/hidden_states","fixed_batch_size": 4}
train_dataset=FixedBatchDataset(**training_set_kwargs)
training_set_sampler = misc.InfiniteSampler(dataset=train_dataset, rank=0, num_replicas=1, seed=0)
train_dataloader=iter(DataLoader(dataset=train_dataset,batch_size=1,sampler=training_set_sampler))
test_dataset=FixedBatchDataset(**training_set_kwargs)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=1)

In [11]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [12]:
def cosine_decay_with_warmup(cur_nimg, base_value, total_nimg, final_value=0.0, warmup_value=0.0, warmup_nimg=0, hold_base_value_nimg=0):
    decay = 0.5 * (1 + np.cos(np.pi * (cur_nimg - warmup_nimg - hold_base_value_nimg) / float(total_nimg - warmup_nimg - hold_base_value_nimg)))
    cur_value = base_value + (1 - decay) * (final_value - base_value)
    if hold_base_value_nimg > 0:
        cur_value = np.where(cur_nimg > warmup_nimg + hold_base_value_nimg, cur_value, base_value)
    if warmup_nimg > 0:
        slope = (base_value - warmup_value) / warmup_nimg
        warmup_v = slope * cur_nimg + warmup_value
        cur_value = np.where(cur_nimg < warmup_nimg, warmup_v, cur_value)
    return float(np.where(cur_nimg > total_nimg, final_value, cur_value))


In [13]:
def calculate_cosine_similarity(real_images, generated_images):
    """
    real_images: Tensor of shape (batch_size, 1, 64, 576)
    generated_images: Tensor of shape (batch_size, 1, 64, 576)
    
    Returns:
    Mean cosine similarity score for the batch
    """
    # Flatten the images to (batch_size, features)
    real_flattened = real_images.view(real_images.size(0), -1)  # Shape: (batch_size, 64*576)
    generated_flattened = generated_images.view(generated_images.size(0), -1)  # Shape: (batch_size, 64*576)

    # Compute cosine similarity
    cosine_similarity = torchmetrics.functional.pairwise_cosine_similarity(real_flattened, generated_flattened)

    # Return mean cosine similarity
    return cosine_similarity.mean()

In [18]:
metric=torchmetrics.CosineSimilarity(reduction="mean")
for i in range(10):
    a=torch.rand(size=[2,1,64,576]).flatten(start_dim=0,end_dim=2)
    b=torch.rand(size=[2,1,64,576]).flatten(start_dim=0,end_dim=2)
    # acc=metric(a,b)
    metric.update(a,b)
    # print(acc)
print(metric.compute())
metric.reset()

tensor(0.7499)


In [15]:
cur_nimg=0
ema_nimg = 5000 * 1000
decay_nimg = 2e7
lr_scheduler={ 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }
beta2_scheduler={ 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }
gamma_scheduler={ 'base_value': 0.05, 'final_value': 0.005, 'total_nimg': decay_nimg }
ema_scheduler={ 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }

cur_lr = cosine_decay_with_warmup(cur_nimg,**lr_scheduler)
cur_beta2 = cosine_decay_with_warmup(cur_nimg, **beta2_scheduler)
cur_gamma = cosine_decay_with_warmup(cur_nimg, **gamma_scheduler)
cur_ema_nimg = cosine_decay_with_warmup(cur_nimg, **ema_scheduler)

In [16]:
phases=[]
phases+=[EasyDict(name='D', module=discriminator, opt=opt_d, batch_gpu=4)]
phases+=[EasyDict(name='G', module=generator, opt=opt_g, batch_gpu=4)]
for phase in phases:
    phase.start_event = torch.cuda.Event(enable_timing=True)
    phase.end_event = torch.cuda.Event(enable_timing=True)

In [21]:
num_gpus=1
run_dir="./"
total_kimg=25000
kimg_per_tick=4
batch_size=4
rank=0
random_seed=0
cudnn_benchmark=True
start_time=time()
tick_start_time = time()
similarity_metric=torchmetrics.CosineSimilarity(reduction="mean")
device = torch.device('cuda', rank)
# device=torch.device("cpu")
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = False       # Improves numerical accuracy.
torch.backends.cudnn.allow_tf32 = False             # Improves numerical accuracy.
conv2d_gradfix.enabled = True                       # Improves training speed.
grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.
stats_collector = training_stats.Collector(regex='.*')
stats_metrics = dict()
stats_jsonl = None
stats_tfevents = None
if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

# Train.
if rank == 0:
    print(f'Training for {total_kimg} kimg...')
    print()
cur_nimg = 0
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time()
maintenance_time = tick_start_time - start_time
batch_idx = 0
for phase in phases:
    if phase.start_event is not None:
        phase.start_event.record(torch.cuda.current_stream(device))
    if phase.end_event is not None:
        phase.end_event.record(torch.cuda.current_stream(device))
for _ in range(2):
    
    with torch.autograd.profiler.record_function("data_fetch"):
        D_z, D_img = next(train_dataloader)
        D_z, D_img= D_z.squeeze(0).unsqueeze(1), D_img.squeeze(0).unsqueeze(1)
        G_z,G_img = next(train_dataloader)
        G_z,G_img=G_z.squeeze(0).unsqueeze(1),G_img.squeeze(0).unsqueeze(1)
        all_real_img = []
        all_gen_z = []
        all_real_img += [(D_img.detach().clone().to(device).to(torch.float32)).split(batch_size)]
        all_gen_z += [D_z.detach().clone().to(device).split(batch_size)]
        all_real_img += [(G_img.detach().clone().to(device).to(torch.float32)).split(batch_size)]
        all_gen_z += [G_z.detach().clone().to(device).split(batch_size)]
    cur_lr = cosine_decay_with_warmup(cur_nimg,**lr_scheduler)
    cur_beta2 = cosine_decay_with_warmup(cur_nimg, **beta2_scheduler)
    cur_gamma = cosine_decay_with_warmup(cur_nimg, **gamma_scheduler)
    cur_ema_nimg = cosine_decay_with_warmup(cur_nimg, **ema_scheduler)

    for phase, phase_gen_z, phase_real_img in zip(phases, all_gen_z, all_real_img):
        if phase.start_event is not None:
            phase.start_event.record(torch.cuda.current_stream(device))

        # Accumulate gradients.
        phase.opt.zero_grad(set_to_none=True)
        phase.module.requires_grad_(True)
        for real_img, gen_z in zip(phase_real_img, phase_gen_z):
            loss_func.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=None, gen_z=gen_z, gamma=cur_gamma, gain=1)
        phase.module.requires_grad_(False)

        # Update weights.  
        for g in phase.opt.param_groups:
            g['lr'] = cur_lr
            g['betas'] = (0, cur_beta2)
                    
        with torch.autograd.profiler.record_function(phase.name + '_opt'):
            params = [param for param in phase.module.parameters() if param.grad is not None]
            if len(params) > 0:
                flat = torch.cat([param.grad.flatten() for param in params])
                if num_gpus > 1:
                    torch.distributed.all_reduce(flat)
                    flat /= num_gpus
                grads = flat.split([param.numel() for param in params])
                for param, grad in zip(params, grads):
                    param.grad = grad.reshape(param.shape)
            phase.opt.step()

        # Phase done.
        if phase.end_event is not None:
            phase.end_event.record(torch.cuda.current_stream(device))
    done = (cur_nimg >= total_kimg * 1000)
    if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
        continue
    cur_nimg += batch_size
    tick_end_time = time()
    fields = []
    fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
    fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
    fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
    fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
    fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
    fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
    fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
    fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
    fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
    torch.cuda.reset_peak_memory_stats()
    # fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
    training_stats.report0('Progress/lr', cur_lr)
    training_stats.report0('Progress/ema_mimg', cur_ema_nimg / 1e6)
    training_stats.report0('Progress/beta2', cur_beta2)
    training_stats.report0('Progress/gamma', cur_gamma)
    training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
    training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
    with torch.no_grad():
        for embedding,hidden_state in test_dataloader:
            generated_hidden_state=generator(embedding.squeeze(0).unsqueeze(1).to(device))
            # simi=similarity_metric(generated_hidden_state.flatten(start_dim=0,end_dim=2),hidden_state.squeeze(0).unsqueeze(1).to(device).flatten(start_dim=0,end_dim=2))
            similarity_metric.update(generated_hidden_state.flatten(start_dim=0,end_dim=2),hidden_state.squeeze(0).unsqueeze(1).to(device).flatten(start_dim=0,end_dim=2))
    # fields += [f"similarity {training_stats.report0('Progress/tick', similarity_metric.compute().item()):<5d}"]
    acc=similarity_metric.compute()
    print(acc)
    similarity_metric.reset()
    if rank == 0:
        print(' '.join(fields))

Training for 25000 kimg...

tensor(-0.0207, device='cuda:0')
tick 0     kimg 0.0      time 2s           sec/tick 0.7     sec/kimg 175.92  maintenance 1.0    cpumem 1.32   gpumem 0.40   reserved 0.49  
tensor(-0.0206, device='cuda:0')
tick 0     kimg 0.0      time 5s           sec/tick 4.1     sec/kimg 512.18  maintenance 1.0    cpumem 1.46   gpumem 0.40   reserved 0.49  
