In [1]:
import cv2
import numpy as np
import torch
from torchvision import transforms
from torch import nn
import time

VIDEOS_LOCATION = "C:\\Users\\trett\Documents\GitHub\ThirdYearProject\pytorch\\videos\\"

transform = transforms.ToTensor()

width = 640
height = 360


In [2]:
VIDEOS_LOCATION = "C:\\Users\\trett\Documents\GitHub\ThirdYearProject\pytorch\\videos\\"
VIDEOS_TRAIN = [
    {
        "file": "soldiers.mp4",
        "fps": None
    },
    {
        "file": 'running.mp4',
        "fps": None
    }
]

VIDEOS_TEST = [
    {
        "file": 'nato.mp4',
        "fps": None
    }
]

VIDEOS_SKI_TRAIN = [
    {
        "file": 'drone_shot.mp4',
        "fps": None
    },
    {
        "file": "crevace2.mp4",
        "fps": None
    }
]

VIDEOS_SKI_TEST = [
    {
        "file": "powder_maybe2.mp4",
        "fps": None
    }
]

class VideoDataLoader:
    video_index = 0
    videos = None
    cap = None
    frame_batch_buffer = []
    first_frame = None
    def __init__(self, videos):
        self.videos = videos
        self.cap = cv2.VideoCapture(VIDEOS_LOCATION + self.videos[self.video_index]["file"])
        self.videos[self.video_index]["fps"] = self.cap.get(cv2.CAP_PROP_FPS)
        if not self.cap.isOpened():
            print("Error: could not open video file")
        ret, first_frame = self.cap.read()
        for i in range(10):
            ret, middle_frame = self.cap.read()
            if not ret:
                break
            ret, last_frame = self.cap.read()
            if not ret:
                break
            batch_of_frames = [transform(first_frame), transform(middle_frame), transform(last_frame)]
            self.frame_batch_buffer.append(batch_of_frames)
            first_frame = last_frame
        self.first_frame = first_frame

    def hasNext(self):
        return len(self.frame_batch_buffer) != 0

    def nextFile(self):
        self.video_index += 1
        self.cap.release()
        try:
            self.cap = cv2.VideoCapture(VIDEOS_LOCATION + self.videos[self.video_index]["file"])
        except:
            batch = self.frame_batch_buffer.pop(0)
            self.frame_batch_buffer = []
            return batch

        self.videos[self.video_index]["fps"] = self.cap.get(cv2.CAP_PROP_FPS)
        if not self.cap.isOpened():
            print("Error: could not open video file")
        ret, first_frame = self.cap.read()
        for i in range(10):
            ret, middle_frame = self.cap.read()
            if not ret:
                break
            ret, last_frame = self.cap.read()
            if not ret:
                break
            batch_of_frames = [transform(first_frame), transform(middle_frame), transform(last_frame)]
            self.frame_batch_buffer.append(batch_of_frames)
            first_frame = last_frame
        self.first_frame = first_frame
        return self.frame_batch_buffer.pop(0)

    def getNext(self):
        if len(self.frame_batch_buffer) < 5:
            return self.nextFile()
        ret, middle_frame = self.cap.read()
        if not ret:
            return self.nextFile()
        ret, last_frame = self.cap.read()
        if not ret:
            return self.nextFile()

        batch_of_frames = [transform(self.first_frame), transform(middle_frame), transform(last_frame)]
        # batch_of_frames = [self.first_frame, middle_frame, last_frame]
        self.frame_batch_buffer.append(batch_of_frames)
        self.first_frame = last_frame
        # cv2.imshow('batch_of_frames', middle_frame)
        # cv2.imshow('batch_of_frames', last_frame)
        # if cv2.waitKey(25) & 0xFF == ord('q'):
        #     cv2.destroyAllWindows()
        #     self.frame_batch_buffer = []
        return self.frame_batch_buffer.pop(0)


In [3]:
class Autoencoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.lefthand_first = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.righthand_first = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.lefthand_second = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.righthand_second = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.lefthand_third = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.righthand_third = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True)
        )

        self.decoder_third = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True)
        )

        self.decoder_second = nn.Sequential(
            nn.ConvTranspose2d(in_channels=144, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True)
        )

        self.decoder_first = nn.Sequential(
            nn.ConvTranspose2d(in_channels=320, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True)
        )


    def forward(self, left, right):
        left_first = self.lefthand_first(left)
        right_first = self.righthand_first(right)

        first_concat_256c = torch.cat((left_first, right_first), 0)

        left_second = self.lefthand_second(left_first)
        right_second = self.righthand_second(right_first)

        second_concat_128c = torch.cat((left_second, right_second), 0)

        left_third = self.lefthand_third(left_second)
        right_third = self.righthand_third(right_second)

        encoded = torch.cat((left_third, right_third), 0)

        third = self.decoder_third(encoded)

        second_skip = torch.cat((third, second_concat_128c), 0)

        second_no_skip = self.decoder_second(second_skip)

        first_skip = torch.cat((second_no_skip, first_concat_256c), 0)

        x = self.decoder_first(first_skip)
        return x


