Skip to content
This repository has been archived by the owner on Oct 31, 2022. It is now read-only.

Commit

Permalink
Initial contrastive temporal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
rharish101 committed Sep 2, 2021
1 parent 97c7a9b commit d63ff27
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 96 deletions.
2 changes: 1 addition & 1 deletion config.py
Expand Up @@ -98,7 +98,7 @@ class Config:
gan_wt: float = 1.0
temp_nn_wt: float = 0.05
disc_steps: int = 1
temp_disc_steps: int = 2
temp_nn_temperature: float = 1.0

seed: int = 0

Expand Down
21 changes: 5 additions & 16 deletions models/discriminator.py
Expand Up @@ -70,23 +70,12 @@ def __init__(self, config: Config) -> None:

model = resnet18(pretrained=True)
layers = list(model.children())
self.net = nn.Sequential(*layers[:-1], nn.Flatten())

pretrained_weights = layers[0].weight
layers[0] = nn.Conv2d(
8,
64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
layers[0].weight.data[:, :3, :, :] = nn.Parameter(pretrained_weights)
layers[0].weight.data[:, 4:7, :, :] = nn.Parameter(pretrained_weights)

classify = nn.Linear(512, 1)
net = nn.Sequential(*layers[:-1], nn.Flatten(), classify)
self.net = spectralize(net)
# Freeze the pre-trained model
for param in self.net.parameters():
param.requires_grad = False

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
with autocast(enabled=self.config.mixed_precision):
return self.ckpt_run(self.net, 2, inputs).flatten()
return self.ckpt_run(self.net, 2, inputs)
64 changes: 37 additions & 27 deletions models/loss.py
@@ -1,4 +1,3 @@
import random
from typing import Optional

import torch
Expand Down Expand Up @@ -141,37 +140,48 @@ class TemporalNNLoss(nn.Module):
def __init__(self, config: Config, reduction: str = "none"):
super().__init__()
self.config = config
self.loss_fn = nn.BCEWithLogitsLoss(reduction=reduction)
self._log_softmax_fn = nn.LogSoftmax(dim=1)

def forward(
self, renders: torch.Tensor, discriminator: nn.Module
) -> torch.Tensor:
renders = renders[:, : self.config.fmo_train_steps, :4]

loss = 0.0

for frame_num in range(self.config.fmo_train_steps - 1):
# Offset from the current index
offset = random.choice(range(2, self.config.fmo_train_steps))
choice = (frame_num + offset) % self.config.fmo_train_steps

correct = torch.cat(
(renders[:, frame_num], renders[:, frame_num + 1]), 1
)
incorrect = torch.cat(
(renders[:, frame_num], renders[:, choice]), 1
)

correct_out = discriminator(correct)
incorrect_out = discriminator(incorrect)
loss += self.loss_fn(
incorrect_out, torch.zeros_like(incorrect_out)
)
loss += self.loss_fn(correct_out, torch.ones_like(correct_out))

loss /= self.config.fmo_train_steps - 1
renders = renders[:, : self.config.fmo_train_steps] # BxTxCxHxW

latents_list = [
discriminator(renders[:, frame_num])
for frame_num in range(renders.shape[1])
]
latents = torch.stack(latents_list, 1) # BxTxD

left = latents.unsqueeze(-1) # BxTxDx1
right = left.permute(0, 3, 2, 1) # Bx1xDxT
similarity = nn.functional.cosine_similarity(
left, right, dim=-2, eps=torch.finfo(left.dtype).eps
) # BxTxT

# Mask out the self values
mask = torch.eye(
self.config.fmo_train_steps, device=similarity.device
).bool()
mask_nd = mask.unsqueeze(0).tile(similarity.shape[0], 1, 1)
neg_inf = float("-inf") * torch.ones_like(similarity)
similarity = torch.where(mask_nd, neg_inf, similarity)

log_softmax = self._log_softmax_fn(
similarity / self.config.temp_nn_temperature
)

return loss
# All positive pairs are (i, i+1)
# - x - - - -
# - - x - - -
# - - - x - -
# - - - - x -
# - - - - - x
# - - - - - -
positive_pairs = torch.diagonal(
log_softmax, offset=1, dim1=-2, dim2=-1
) # Bx(T-1)
return -positive_pairs.mean(dim=-1) # B dimensional vector


def oflow_loss(renders: torch.Tensor) -> torch.Tensor:
Expand Down
56 changes: 4 additions & 52 deletions train.py
Expand Up @@ -53,15 +53,13 @@ class Trainer:
ENC_PREFIX: Final = "encoder"
RENDER_PREFIX: Final = "rendering"
DISC_PREFIX: Final = "discriminator"
TEMP_DISC_PREFIX: Final = "temp_disc"
BEST_SUFFIX: Final = "_best"

# Used when saving training state
STATE_PREFIX: Final = "train_state"
GLOBAL_STEP_KEY: Final = "global_step"
MODEL_OPT_KEY: Final = "model_optim"
DISC_OPT_KEY: Final = "disc_optim"
TEMP_DISC_OPT_KEY: Final = "temp_disc_optim"

def __init__(
self,
Expand Down Expand Up @@ -230,15 +228,6 @@ def _init_optimizers(self, load_folder: Optional[Path]) -> None:
step_size=self.config.sched_step_size,
gamma=0.5,
)
if self.config.use_nn_timeconsistency:
self.temp_disc_optim = torch.optim.Adam(
self.temp_disc.parameters(), lr=self.config.temp_disc_lr
)
self.temp_disc_sched = torch.optim.lr_scheduler.StepLR(
self.temp_disc_optim,
step_size=self.config.sched_step_size,
gamma=0.5,
)

self.scaler = torch.cuda.amp.GradScaler(
enabled=self.config.mixed_precision
Expand All @@ -264,11 +253,6 @@ def load_weights(self, load_folder: Path) -> None:
torch.load(load_folder / f"{self.DISC_PREFIX}.pt")
)

if self.config.use_nn_timeconsistency:
self.temp_disc.load_state_dict(
torch.load(load_folder / f"{self.TEMP_DISC_PREFIX}.pt")
)

def load_state(self, load_folder: Path) -> None:
"""Load training state from a previous checkpoint."""
load_folder = load_folder.expanduser()
Expand All @@ -278,8 +262,6 @@ def load_state(self, load_folder: Path) -> None:
self.model_optim.load_state_dict(state[self.MODEL_OPT_KEY])
if self.config.use_gan_loss:
self.disc_optim.load_state_dict(state[self.DISC_OPT_KEY])
if self.config.use_nn_timeconsistency:
self.temp_disc_optim.load_state_dict(state[self.TEMP_DISC_OPT_KEY])

def save_weights(self, save_best: bool = False) -> None:
"""Save weights to disk."""
Expand All @@ -297,11 +279,6 @@ def save_weights(self, save_best: bool = False) -> None:
self.discriminator.module.state_dict(),
self.save_folder / f"{self.DISC_PREFIX}{suffix}",
)
if self.config.use_nn_timeconsistency:
torch.save(
self.temp_disc.module.state_dict(),
self.save_folder / f"{self.TEMP_DISC_PREFIX}{suffix}",
)

