In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import time
import os

from monodepthloss import MonodepthLoss
from depthnet import *

DEVICE = torch.device("cuda:0")

In [10]:
encoderdecoder = ResnetModel(3).to(DEVICE)

optimizer = optim.Adam(encoderdecoder.parameters(),lr=0.0001)
loss_function = MonodepthLoss(n=4, SSIM_w=0.85,\
                disp_gradient_w=0.1, lr_w=1).to(DEVICE)

encoderdecoder.load_state_dict(torch.load('state_dicts/encoderdecoder3-1574426512'))
data = []
directory = "numpy_img/"
for file in os.listdir(directory):
    training_data = np.load(directory+file, allow_pickle=True, mmap_mode = 'r+')
    data.append(training_data)

In [11]:
testing = False
j = 0
mean = []
n = 12
epochs = 10

f= open(f"logs/results-{int(time.time())}.txt","w+")
for epoch in range(epochs):
    samps = []
    print("Epoch: "+ str(epoch))
    for img, training_data in enumerate(data):
        nums = list(range(len(training_data)))
        for i in range(0, len(nums), n):
            set = nums[i:i+n]
            random.shuffle(set)
            samps.append([img, set])
    random.shuffle(samps)
    for frame in tqdm(samps):
        LinNMimg = torch.from_numpy(data[frame[0]][[x - 2 for x in frame[1]],0])\
        .type(torch.cuda.FloatTensor)
        LinNimg = torch.from_numpy(data[frame[0]][[x - 1 for x in frame[1]],0])\
        .type(torch.cuda.FloatTensor)
        LinNPimg = torch.from_numpy(data[frame[0]][frame[1],0])\
        .type(torch.cuda.FloatTensor)

        RoutNimg = torch.from_numpy(data[frame[0]][frame[1],1])\
        .type(torch.cuda.FloatTensor)

        LinNM = torch.div(LinNMimg, 255).permute(0,3,1,2)
        LinN = torch.div(LinNimg, 255).permute(0,3,1,2)
        LinNP = torch.div(LinNPimg, 255).permute(0,3,1,2)

        RoutN = torch.div(RoutNimg, 255).permute(0,3,1,2)

        encoderdecoder.zero_grad()
        output = encoderdecoder([LinNM.view(-1,3,256,640),LinN.view(-1,3,256,640),\
                                LinNP.view(-1,3,256,640)])
        loss = loss_function(output,[LinN.view(-1,3,256,640),\
                             RoutN.view(-1,3,256,640)])
        loss.backward()
        mean.append(loss.item())
        j += 1
        if j % 100 == 0:
            f.write(f"{round(sum(mean)/len(mean),5)}\n")
            f.flush()
            mean = []
        if j % 5000 == 0:
            thetime = int(time.time())
            torch.save(encoderdecoder.state_dict(),\
             f"state_dicts/encoderdecoder3-{thetime}")

            '''save_image(RoutN[0,:,:,:].view(3,256,640).cpu(), \
                       f'imageout/{thetime}-left.png')
            save_image(output[0][0,0,:,:].view(256,640).cpu(), \
                       f'imageout/{thetime}-depth.png')'''
        optimizer.step()
f.close()


  0%|          | 0/11277 [00:00<?, ?it/s][AEpoch: 0

  0%|          | 1/11277 [00:06<20:14:21,  6.46s/it][A
  0%|          | 2/11277 [00:12<20:06:28,  6.42s/it][A
  0%|          | 3/11277 [00:19<20:24:39,  6.52s/it][A
  0%|          | 4/11277 [00:26<20:22:18,  6.51s/it][A
  0%|          | 5/11277 [00:32<20:23:05,  6.51s/it][A

KeyboardInterrupt: 

In [0]:
plt.imshow(output[0][0,0,:,:].view(256, 640).cpu().detach().numpy(), 'gray')
plt.show()
plt.imshow(RoutN[0,:,:,:].view(3,256,640).permute(1,2,0).cpu())
plt.show()