In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import time
import os

from monodepthloss import MonodepthLoss
#from depthnet import *
from depthencoder import depthencoder 
from depth_decoder import *

DEVICE = torch.device("cuda:0")

In [2]:
#encoderdecoder = ResnetModel(3).to(DEVICE)
encoderdecoder = DepthDecoder().to(DEVICE)

optimizer = optim.Adam(encoderdecoder.parameters(),lr=0.001)
loss_function = MonodepthLoss(n=4, SSIM_w=0.85, disp_gradient_w=0.1, lr_w=1).to(DEVICE)

#encoderdecoder.load_state_dict(torch.load('state_dicts/encoderdecoder-1579645491'))
data = []
directory = "numpy_img/"
for file in os.listdir(directory):
    numpy_file = np.load(directory+file, allow_pickle=True, mmap_mode = 'r+')
    data.append(numpy_file)


In [3]:
testing = False
j = 0
mean = []
n = 12
epochs = 10

def test_model(testing_indeces):
    with torch.no_grad():
        mean = []
        for training_index in testing_indeces:
            random_num = random.randint(1,16) # Sample ~1000 samples (so variance matches between testing and training means)
            if random_num == 1:
                imageLEFT = torch.from_numpy(data[training_index[0]][training_index[1],0]).type(torch.cuda.FloatTensor)

                imageRIGHT = torch.from_numpy(data[training_index[0]][training_index[1],1]).type(torch.cuda.FloatTensor)

                inputLEFT = torch.div(imageLEFT, 255).permute(0,3,1,2)

                inputRIGHT = torch.div(imageRIGHT, 255).permute(0,3,1,2)

                output1 = encoderdecoder(inputLEFT.view(-1,3,256,640))
                output2 = encoderdecoder(inputRIGHT.view(-1,3,256,640))
                output = []
                for i in output1.keys():
                    output.insert(0, torch.cat((output1[i], output2[i]), 1))

                loss = loss_function(output,[inputLEFT.view(-1,3,256,640), inputRIGHT.view(-1,3,256,640)])
                mean.append(loss.item())
    return round(sum(mean)/len(mean),5)

f= open(f"logs/results-{int(time.time())}.txt","w+")
for epoch in range(epochs):
    print("Epoch: "+ str(epoch))
    training_indeces = []
    testing_indeces = []
    for number, array in enumerate(data):
        frame_numbers = list(range(len(array)))
        random.shuffle(frame_numbers)
        for i in range(0, len(frame_numbers), n):
            frame_set = frame_numbers[i:i+n]
            random.shuffle(frame_set)
            if i >= len(frame_numbers)*.9:
                testing_indeces.append([number, frame_set])
            else:
                training_indeces.append([number, frame_set])
                # Make last 10% testing to ensure novelty
    random.shuffle(training_indeces)
    # REMOVES 95% OF TRAINING DATA::: (TESTING)
    training_indeces = training_indeces[0:5000]
    for training_index in tqdm(training_indeces):
        imageLEFT = torch.from_numpy(data[training_index[0]][training_index[1],0]).type(torch.cuda.FloatTensor)

        imageRIGHT = torch.from_numpy(data[training_index[0]][training_index[1],1]).type(torch.cuda.FloatTensor)

        inputLEFT = torch.div(imageLEFT, 255).permute(0,3,1,2)

        inputRIGHT = torch.div(imageRIGHT, 255).permute(0,3,1,2)

        encoderdecoder.zero_grad()

        output1 = encoderdecoder(inputLEFT.view(-1,3,256,640))
        output2 = encoderdecoder(inputRIGHT.view(-1,3,256,640))
        output = []
        for i in output1.keys():
            output.insert(0, torch.cat((output1[i], output2[i]), 1))

        loss = loss_function(output,[inputLEFT.view(-1,3,256,640), inputRIGHT.view(-1,3,256,640)])
        loss.backward()

        mean.append(loss.item())
        j += 1
        if j % 10 == 0:
            trueloss = test_model(testing_indeces)
            f.write(f"{round(sum(mean)/len(mean),5)}, {trueloss}\n")
            f.flush()
            mean = []
            # Record the average training loss over time
        if j % 10000 == 0:
            thetime = int(time.time())
            torch.save(encoderdecoder.state_dict(), f"state_dicts/encoderdecoder-{thetime}")
        optimizer.step()
f.close()

Epoch: 0
  6%|▌         | 301/5000 [00:36<09:27,  8.27it/s]

In [0]:
def display_results(movie, frame):
    with torch.no_grad():
        imageLEFT = torch.from_numpy(data[movie][frame,0]).type(torch.cuda.FloatTensor)
        print(imageLEFT.shape)
        inputLEFT = torch.div(imageLEFT, 255)

        output = encoderdecoder(inputLEFT.view(-1,3,256,640))
    return [imageLEFT, output]

result = display_results(1, 6000)
plt.imshow(result[0][:,:,0].view(256,640).cpu(), 'gray')
plt.show()
plt.imshow(result[1][0][0,0,:,:].view(256, 640).cpu().detach().numpy())
plt.show()