def save_state(self, global_step: int) -> None:
"""Save training state to disk."""
Expand All @@ -313,8 +290,6 @@ def save_state(self, global_step: int) -> None:
}
if self.config.use_gan_loss:
state[self.DISC_OPT_KEY] = self.disc_optim.state_dict()
if self.config.use_nn_timeconsistency:
state[self.TEMP_DISC_OPT_KEY] = self.temp_disc_optim.state_dict()

torch.save(state, self.save_folder / f"{self.STATE_PREFIX}.pt")

Expand All @@ -328,8 +303,6 @@ def train(self, log_steps: int, save_steps: int) -> None:
self.model_sched.step()
if self.config.use_gan_loss:
self.disc_sched.step()
if self.config.use_nn_timeconsistency:
self.temp_disc_sched.step()

best_val_loss = float("inf")
running_losses = _Losses()
Expand All @@ -347,8 +320,6 @@ def train(self, log_steps: int, save_steps: int) -> None:
)
if self.config.use_gan_loss:
self._train_step_disc(renders, inputs[2])
if self.config.use_nn_timeconsistency:
self._train_step_temp_disc(renders)

global_step += 1

Expand Down Expand Up @@ -376,8 +347,6 @@ def train(self, log_steps: int, save_steps: int) -> None:
self.model_sched.step()
if self.config.use_gan_loss:
self.disc_sched.step()
if self.config.use_nn_timeconsistency:
self.temp_disc_sched.step()

torch.cuda.empty_cache()
self.save_state(global_step)
Expand All @@ -392,8 +361,6 @@ def _train_step(
self.rendering.train()
if self.config.use_gan_loss:
self.discriminator.module.freeze()
if self.config.use_nn_timeconsistency:
self.temp_disc.module.freeze()

self.model_optim.zero_grad()

Expand Down Expand Up @@ -435,7 +402,10 @@ def _train_step(
jloss += self.config.gan_wt * gen_loss

if self.config.use_nn_timeconsistency:
temp_nn_loss = self.temp_nn_fn(renders, self.temp_disc)
mask = renders[:, :, 3:4] # for broadcasting
fg = renders[:, :, :3] * mask
bg = input_batch[:, 3:6].unsqueeze(1) * (1 - mask)
temp_nn_loss = self.temp_nn_fn(fg + bg, self.temp_disc)
running_losses.temp_nn += temp_nn_loss.mean().item()
jloss += self.config.temp_nn_wt * temp_nn_loss

Expand Down Expand Up @@ -473,24 +443,6 @@ def _train_step_disc(
self.scaler.step(self.disc_optim)
self.scaler.update()

def _train_step_temp_disc(self, renders: torch.Tensor) -> None:
"""Train the temporal discriminator."""
outputs = renders.detach()
# Needed for gradient checkpointing
outputs.requires_grad = True

self.temp_disc.module.unfreeze()

for _ in range(self.config.temp_disc_steps):
self.temp_disc_optim.zero_grad()

with autocast(enabled=self.config.mixed_precision):
temp_nn_loss = self.temp_nn_fn(outputs, self.temp_disc).mean()

self.scaler.scale(temp_nn_loss).backward()
self.scaler.step(self.temp_disc_optim)
self.scaler.update()

def save_logs(
self,
global_step: int,
Expand Down

0 comments on commit d63ff27

Please sign in to comment.