In [None]:
!pip install wandb
# !wandb login --relogin

In [None]:
from google.colab import drive
drive.mount('/content/drive') 

In [None]:
!git clone https://github.com/shashimalcse/Gaze_Dual_Attn.git

In [None]:
%cd /content/RetailGazeDataset/gaze/

In [None]:
import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import argparse
import os
from datetime import datetime
import shutil
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import cv2

from utils_logging import setup_logger

In [None]:
!pip install pytorch_lightning
!pip install git+https://github.com/zhanghang1989/ResNeSt
!pip install timm
!pip install pytorchcvtcolor

In [None]:
from models.dual_attn import Dual_Attn,train_new,train_new_goo
from models.__init__ import save_checkpoint, resume_checkpoint
from dataloader.dual_attn import RetailGaze,GooDataset
from dataloader import chong_imutils
from training.train_chong import train, test, GazeOptimizer

In [None]:
# Logger will save the training and test errors to a .log file 
logger = setup_logger(name='first_logger',
                      log_dir ='./logs/',
                      log_file='train_chong_gooreal.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

In [None]:
# Dataloaders for Retail Gaze
batch_size=8
workers=12

images_dir = '/content/drive/MyDrive/RetailGaze/RetailGaze_V2/'
pickle_path = '/content/drive/MyDrive/RetailGaze/RetailGaze_V3_train2.pickle'
test_images_dir = '/content/drive/MyDrive/RetailGaze/RetailGaze_V2/'
test_pickle_path = '/content/drive/MyDrive/RetailGaze/RetailGaze_V3_test2.pickle'
val_images_dir = '/content/drive/MyDrive/RetailGaze/RetailGaze_V2/'
val_pickle_path = '/content/drive/MyDrive/RetailGaze/RetailGaze_V3_valid2.pickle'

eval_images_dir = '/content/drive/MyDrive/RetailGaze/eval_images/'
eval_pickle_path = '/content/drive/MyDrive/RetailGaze/072224.pickle'

train_set = RetailGaze(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)

test_set = RetailGaze(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=4,
                            shuffle=False, num_workers=1)

val_set = RetailGaze(val_images_dir, val_pickle_path, 'train')
val_data_loader = DataLoader(val_set, batch_size=batch_size,
                            shuffle=False, num_workers=1)

eval_set = RetailGaze(eval_images_dir, eval_pickle_path, 'train')
eval_data_loader = DataLoader(eval_set, batch_size=batch_size,
                            shuffle=False, num_workers=1)

In [None]:
# Dataloaders for GOO 
batch_size=8
workers=12

images_dir = '/content/drive/MyDrive/gooreal/finalrealdatasetImgsV2'
pickle_path = '/content/drive/MyDrive/gooreal/oneshotrealhumansNew.pickle'
test_images_dir = '/content/drive/MyDrive/gooreal/finalrealdatasetImgsV2'
test_pickle_path = '/content/drive/MyDrive/gooreal/testrealhumansNew.pickle'
val_pickle_path = '/content/drive/MyDrive/gooreal/valrealhumansNew.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
val_set = GooDataset(images_dir, val_pickle_path, 'train')
val_data_loader = DataLoader(dataset=val_set,
                                           batch_size=4,
                                           shuffle=True,
                                           num_workers=1)
test_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size,
                            shuffle=False, num_workers=1)

In [None]:
model_ft = Dual_Attn()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft = model_ft.to(device)

criterion = nn.MSELoss().cuda()

# Observe that all parameters are being optimized
start_epoch = 0
max_epoch = 5
learning_rate = 1e-4

# Initializes Optimizer
gaze_opt = GazeOptimizer(model_ft, learning_rate)
optimizer = gaze_opt.getOptimizer(start_epoch)

In [None]:
model_ft = train_new_goo(model_ft,train_data_loader,val_data_loader, '/content/drive/MyDrive/DualAttention/shashimal6_new_for_graphs', optimizer,logger,1,num_epochs=50)

