In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import DataParallel

import time
import os
import numpy as np
import json
import cv2
from PIL import Image, ImageOps
import random
from tqdm import tqdm
import operator
import itertools
from scipy.io import  loadmat
import logging
from scipy import signal

from utils import data_transforms
from utils import get_paste_kernel, kernel_map
from utils_logging import setup_logger

from models.recasens import GazeNet
from dataloader.recasens import GooDataset, parse_GooPickle
from models.__init__ import save_checkpoint, resume_checkpoint
from training.train_recasens import test
#from training.train_recasens import Optimizer

In [4]:
logger = setup_logger(name='first_logger', 
                      log_dir ='./logs/',
                      log_file='train_recasens.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

In [5]:
# Dataloaders
batch_size=32
workers=12
testbatchsize=16

images_dir = '/hdd/HENRI/goosynth/1person/GazeDatasets/'
pickle_path = '/hdd/HENRI/goosynth/picklefiles/trainpickle2to19human.pickle'
test_images_dir = '/hdd/HENRI/goosynth/test/'
test_pickle_path = '/hdd/HENRI/goosynth/picklefiles/testpickle120.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

Total images in this set: 172800
Total images in this set: 19200


In [6]:
# Loads model
net = GazeNet(placesmodel_path='./alexnet_places365.pth')
#net = DataParallel(net)
net.cuda()

GazeNet(
  (salpath): AlexSal(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (relu): ReLU()
    (sigmoid): Sigmoid()
    (conv6): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (gazepath): AlexGaze(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
   

In [7]:
start_epoch = 0
max_epoch = 25
learning_rate = 0.0001

optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-6)

resume_training = True
resume_path = './saved_models/temp/gazenet_gazefollow_0.1647_11epoch_0modIdx.pth.tar'
if resume_training :
    net, optimizer = resume_checkpoint(net, optimizer, resume_path)
    test(net, test_data_loader,logger)
    start_epoch = 25

=> loading checkpoint './saved_models/temp/gazenet_gazefollow_0.1647_11epoch_0modIdx.pth.tar'


  0%|          | 0/1200 [00:00<?, ?it/s]

=> loaded checkpoint './saved_models/temp/gazenet_gazefollow_0.1647_11epoch_0modIdx.pth.tar' (epoch 12)


  "See the documentation of nn.Upsample for details.".format(mode))
 14%|█▍        | 173/1200 [01:17<07:38,  2.24it/s]


KeyboardInterrupt: 

In [6]:
staged_opt = StagedOptimizer(net, learning_rate)

for epoch in range(start_epoch, max_epoch):
    
    # Update optimizer
    optimizer = staged_opt.update(epoch)

    # Train model
    train(net, train_data_loader, optimizer, epoch, logger)

    # Save model and optimizer
    if epoch > max_epoch-5:
        save_path = './saved_models/temp/'
        save_checkpoint(net, optimizer, epoch+1, save_path)
    
    # Evaluate model
    test(net, test_data_loader, logger)

  1%|          | 32/3924 [00:14<29:58,  2.16it/s] 


KeyboardInterrupt: 