In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from models.models import ObjectDetectionModel

In [2]:
# Custom dataset class for Moving MNIST
class MovingMNISTDataset(Dataset):
    def __init__(self, num_samples, seq_length, image_size):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.image_size = image_size
        self.data = self.generate_dataset()

    def generate_dataset(self):
        data = []
        for _ in range(self.num_samples):
            sample = []
            for _ in range(self.seq_length):
                mnist_image = np.zeros((self.image_size, self.image_size), dtype=np.uint8)
                x, y = random.randint(0, self.image_size - 28), random.randint(0, self.image_size - 28)
                mnist_image[x:x+28, y:y+28] = np.random.randint(0, 256, size=(28, 28))
                sample.append(mnist_image)
            data.append(sample)
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32)

In [3]:
# Create datasets and data loaders
train_dataset = MovingMNISTDataset(num_samples=10000, seq_length=5, image_size=64)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

model = ObjectDetectionModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [4]:
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch in train_loader:
        # Flatten the sequence and batch dimensions
        batch_size, seq_length, height, width = batch.size()
        batch = batch.view(batch_size * seq_length, 1, height, width)  # reshaping for individual frame processing

        optimizer.zero_grad()
        outputs = model(batch)  # Process the batch which now includes all frames individually
        targets = torch.zeros_like(outputs)  # Assuming target bounding boxes are all zeros for simplicity

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/(len(train_loader) * seq_length)}")


Epoch 1/10, Loss: 0.10373876318661793
Epoch 2/10, Loss: 5.871439337801592e-05
Epoch 3/10, Loss: 1.826754688310502e-05
Epoch 4/10, Loss: 7.116681412600165e-06


KeyboardInterrupt: 