In [None]:
model_ft = train_new(model_ft,train_data_loader,val_data_loader, '/content/drive/MyDrive/DualAttention/shashimal6_new_retail_for_graphs', optimizer,logger,1,num_epochs=40)

In [None]:
checkpoint = torch.load("/content/drive/MyDrive/shashimal6_new2_goo_3.pt")
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

# Test Dual_Attm

In [None]:
from sklearn.metrics import roc_auc_score
model_ft.eval()
import glob
total_error = []
b_acc = []
all_predmap = []
all_gtmap = []
with torch.no_grad():
    for img, face, head_channel,object_channel,fov, eye,head_for_mask,gt_label,gaze_heatmap, image_path in test_data_loader:
        image =  img.cuda()
        face = face.cuda()
        object_channel = object_channel.cuda()
        head_point = head_for_mask.cuda()
        fov = fov.cuda()
        gaze_heatmap = gaze_heatmap.cuda()
        heatmap = model_ft(image,face,object_channel,head_point,fov)
        heatmap = heatmap.squeeze(1) 
        heatmap = heatmap.cpu().data.numpy()
        gaze_heatmap = gaze_heatmap.cpu().data.numpy()
        gt_label = gt_label.cpu().data.numpy()
        head = head_for_mask.cpu().data.numpy()
        for batch in range(img.shape[0]):
            output = heatmap[batch]
            output = output.clip(0)
            target = gaze_heatmap[batch]
            gt = gt_label[batch]
            gndt = gt_label[batch]
            head_position = head[batch]/[224,224]
            h_index, w_index = np.unravel_index(output.argmax(), output.shape)
            f_point = np.array([w_index / 64, h_index / 64])
            f_error = f_point - gt
            f_dist = np.sqrt(f_error[0] ** 2 + f_error[1] ** 2)
            f_direction = f_point - head_position
            gt_direction = gt - head_position
            ae = np.arccos(np.dot(f_direction,gt_direction)/np.sqrt(np.dot(f_direction,f_direction)*np.dot(gt_direction,gt_direction)))
            ae = np.maximum(np.minimum(ae,1.0),-1.0) * 180 / np.pi
            total_error.append([f_dist, ae])
            target = torch.zeros(64,64)
            target = chong_imutils.draw_labelmap(target, [w_index,h_index],3,type='Gaussian')
            target = target.cpu().data.numpy()
            fname_temp = image_path[batch].split('/')
            fname = '/'.join(fname_temp[:-1]+["masks/*.png"])
            seg_masks = glob.glob(fname)
            gt = gt_label[batch]*[224,224]
            gt = gt.astype(np.int)
            p_heatmap = cv2.resize(output, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
            g_heatmap = cv2.resize(target, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
            g_max_seg = float('-inf')
            p_max_seg = float('-inf')
            p_ch_seg_by_dis = None
            g_ch_seg_by_dis = None
            for m in seg_masks:
              msk = cv2.imread(m,0)
              msk=cv2.resize(msk, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
              pk = np.sum(msk*p_heatmap)
              gk = np.sum(msk*g_heatmap)
              if pk>p_max_seg:
                p_max_seg = pk
                p_ch_seg_by_dis = m  
              if gk>g_max_seg:
                g_max_seg = gk
                g_ch_seg_by_dis = m 
            if p_ch_seg_by_dis==g_ch_seg_by_dis:
              b_acc.append(1)
            else:
              b_acc.append(0) 
            g_heatmap_auc = np.zeros((64, 64))
            x, y = list(map(int, gndt * 64))
            if x==64:
              x=63
            if y==64:
              y=63              
            g_heatmap_auc[y, x] = 1.0
            all_predmap.append(output)
            all_gtmap.append(g_heatmap_auc)
    l2, ang = np.mean(np.array(total_error), axis=0)
    all_predmap = np.stack(all_predmap).reshape([-1])
    all_gtmap = np.stack(all_gtmap).reshape([-1])
    auc = roc_auc_score(all_gtmap, all_predmap)
    print(l2,ang,auc)
    print((sum(b_acc)/len(b_acc))*100)