### 1. Dependencies

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import DataParallel

import time
import os
import numpy as np
import json
import cv2
from PIL import Image, ImageOps
import random
from tqdm import tqdm
import operator
import itertools
from scipy.io import  loadmat
import logging
from scipy import signal

from utils import data_transforms
from utils import get_paste_kernel, kernel_map
from utils_logging import setup_logger

### 2. Choose between Recasens or GazeNet

- Idea is you can just swap 
models.recasens, dataloader.recasens, training.train_recasens, etc...
- with the following
models.gazenet, dataloader.gazenet, training.train_gazenet

In [2]:
from models.gazenet import GazeNet
from models.__init__ import save_checkpoint, resume_checkpoint
from dataloader.gazenet import GooDataset, GazeDataset
from training.train_gazenet import train, test, GazeOptimizer

In [3]:
# Logger will save the training and test errors to a .log file 
logger = setup_logger(name='first_logger', 
                      log_dir ='./logs/',
                      log_file='train_gazenet.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

### 3. Dataloaders
- Choose between GazeDataset (Gazefollow dataset) or GooDataset (GooSynth/GooReal)
- Set paths to image directories and pickle paths. For Gazefollow, images_dir and test_images_dir should be the same and both lead to the path containing the train and test folders.

In [4]:
# Dataloaders for GazeFollow
batch_size=32
workers=12
testbatchsize=16

images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/train_annotations.mat'
test_images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
test_pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/test_annotations.mat'

train_set = GazeDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GazeDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

In [4]:
# Dataloaders for GOO
batch_size=32
workers=12
testbatchsize=32

images_dir = '/hdd/HENRI/goosynth/1person/GazeDatasets/'
pickle_path = '/hdd/HENRI/goosynth/picklefiles/trainpickle2to19human.pickle'
test_images_dir = '/hdd/HENRI/goosynth/test/'
test_pickle_path = '/hdd/HENRI/goosynth/picklefiles/testpickle120.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

172800
19200


### 4. Load Model and Set Training Hyperparameters
- For Gazefollow, the model requires the alexnet_places365 pretrained model, provided here: https://urlzs.com/ytKK3
- When resuming training, set to True and set the resume_path for the saved model.
- Here, logging module is initialized (logger) to save training and testing errors.

In [5]:
# Loads model
net = GazeNet()
net.cuda()

# Hyperparameters
start_epoch = 0
max_epoch = 25
learning_rate = 0.0001

# Initializes Optimizer
gaze_opt = GazeOptimizer(net, learning_rate)
optimizer = gaze_opt.getOptimizer(start_epoch)

# Is training resumed? If so, set the resume_path and set flag to True
# This can also be used to evaluate a model 
resume_training = True
resume_path = './saved_models/gazenet_goo/model_epoch25.pth.tar'
if resume_training :
    net, optimizer, start_epoch = resume_checkpoint(net, optimizer, resume_path)
    test(net, test_data_loader,logger)

  0%|          | 0/600 [00:00<?, ?it/s]

=> loading checkpoint './saved_models/gazenet_goo/model_epoch25.pth.tar'
=> loaded checkpoint './saved_models/gazenet_goo/model_epoch25.pth.tar' (epoch 25)


loss: 0.08515, 0.14883, 0.23398
  0%|          | 1/600 [00:13<2:17:54, 13.81s/it]loss: 0.06691, 0.10966, 0.17657
  0%|          | 2/600 [00:13<1:36:52,  9.72s/it]loss: 0.07236, 0.07628, 0.14864
  0%|          | 3/600 [00:14<1:08:10,  6.85s/it]loss: 0.06304, 0.14304, 0.20608
  1%|          | 4/600 [00:14<48:06,  4.84s/it]  loss: 0.06044, 0.09855, 0.15899
  1%|          | 5/600 [00:15<36:30,  3.68s/it]loss: 0.06605, 0.10682, 0.17287
  1%|          | 6/600 [00:15<26:02,  2.63s/it]loss: 0.06807, 0.04573, 0.11379
  1%|          | 7/600 [00:15<18:43,  1.90s/it]loss: 0.08135, 0.23735, 0.31869
  1%|▏         | 8/600 [00:15<13:41,  1.39s/it]loss: 0.06671, 0.09293, 0.15964
  2%|▏         | 9/600 [00:17<13:24,  1.36s/it]loss: 0.07307, 0.13993, 0.21300
  2%|▏         | 10/600 [00:17<09:52,  1.00s/it]loss: 0.07696, 0.16930, 0.24626
  2%|▏         | 11/600 [00:17<07:31,  1.31it/s]loss: 0.07425, 0.09283, 0.16708
  2%|▏         | 12/600 [00:17<05:48,  1.69it/s]loss: 0.05499, 0.04176, 0.09674
  2%|▏   

 17%|█▋        | 103/600 [01:38<03:48,  2.18it/s]loss: 0.06452, 0.11493, 0.17945
 17%|█▋        | 104/600 [01:38<03:06,  2.66it/s]loss: 0.05939, 0.13521, 0.19461
 18%|█▊        | 105/600 [01:38<02:41,  3.07it/s]loss: 0.06041, 0.08667, 0.14707
 18%|█▊        | 106/600 [01:38<02:21,  3.50it/s]loss: 0.06664, 0.13968, 0.20632
 18%|█▊        | 107/600 [01:38<02:06,  3.89it/s]loss: 0.07950, 0.21536, 0.29487
 18%|█▊        | 108/600 [01:39<01:52,  4.35it/s]loss: 0.07573, 0.09895, 0.17468
 18%|█▊        | 109/600 [01:46<19:01,  2.33s/it]loss: 0.06432, 0.14244, 0.20676
 18%|█▊        | 110/600 [01:46<13:56,  1.71s/it]loss: 0.06216, 0.07192, 0.13407
 18%|█▊        | 111/600 [01:48<15:33,  1.91s/it]loss: 0.05721, 0.10703, 0.16423
 19%|█▊        | 112/600 [01:49<11:33,  1.42s/it]loss: 0.05554, 0.07878, 0.13431
 19%|█▉        | 113/600 [01:49<08:36,  1.06s/it]loss: 0.08457, 0.11060, 0.19516
 19%|█▉        | 114/600 [01:49<06:41,  1.21it/s]loss: 0.07957, 0.22934, 0.30891
 19%|█▉        | 115/600 [01

 34%|███▍      | 204/600 [02:59<04:09,  1.59it/s]loss: 0.07283, 0.17488, 0.24772
 34%|███▍      | 205/600 [02:59<03:21,  1.96it/s]loss: 0.07090, 0.06312, 0.13403
 34%|███▍      | 206/600 [03:00<02:50,  2.31it/s]loss: 0.08098, 0.22499, 0.30596
 34%|███▍      | 207/600 [03:03<08:45,  1.34s/it]loss: 0.06719, 0.12874, 0.19593
 35%|███▍      | 208/600 [03:03<06:29,  1.01it/s]loss: 0.07595, 0.14367, 0.21962
 35%|███▍      | 209/600 [03:04<04:59,  1.30it/s]loss: 0.08239, 0.14891, 0.23130
 35%|███▌      | 210/600 [03:04<03:57,  1.64it/s]loss: 0.06783, 0.06907, 0.13690
 35%|███▌      | 211/600 [03:04<03:10,  2.04it/s]loss: 0.08996, 0.17247, 0.26242
 35%|███▌      | 212/600 [03:04<02:38,  2.45it/s]loss: 0.07069, 0.09668, 0.16737
 36%|███▌      | 213/600 [03:04<02:18,  2.80it/s]loss: 0.07445, 0.12542, 0.19987
 36%|███▌      | 214/600 [03:08<07:54,  1.23s/it]loss: 0.07743, 0.17231, 0.24974
 36%|███▌      | 215/600 [03:08<05:49,  1.10it/s]loss: 0.06550, 0.13328, 0.19878
 36%|███▌      | 216/600 [03

 51%|█████     | 305/600 [04:21<03:06,  1.58it/s]loss: 0.07474, 0.11269, 0.18743
 51%|█████     | 306/600 [04:22<02:34,  1.90it/s]loss: 0.07043, 0.08093, 0.15136
 51%|█████     | 307/600 [04:22<02:14,  2.18it/s]loss: 0.06955, 0.06174, 0.13129
 51%|█████▏    | 308/600 [04:22<01:50,  2.64it/s]loss: 0.07920, 0.16022, 0.23942
 52%|█████▏    | 309/600 [04:22<01:35,  3.05it/s]loss: 0.07561, 0.08462, 0.16023
 52%|█████▏    | 310/600 [04:24<03:08,  1.53it/s]loss: 0.06521, 0.07180, 0.13700
 52%|█████▏    | 311/600 [04:24<02:28,  1.94it/s]loss: 0.07045, 0.07607, 0.14651
 52%|█████▏    | 312/600 [04:24<02:04,  2.31it/s]loss: 0.05692, 0.10296, 0.15988
 52%|█████▏    | 313/600 [04:30<09:48,  2.05s/it]loss: 0.06631, 0.05826, 0.12457
 52%|█████▏    | 314/600 [04:30<07:12,  1.51s/it]loss: 0.07672, 0.17534, 0.25207
 52%|█████▎    | 315/600 [04:30<05:20,  1.12s/it]loss: 0.07190, 0.06497, 0.13687
 53%|█████▎    | 316/600 [04:31<04:00,  1.18it/s]loss: 0.07356, 0.29500, 0.36856
 53%|█████▎    | 317/600 [04

 68%|██████▊   | 406/600 [05:41<04:33,  1.41s/it]loss: 0.07221, 0.08191, 0.15413
 68%|██████▊   | 407/600 [05:41<03:22,  1.05s/it]loss: 0.06476, 0.07089, 0.13565
 68%|██████▊   | 408/600 [05:41<02:34,  1.25it/s]loss: 0.07061, 0.16508, 0.23568
 68%|██████▊   | 409/600 [05:43<03:05,  1.03it/s]loss: 0.06935, 0.07238, 0.14173
 68%|██████▊   | 410/600 [05:43<02:24,  1.31it/s]loss: 0.08027, 0.12821, 0.20849
 68%|██████▊   | 411/600 [05:43<01:57,  1.61it/s]loss: 0.08502, 0.20458, 0.28960
 69%|██████▊   | 412/600 [05:43<01:32,  2.03it/s]loss: 0.07986, 0.12210, 0.20196
 69%|██████▉   | 413/600 [05:44<01:17,  2.40it/s]loss: 0.08273, 0.19412, 0.27685
 69%|██████▉   | 414/600 [05:44<01:04,  2.89it/s]loss: 0.07448, 0.21907, 0.29355
 69%|██████▉   | 415/600 [05:44<00:54,  3.37it/s]loss: 0.06272, 0.06358, 0.12630
 69%|██████▉   | 416/600 [05:44<00:48,  3.77it/s]loss: 0.07402, 0.07353, 0.14755
 70%|██████▉   | 417/600 [05:44<00:45,  4.01it/s]loss: 0.07208, 0.18812, 0.26021
 70%|██████▉   | 418/600 [05

 84%|████████▍ | 507/600 [06:58<00:48,  1.94it/s]loss: 0.07440, 0.09700, 0.17140
 85%|████████▍ | 508/600 [06:58<00:38,  2.37it/s]loss: 0.07624, 0.06374, 0.13998
 85%|████████▍ | 509/600 [06:59<00:45,  2.00it/s]loss: 0.06163, 0.15636, 0.21798
 85%|████████▌ | 510/600 [06:59<00:43,  2.09it/s]loss: 0.06922, 0.17370, 0.24292
 85%|████████▌ | 511/600 [07:00<00:36,  2.44it/s]loss: 0.07002, 0.05705, 0.12707
 85%|████████▌ | 512/600 [07:00<00:31,  2.77it/s]loss: 0.06985, 0.06850, 0.13835
 86%|████████▌ | 513/600 [07:00<00:27,  3.16it/s]loss: 0.08163, 0.14323, 0.22486
 86%|████████▌ | 514/600 [07:09<04:05,  2.85s/it]loss: 0.08132, 0.11958, 0.20091
 86%|████████▌ | 515/600 [07:09<02:56,  2.07s/it]loss: 0.07281, 0.12339, 0.19620
 86%|████████▌ | 516/600 [07:09<02:08,  1.53s/it]loss: 0.07338, 0.13682, 0.21020
 86%|████████▌ | 517/600 [07:10<01:33,  1.13s/it]loss: 0.07798, 0.15798, 0.23596
 86%|████████▋ | 518/600 [07:10<01:10,  1.16it/s]loss: 0.05337, 0.04003, 0.09340
 86%|████████▋ | 519/600 [07

In [6]:
pretrained_dict = torch.load('./saved_models/gazenet_goo/model_epoch25.pkl')
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

save_path = './saved_models/gazenet_goo/'
save_checkpoint(net, optimizer, 25, save_path)

### 5. Training the Model
- Determine in which epochs do you want to save the model, as you might not want to save every epoch
- Training and test errors can be accessed in the logs directory set up earlier

In [6]:
for epoch in range(start_epoch, max_epoch):
    
    # Update optimizer
    optimizer = gaze_opt.getOptimizer(epoch)

    # Train model
    train(net, train_data_loader, optimizer, epoch, logger)

    # Save model and optimizer
    if epoch > max_epoch-5:
        save_path = './saved_models/temp/'
        save_checkpoint(net, optimizer, epoch+1, save_path)
    
    # Evaluate model
    test(net, test_data_loader, logger)

  2%|▏         | 98/5400 [01:54<2:46:40,  1.89s/it]2.9881935048103334
  3%|▎         | 156/5400 [02:50<1:35:42,  1.10s/it]


KeyboardInterrupt: 