From d63ff27583ff1efa7129a450213559fa65e8a48e Mon Sep 17 00:00:00 2001 From: Harish Rajagopal Date: Sun, 27 Jun 2021 21:43:12 +0200 Subject: [PATCH] Initial contrastive temporal loss --- config.py | 2 +- models/discriminator.py | 21 ++++---------- models/loss.py | 64 ++++++++++++++++++++++++----------------- train.py | 56 +++--------------------------------- 4 files changed, 47 insertions(+), 96 deletions(-) diff --git a/config.py b/config.py index b92ec7d..4af7b64 100644 --- a/config.py +++ b/config.py @@ -98,7 +98,7 @@ class Config: gan_wt: float = 1.0 temp_nn_wt: float = 0.05 disc_steps: int = 1 - temp_disc_steps: int = 2 + temp_nn_temperature: float = 1.0 seed: int = 0 diff --git a/models/discriminator.py b/models/discriminator.py index 1870d8e..e48b16c 100644 --- a/models/discriminator.py +++ b/models/discriminator.py @@ -70,23 +70,12 @@ def __init__(self, config: Config) -> None: model = resnet18(pretrained=True) layers = list(model.children()) + self.net = nn.Sequential(*layers[:-1], nn.Flatten()) - pretrained_weights = layers[0].weight - layers[0] = nn.Conv2d( - 8, - 64, - kernel_size=(7, 7), - stride=(2, 2), - padding=(3, 3), - bias=False, - ) - layers[0].weight.data[:, :3, :, :] = nn.Parameter(pretrained_weights) - layers[0].weight.data[:, 4:7, :, :] = nn.Parameter(pretrained_weights) - - classify = nn.Linear(512, 1) - net = nn.Sequential(*layers[:-1], nn.Flatten(), classify) - self.net = spectralize(net) + # Freeze the pre-trained model + for param in self.net.parameters(): + param.requires_grad = False def forward(self, inputs: torch.Tensor) -> torch.Tensor: with autocast(enabled=self.config.mixed_precision): - return self.ckpt_run(self.net, 2, inputs).flatten() + return self.ckpt_run(self.net, 2, inputs) diff --git a/models/loss.py b/models/loss.py index 8b26108..6cab1f8 100644 --- a/models/loss.py +++ b/models/loss.py @@ -1,4 +1,3 @@ -import random from typing import Optional import torch @@ -141,37 +140,48 @@ class TemporalNNLoss(nn.Module): def __init__(self, config: Config, reduction: str = "none"): super().__init__() self.config = config - self.loss_fn = nn.BCEWithLogitsLoss(reduction=reduction) + self._log_softmax_fn = nn.LogSoftmax(dim=1) def forward( self, renders: torch.Tensor, discriminator: nn.Module ) -> torch.Tensor: - renders = renders[:, : self.config.fmo_train_steps, :4] - - loss = 0.0 - - for frame_num in range(self.config.fmo_train_steps - 1): - # Offset from the current index - offset = random.choice(range(2, self.config.fmo_train_steps)) - choice = (frame_num + offset) % self.config.fmo_train_steps - - correct = torch.cat( - (renders[:, frame_num], renders[:, frame_num + 1]), 1 - ) - incorrect = torch.cat( - (renders[:, frame_num], renders[:, choice]), 1 - ) - - correct_out = discriminator(correct) - incorrect_out = discriminator(incorrect) - loss += self.loss_fn( - incorrect_out, torch.zeros_like(incorrect_out) - ) - loss += self.loss_fn(correct_out, torch.ones_like(correct_out)) - - loss /= self.config.fmo_train_steps - 1 + renders = renders[:, : self.config.fmo_train_steps] # BxTxCxHxW + + latents_list = [ + discriminator(renders[:, frame_num]) + for frame_num in range(renders.shape[1]) + ] + latents = torch.stack(latents_list, 1) # BxTxD + + left = latents.unsqueeze(-1) # BxTxDx1 + right = left.permute(0, 3, 2, 1) # Bx1xDxT + similarity = nn.functional.cosine_similarity( + left, right, dim=-2, eps=torch.finfo(left.dtype).eps + ) # BxTxT + + # Mask out the self values + mask = torch.eye( + self.config.fmo_train_steps, device=similarity.device + ).bool() + mask_nd = mask.unsqueeze(0).tile(similarity.shape[0], 1, 1) + neg_inf = float("-inf") * torch.ones_like(similarity) + similarity = torch.where(mask_nd, neg_inf, similarity) + + log_softmax = self._log_softmax_fn( + similarity / self.config.temp_nn_temperature + ) - return loss + # All positive pairs are (i, i+1) + # - x - - - - + # - - x - - - + # - - - x - - + # - - - - x - + # - - - - - x + # - - - - - - + positive_pairs = torch.diagonal( + log_softmax, offset=1, dim1=-2, dim2=-1 + ) # Bx(T-1) + return -positive_pairs.mean(dim=-1) # B dimensional vector def oflow_loss(renders: torch.Tensor) -> torch.Tensor: diff --git a/train.py b/train.py index f02c88d..424358a 100755 --- a/train.py +++ b/train.py @@ -53,7 +53,6 @@ class Trainer: ENC_PREFIX: Final = "encoder" RENDER_PREFIX: Final = "rendering" DISC_PREFIX: Final = "discriminator" - TEMP_DISC_PREFIX: Final = "temp_disc" BEST_SUFFIX: Final = "_best" # Used when saving training state @@ -61,7 +60,6 @@ class Trainer: GLOBAL_STEP_KEY: Final = "global_step" MODEL_OPT_KEY: Final = "model_optim" DISC_OPT_KEY: Final = "disc_optim" - TEMP_DISC_OPT_KEY: Final = "temp_disc_optim" def __init__( self, @@ -230,15 +228,6 @@ def _init_optimizers(self, load_folder: Optional[Path]) -> None: step_size=self.config.sched_step_size, gamma=0.5, ) - if self.config.use_nn_timeconsistency: - self.temp_disc_optim = torch.optim.Adam( - self.temp_disc.parameters(), lr=self.config.temp_disc_lr - ) - self.temp_disc_sched = torch.optim.lr_scheduler.StepLR( - self.temp_disc_optim, - step_size=self.config.sched_step_size, - gamma=0.5, - ) self.scaler = torch.cuda.amp.GradScaler( enabled=self.config.mixed_precision @@ -264,11 +253,6 @@ def load_weights(self, load_folder: Path) -> None: torch.load(load_folder / f"{self.DISC_PREFIX}.pt") ) - if self.config.use_nn_timeconsistency: - self.temp_disc.load_state_dict( - torch.load(load_folder / f"{self.TEMP_DISC_PREFIX}.pt") - ) - def load_state(self, load_folder: Path) -> None: """Load training state from a previous checkpoint.""" load_folder = load_folder.expanduser() @@ -278,8 +262,6 @@ def load_state(self, load_folder: Path) -> None: self.model_optim.load_state_dict(state[self.MODEL_OPT_KEY]) if self.config.use_gan_loss: self.disc_optim.load_state_dict(state[self.DISC_OPT_KEY]) - if self.config.use_nn_timeconsistency: - self.temp_disc_optim.load_state_dict(state[self.TEMP_DISC_OPT_KEY]) def save_weights(self, save_best: bool = False) -> None: """Save weights to disk.""" @@ -297,11 +279,6 @@ def save_weights(self, save_best: bool = False) -> None: self.discriminator.module.state_dict(), self.save_folder / f"{self.DISC_PREFIX}{suffix}", ) - if self.config.use_nn_timeconsistency: - torch.save( - self.temp_disc.module.state_dict(), - self.save_folder / f"{self.TEMP_DISC_PREFIX}{suffix}", - ) def save_state(self, global_step: int) -> None: """Save training state to disk.""" @@ -313,8 +290,6 @@ def save_state(self, global_step: int) -> None: } if self.config.use_gan_loss: state[self.DISC_OPT_KEY] = self.disc_optim.state_dict() - if self.config.use_nn_timeconsistency: - state[self.TEMP_DISC_OPT_KEY] = self.temp_disc_optim.state_dict() torch.save(state, self.save_folder / f"{self.STATE_PREFIX}.pt") @@ -328,8 +303,6 @@ def train(self, log_steps: int, save_steps: int) -> None: self.model_sched.step() if self.config.use_gan_loss: self.disc_sched.step() - if self.config.use_nn_timeconsistency: - self.temp_disc_sched.step() best_val_loss = float("inf") running_losses = _Losses() @@ -347,8 +320,6 @@ def train(self, log_steps: int, save_steps: int) -> None: ) if self.config.use_gan_loss: self._train_step_disc(renders, inputs[2]) - if self.config.use_nn_timeconsistency: - self._train_step_temp_disc(renders) global_step += 1 @@ -376,8 +347,6 @@ def train(self, log_steps: int, save_steps: int) -> None: self.model_sched.step() if self.config.use_gan_loss: self.disc_sched.step() - if self.config.use_nn_timeconsistency: - self.temp_disc_sched.step() torch.cuda.empty_cache() self.save_state(global_step) @@ -392,8 +361,6 @@ def _train_step( self.rendering.train() if self.config.use_gan_loss: self.discriminator.module.freeze() - if self.config.use_nn_timeconsistency: - self.temp_disc.module.freeze() self.model_optim.zero_grad() @@ -435,7 +402,10 @@ def _train_step( jloss += self.config.gan_wt * gen_loss if self.config.use_nn_timeconsistency: - temp_nn_loss = self.temp_nn_fn(renders, self.temp_disc) + mask = renders[:, :, 3:4] # for broadcasting + fg = renders[:, :, :3] * mask + bg = input_batch[:, 3:6].unsqueeze(1) * (1 - mask) + temp_nn_loss = self.temp_nn_fn(fg + bg, self.temp_disc) running_losses.temp_nn += temp_nn_loss.mean().item() jloss += self.config.temp_nn_wt * temp_nn_loss @@ -473,24 +443,6 @@ def _train_step_disc( self.scaler.step(self.disc_optim) self.scaler.update() - def _train_step_temp_disc(self, renders: torch.Tensor) -> None: - """Train the temporal discriminator.""" - outputs = renders.detach() - # Needed for gradient checkpointing - outputs.requires_grad = True - - self.temp_disc.module.unfreeze() - - for _ in range(self.config.temp_disc_steps): - self.temp_disc_optim.zero_grad() - - with autocast(enabled=self.config.mixed_precision): - temp_nn_loss = self.temp_nn_fn(outputs, self.temp_disc).mean() - - self.scaler.scale(temp_nn_loss).backward() - self.scaler.step(self.temp_disc_optim) - self.scaler.update() - def save_logs( self, global_step: int,