# Predicting SST

In this article, we use an autoregressive model to forecast SST.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. Define a simple autoregressive model
class AutoregressiveVideoModel(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(AutoregressiveVideoModel, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # x: (batch_size, input_channels, height, width)
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x  # (batch_size, output_channels, height, width)

# 2. Generate some dummy video data
def generate_dummy_video_data(batch_size, num_frames, channels, height, width):
    """
    Generates dummy video data for testing.

    Args:
        batch_size (int): The batch size.
        num_frames (int): The number of frames in the video.
        channels (int): The number of channels in each frame (e.g., 3 for RGB).
        height (int): The height of each frame.
        width (int): The width of each frame.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, num_frames, channels, height, width).
    """
    return torch.randn(batch_size, num_frames, channels, height, width)

# 3. Prepare data for training
def prepare_data(video_data):
    """
    Prepares the video data for autoregressive training.  The input to the model
    is a sequence of frames, and the target is the next frame.

    Args:
        video_data (torch.Tensor): Tensor of shape (batch_size, num_frames, channels, height, width).

    Returns:
        tuple: (inputs, targets)
            inputs:  Tensor of shape (batch_size, num_frames - 1, channels, height, width)
            targets: Tensor of shape (batch_size, num_frames - 1, channels, height, width)
    """
    inputs = video_data[:, :-1, :, :, :]
    targets = video_data[:, 1:, :, :, :]
    # Reshape to (batch_size * (num_frames - 1), ...) for easier training with 2D conv
    inputs = inputs.reshape(-1, *inputs.shape[2:])
    targets = targets.reshape(-1, *targets.shape[2:])
    return inputs, targets

# 4. Train the model
def train_model(model, train_loader, optimizer, loss_fn, epochs=10):
    """
    Trains the autoregressive video model.

    Args:
        model (nn.Module): The autoregressive model.
        train_loader (torch.utils.data.DataLoader): DataLoader for training data.
        optimizer (torch.optim.Optimizer): The optimizer.
        loss_fn (nn.Module): The loss function.
        epochs (int): The number of epochs to train. Defaults to 10.
    """
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch: {epoch + 1}, Loss: {total_loss / len(train_loader)}")

# 5.  Generate a future frame given a sequence of past frames
def generate_next_frame(model, past_frames):
    """
    Generates the next frame given a sequence of past frames.

    Args:
        model (nn.Module): Trained autoregressive model.
        past_frames (torch.Tensor): Tensor of shape (1, sequence_length, channels, height, width).

    Returns:
        torch.Tensor: The predicted next frame of shape (1, channels, height, width).
    """
    model.eval()
    with torch.no_grad():
        # Use the last frame in the input sequence as the input to the model
        next_frame = model(past_frames[:, -1])
        return next_frame.unsqueeze(0) # Add the batch dimension back

if __name__ == '__main__':
    # 0. Set random seed for reproducibility
    torch.manual_seed(0)
    np.random.seed(0)

    # 1. Hyperparameters
    batch_size = 2
    num_frames = 5
    channels = 3  # RGB
    height = 64
    width = 64
    hidden_channels = 16
    epochs = 5

    # 2. Generate dummy video data
    video_data = generate_dummy_video_data(batch_size, num_frames, channels, height, width)

    # 3. Prepare data for training
    inputs, targets = prepare_data(video_data)
    train_dataset = torch.utils.data.TensorDataset(inputs, targets)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # 4. Initialize model, loss function, and optimizer
    model = AutoregressiveVideoModel(channels, hidden_channels, channels)
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 5. Train the model
    train_model(model, train_loader, optimizer, loss_fn, epochs)

    # 6. Generate a future frame
    # Example: Use the first 3 frames from the first video in the batch to predict the 4th frame
    past_frames = video_data[:1, :3, :, :, :]  # (1, 3, 3, 64, 64)
    next_frame = generate_next_frame(model, past_frames)
    print("Predicted next frame shape:", next_frame.shape)  # Should be (1, 3, 64, 64)

    # 7.  Visualize the input frames and the predicted next frame.
    #     (Requires matplotlib)
    import matplotlib.pyplot as plt

    def visualize_frames(frames, title="Frames"):
        """
        Visualizes a sequence of frames using matplotlib.

        Args:
            frames (torch.Tensor): Tensor of shape (1, num_frames, channels, height, width).
            title (str): Title of the plot.
        """
        frames = frames.squeeze(0).permute(0, 2, 3, 1).cpu().numpy() # (num_frames, height, width, channels)
        num_frames_to_plot = frames.shape[0]
        fig, axes = plt.subplots(1, num_frames_to_plot, figsize=(15, 5))
        fig.suptitle(title)
        for i in range(num_frames_to_plot):
            axes[i].imshow(frames[i])
            axes[i].axis('off')
        plt.show()

    # Visualize the input frames and the prediction
    visualize_frames(past_frames, title="Past Frames")
    visualize_frames(next_frame.unsqueeze(0), title="Predicted Next Frame")

