# Train predictor using autoregressive loss

In [None]:
#| default_exp autoregressive_trainer

In [None]:
#| export
import os
import random

import lightning as pl
import torch
import wandb

from maskpredformer.mask_simvp import MaskSimVP
from maskpredformer.simvp_dataset import DLDataset
from maskpredformer.vis_utils import show_video_line, show_gif

## MaskSimVPAutoRegressiveModule

In [None]:
#| export
class MaskSimVPAutoRegressiveModule(pl.LightningModule):
    def __init__(self, in_shape, hid_S, hid_T, N_S, N_T, model_type,
                 batch_size, lr, weight_decay, max_epochs,
                 data_root, backprop_indices = [10], pre_seq_len=11, aft_seq_len=1,
                 drop_path=0.0, unlabeled=False, downsample=False):
        super().__init__()
        self.save_hyperparameters()
        self.model = MaskSimVP(
            in_shape, hid_S, hid_T, N_S, N_T, model_type, downsample=downsample, drop_path=drop_path,
            pre_seq_len=pre_seq_len, aft_seq_len=aft_seq_len
        )
        self.backprop_indices = backprop_indices
        self.train_set = DLDataset(data_root, "train", unlabeled=unlabeled, pre_seq_len=11, aft_seq_len=11)
        self.val_set = DLDataset(data_root, "val", pre_seq_len=11, aft_seq_len=11)
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=True, pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=False, pin_memory=True
        )

    def calculate_loss(self, logits, target):
        b, t, *_ = logits.shape
        logits = logits.view(b*t, *logits.shape[2:])
        target = target.view(b*t, *target.shape[2:])
        loss = self.criterion(logits, target)
        return loss
    
    def step(self, x, y):
        y_hat_logits = []
        cur_seq = x.clone()
        for i in range(11):
            y_hat_logit_t = self.model(cur_seq)
            if i in self.backprop_indices:
                y_hat_logits.append(y_hat_logit_t) # get logits for backprop
            y_hat = torch.argmax(y_hat_logit_t, dim=2) # get current prediction
            cur_seq = torch.cat([cur_seq[:, 1:], y_hat], dim=1) # autoregressive concatenation
        
        y_hat_logits = torch.cat(y_hat_logits, dim=1)
        assert y_hat_logits.size(1) == len(self.backprop_indices)
        # calculate loss
        loss = self.calculate_loss(y_hat_logits, y[:, self.backprop_indices])
        del y_hat_logits
        return loss, cur_seq

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, _ = self.step(x, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss, _ = self.step(x, y)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, 
            weight_decay=self.hparams.weight_decay
        )
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.lr,
            total_steps=self.hparams.max_epochs*len(self.train_dataloader()),
            final_div_factor=1e4
        )
        opt_dict = {
            "optimizer": optimizer,
            "lr_scheduler":{
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1
            } 
        }

        return opt_dict
        

**Test out the MaskSimVPAutoRegressive Module**

In [None]:
%cd ..

In [None]:
ckpt_path = "checkpoints/in_shape=11-49-160-240_hid_S=64_hid_T=512_N_S=4_N_T=8_model_type=gSTA_batch_size=4_lr=0.001_weight_decay=0.0_max_epochs=20_pre_seq_len=11_aft_seq_len=1_unlabeled=True_downsample=True/simvp_epoch=16-val_loss=0.014.ckpt"
mask_sim_vp_ckpt = torch.load(ckpt_path)

autoregressive_params = mask_sim_vp_ckpt['hyper_parameters']
autoregressive_params['unlabeled'] = False

pl_module = MaskSimVPAutoRegressiveModule(**autoregressive_params)
pl_module.load_state_dict(mask_sim_vp_ckpt["state_dict"])

In [None]:
def test_prior_model_results():
    x, y = pl_module.val_set[0]
    x=x.unsqueeze(0).to(pl_module.device); y=y.unsqueeze(0).to(pl_module.device)
    return y, *pl_module.step(x, y)
    
y, loss, cur_seq = test_prior_model_results()
loss.backward()
print(loss)

In [None]:
show_video_line(cur_seq.squeeze().numpy(), 11)

In [None]:
show_video_line(y.squeeze().numpy(), 11)

## Sample AR Video Callback

In [None]:
#| export
class SampleAutoRegressiveVideoCallback(pl.Callback):
    def __init__(self, val_set, video_path="./val_videos/"):
        super().__init__()
        self.val_set = val_set
        self.val_count = 0
        self.val_path = video_path
        if not os.path.exists(self.val_path):
            os.makedirs(self.val_path)

    def generate_video(self, pl_module):
        pl_module.eval()
        sample_idx = random.randint(0, len(self.val_set)-1)

        x, y = self.val_set[sample_idx]
        x = x.unsqueeze(0).to(pl_module.device)
        y = y.unsqueeze(0).to(pl_module.device)
        
        _, cur_seq = pl_module.step(x, y)

        # convert to numpy
        x = x.squeeze(0).cpu().numpy()
        y = y.squeeze(0).cpu().numpy()
        y_hat = cur_seq.squeeze(0).cpu().numpy()

        gif_path = os.path.join(self.val_path, f"val_ar_video_{self.val_count}.gif")

        show_gif(x, y, y_hat, out_path=gif_path)
        self.val_count += 1

        return gif_path
    
    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.global_rank == 0:
            gif_path = self.generate_video(pl_module)
            trainer.logger.experiment.log({
                "val_video": wandb.Video(gif_path, fps=4, format="gif")
            })

In [None]:
# test video callback
sample_video_cb = SampleAutoRegressiveVideoCallback(pl_module.val_set)
gif_path = sample_video_cb.generate_video(pl_module)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()