In [6]:
import os
import time
import copy
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import cv2
from torchvision import transforms, datasets
from Datasets.train_data import SaliencyDataset

from utils.loss_function import SaliencyLoss
#from utils.data_process import MyDataset
from TranSalNet_Res import TranSalNet

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [8]:
train_dataset = SaliencyDataset(
    image_dir=r"E:\Digital Image Processing\Assignment 3\TranSalNet-Res\Datasets\Images\images",
    saliency_dir=r"E:\Digital Image Processing\Assignment 3\TranSalNet-Res\Datasets\TD_FixMaps",
    fixation_dir=r"E:\Digital Image Processing\Assignment 3\TranSalNet-Res\Datasets\TD_FixPts",  
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = TranSalNet()
model = model.to(device)

In [10]:
optimizer = optim.Adam(model.parameters(),lr=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
loss_fn = SaliencyLoss()

save_dir = 'E:\Digital Image Processing\Assignment 3\TranSalNet-Res\Datasets\Saliency_map_TD'
os.makedirs(save_dir, exist_ok=True)

'''Training'''
best_model_wts = copy.deepcopy(model.state_dict())
num_epochs =10  
best_loss = 100
for k,v in model.named_parameters():
    print('{}: {}'.format(k, v.requires_grad))

total_loss = 0.0

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch + 1, num_epochs))
    print('-' * 10)
    model.train()
    running_loss = 0.0

    for images, TD_maps, TD_FixPts in train_loader:
        images, TD_maps, TD_FixPts = images.to(device), TD_maps.to(device), TD_FixPts.to(device)

        optimizer.zero_grad()
        outputs = model(images)  # 🔹 Predicted saliency map
        
        loss = -2*loss_fn(outputs,TD_maps,loss_type='cc')\
                        -1*loss_fn(outputs,TD_maps,loss_type='sim')+\
                        10*loss_fn(outputs,TD_maps,loss_type='kldiv')-1*loss_fn(outputs,TD_FixPts,loss_type='nss')
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

backbone.0.weight: True
backbone.1.weight: True
backbone.1.bias: True
backbone.4.0.conv1.weight: True
backbone.4.0.bn1.weight: True
backbone.4.0.bn1.bias: True
backbone.4.0.conv2.weight: True
backbone.4.0.bn2.weight: True
backbone.4.0.bn2.bias: True
backbone.4.0.conv3.weight: True
backbone.4.0.bn3.weight: True
backbone.4.0.bn3.bias: True
backbone.4.0.downsample.0.weight: True
backbone.4.0.downsample.1.weight: True
backbone.4.0.downsample.1.bias: True
backbone.4.1.conv1.weight: True
backbone.4.1.bn1.weight: True
backbone.4.1.bn1.bias: True
backbone.4.1.conv2.weight: True
backbone.4.1.bn2.weight: True
backbone.4.1.bn2.bias: True
backbone.4.1.conv3.weight: True
backbone.4.1.bn3.weight: True
backbone.4.1.bn3.bias: True
backbone.4.2.conv1.weight: True
backbone.4.2.bn1.weight: True
backbone.4.2.bn1.bias: True
backbone.4.2.conv2.weight: True
backbone.4.2.bn2.weight: True
backbone.4.2.bn2.bias: True
backbone.4.2.conv3.weight: True
backbone.4.2.bn3.weight: True
backbone.4.2.bn3.bias: True
backb

In [12]:
import matplotlib.pyplot as plt
import numpy as np
import torch

model.eval()  # Set the model to evaluation mode to disable dropout

# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)

# Get a batch from the validation set (or train set) to predict the final saliency map
with torch.no_grad():
    for idx, (images, saliency_maps, fixation_maps) in enumerate(train_loader):
        images = images.to(device)

        outputs = model(images)

        for i in range(images.size(0)):
            output_map = outputs[i].cpu().detach().numpy()
            output_map = np.squeeze(output_map)

            plt.imshow(output_map, cmap='jet')
            plt.axis('off')
            plt.colorbar()

            image_filename = f"{idx * len(images) + i + 1}_predicted_saliency_map_epoch_{epoch+1}.png"
            plt.savefig(os.path.join(save_dir, image_filename), bbox_inches='tight')
            plt.close()