In [4]:
loss_fn = torch.nn.MSELoss()
# torch.manual_seed(0)

autoencoder = Autoencoder()
params_to_optimize = [
    {'params': autoencoder.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=0.0001, weight_decay=1e-08)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

autoencoder = autoencoder.to(device)


Selected device: cuda


In [5]:
autoencoder.train()
epochs = 5
train_loss_epochs = []
for i in range(epochs):
    train_loss = []
    start = time.time()
    videoDataLoader = VideoDataLoader(VIDEOS_TRAIN)
    while videoDataLoader.hasNext():
        batch = videoDataLoader.getNext()
        first_frame = batch[0].to(device)
        middle_frame = batch[1].to(device)
        last_frame = batch[2].to(device)

        res_frame = autoencoder(first_frame, last_frame)
        # image_np = torchvision.transforms.functional.invert(res_frame.cpu().data).numpy()

        # Evaluate loss
        loss = loss_fn(res_frame, middle_frame)
        # Backward pass
        optim.zero_grad()
        loss.backward()
        optim.step()
        # Print batch loss
        print('\t %d partial train loss (single batch): %f' % (i, loss.data))
        train_loss.append(loss.detach().cpu().numpy())
    train_loss_epochs.append(train_loss)

	 0 partial train loss (single batch): 0.268362
	 0 partial train loss (single batch): 0.261266
	 0 partial train loss (single batch): 0.254208
	 0 partial train loss (single batch): 0.246720
	 0 partial train loss (single batch): 0.239176
	 0 partial train loss (single batch): 0.231392
	 0 partial train loss (single batch): 0.223537
	 0 partial train loss (single batch): 0.215440
	 0 partial train loss (single batch): 0.207270
	 0 partial train loss (single batch): 0.198940
	 0 partial train loss (single batch): 0.190604
	 0 partial train loss (single batch): 0.181665
	 0 partial train loss (single batch): 0.172875
	 0 partial train loss (single batch): 0.163741
	 0 partial train loss (single batch): 0.154588
	 0 partial train loss (single batch): 0.145054
	 0 partial train loss (single batch): 0.135885
	 0 partial train loss (single batch): 0.126661
	 0 partial train loss (single batch): 0.117473
	 0 partial train loss (single batch): 0.108362
	 0 partial train loss (single batch): 0

In [6]:
videoDataLoader = VideoDataLoader(VIDEOS_SKI_TEST)
fps = videoDataLoader.videos[0]["fps"]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter('24fpsFromModel.mp4', fourcc, fps, (width, height), isColor=True)
autoencoder.eval()
i = 0
with torch.no_grad():
    test_loss = []
    while videoDataLoader.hasNext():
        batch = videoDataLoader.getNext()
        first_f = np.transpose(batch[0].numpy(), (1, 2, 0))
        original_f = np.transpose(batch[1].numpy(), (1, 2, 0))

        first_frame = batch[0].to(device)
        middle_frame = batch[1].to(device)
        last_frame = batch[2].to(device)

        res_frame = autoencoder(first_frame, last_frame)
        # image_np = torchvision.transforms.functional.invert(res_frame.cpu().data).numpy()
        middle_f = np.transpose(res_frame.cpu().data.numpy(), (1, 2, 0))
        cv2.imshow('generated frame', cv2.hconcat([first_f, first_f]))
        cv2.imshow('generated frame', cv2.hconcat([first_f, middle_f]))
        # out.write(cv2.hconcat([first_f, first_f]))
        # out.write(cv2.hconcat([first_f, middle_f]))
        # out.write(first_f)
        # out.write(middle_f)
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break

        loss = loss_fn(res_frame, middle_frame)
        print('\t %d partial test loss (single batch): %f' % (i, loss.data))
        i = i + 1
        test_loss.append(loss.detach().cpu().numpy())
    cv2.destroyAllWindows()
# out.release()

	 0 partial test loss (single batch): 0.000359
	 1 partial test loss (single batch): 0.000244
	 2 partial test loss (single batch): 0.000426
	 3 partial test loss (single batch): 0.000466
	 4 partial test loss (single batch): 0.000320
	 5 partial test loss (single batch): 0.000403
	 6 partial test loss (single batch): 0.000366
	 7 partial test loss (single batch): 0.000239
	 8 partial test loss (single batch): 0.000382
	 9 partial test loss (single batch): 0.000472
	 10 partial test loss (single batch): 0.000336
	 11 partial test loss (single batch): 0.000281
	 12 partial test loss (single batch): 0.000368
	 13 partial test loss (single batch): 0.000247
	 14 partial test loss (single batch): 0.000238
	 15 partial test loss (single batch): 0.000430
	 16 partial test loss (single batch): 0.000446
	 17 partial test loss (single batch): 0.000368
	 18 partial test loss (single batch): 0.000420
	 19 partial test loss (single batch): 0.000372
	 20 partial test loss (single batch): 0.000311
	 