In [1]:
import torch
from torch import nn, optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
import cv2
import numpy as np

# Define the autoencoder model
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, latent_dim)
        )

        # Decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 4 * 4),
            nn.Unflatten(-1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x1, x2):
        z1 = self.encode(x1)
        z2 = self.encode(x2)
        z = (z1 + z2) / 2  # Interpolate latent vectors
        return self.decode(z), z


# Define the training function
def train(model, dataloader, optimizer, criterion, num_epochs=10, device='cpu'):
    model.train()
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, _) in enumerate(dataloader):
            inputs = inputs.to(device)

            optimizer.zero_grad()

            outputs, _ = model(inputs)
            loss = criterion(outputs, inputs)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()
            if i % 10 == 9:
                print(f"Epoch {epoch+1}, Batch {i+1}/{len(dataloader)}, Loss: {running_loss/10:.5f}")
                running_loss = 0.0

# Load the video as a dataset of frames
video_filename = 'input_video.mkv'
video_dataset = ImageFolder(video_filename, transform=Compose([Resize((256, 256)), ToTensor()]))

# Create a data loader for the video dataset
video_dataloader = DataLoader(video_dataset, batch_size=32, shuffle=True)

# Define the autoencoder model
autoencoder = Autoencoder()



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch [1/10], Batch [1/469], Loss: 549.2940
Epoch [1/10], Batch [101/469], Loss: 192.5926
Epoch [1/10], Batch [201/469], Loss: 157.3862
Epoch [1/10], Batch [301/469], Loss: 142.5435
Epoch [1/10], Batch [401/469], Loss: 134.6329
Epoch [2/10], Batch [1/469], Loss: 125.4808
Epoch [2/10], Batch [101/469], Loss: 127.6829
Epoch [2/10], Batch [201/469], Loss: 125.1700
Epoch [2/10], Batch [301/469], Loss: 118.9432
Epoch [2/10], Batch [401/469], Loss: 113.4957
Epoch [3/10], Batch [1/469], Loss: 119.2539
Epoch [3/10], Batch [101/469], Loss: 114.0507
Epoch [3/10], Batch [201/469], Loss: 111.7247
Epoch [3/10], Batch [301/469], Loss: 111.7111
Epoch [3/10], Batch [401/469], Loss: 116.4308
Epoch [4/10], Batch [1/469], Loss: 110.2689
Epoch [4/10], Batch [101/469], Loss: 112.5125
Epoch [4/10], Batch [201/469], Loss: 108.3331
Epoch [4/10], Batch [301/469], Loss: 112.6111
Epoch [4/10], Batch [401/469], Loss: 114.8788
Epoch [5/10],