In [None]:
"""
My computer system specification
i9-10900X
RAM 64GB
Samsung SSD 970 PRO 512GB
RTX 3090 2 units

Window 11
Pytorch 2.5.1
Anaconda3-2024.10-1-Windows-x86_64
cudnn-windows-x86_64-8.9.7.29_cuda12-archive
cuda_12.4.0_windows_network

Email: sjw007s@korea.ac.kr
"""
import torch
import os
import torchvision.transforms.v2 as transforms_v2
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_file, decode_jpeg
from concurrent.futures import ThreadPoolExecutor
from threading import Lock


# GPU setting
torch.cuda.set_device(0) # Set the GPU device to be used (device 0)
print("GPU device currently in use:", torch.cuda.current_device()) # Print the current GPU device

# Parsing a mapping file (reading a text file) from 2017 ILSVRC kit for target label
def train_parse_mapping_file(mapping_file):
    class_to_idx = {} # Dictionary to store class-to-index mappings
    with open(mapping_file, 'r') as f:
        for line in f:
            folder, idx, _ = line.strip().split(' ', 2) # Split each line by space into folder name and index
            class_to_idx[folder] = int(idx)-1   # Map the folder to its corresponding index (adjusted by -1 for zero-indexing)
    return class_to_idx

# Parsing validation ground truth file
def test_parse_mapping_file(mapping_file):
    class_to_idx = [] # List to store validation labels
    with open(mapping_file, 'r') as f:
        for line in f:
            number = line.strip() # Read each line and strip any extra whitespace
            class_to_idx.append(int(number)-1) # Append the class index to the list (adjusted by -1)
    return class_to_idx

# training data augmentation
transform_train = transforms_v2.Compose([
    transforms_v2.RandomResize(min_size=256, max_size=481), # Randomly resize image between 256 and 481 pixels
    transforms_v2.RandomHorizontalFlip(p=0.5), # 50% chance of horizontally flipping the image
    transforms_v2.ToDtype(torch.float32, scale=True),  # Convert image to float32 and scale to [0, 1]
    transforms_v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet mean and std
    transforms_v2.RandomCrop(224) # Randomly crop to 224x224
])

# test data augmentation
transform_test = transforms_v2.Compose([
    transforms_v2.Resize(256),  # Resize the shorter side to 256 pixels
    transforms_v2.CenterCrop(256), # Center crop the image to 256x256
    transforms_v2.ToDtype(torch.float32, scale=True),   # Convert to float32 and scale
    transforms_v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
    transforms_v2.TenCrop(224) # Apply ten-crop augmentation (corner and center crops)
])

# training dataset
class ImageNetDataset_train(Dataset): 
    def __init__(self, root_dir, mapping_file, transform):
        self.root_dir = root_dir # Directory containing training images
        self.transform = transform # Transformations to apply to each image
        self.class_to_idx = train_parse_mapping_file(mapping_file) # Parse mapping file for class indices
        self.img_paths = [] # List to store image file paths
        self.labels = [] # List to store labels
        self.lock = Lock() # Lock to ensure thread safety during multithreading
        
        with ThreadPoolExecutor(max_workers=40) as executor: # Use ThreadPoolExecutor to scan the folder concurrently
            executor.map(self._scan_folder, os.listdir(root_dir))

        self.labels = torch.tensor(self.labels, dtype=torch.long) # Convert labels to tensor and apply one-hot encoding
        self.labels = F.one_hot(self.labels, num_classes=1000).float()

        print("training dataset load complete") # Print completion message

    def _scan_folder(self, class_folder): # Scan each folder for images and assign labels
        folder_path = os.path.join(self.root_dir, class_folder) # Get full folder path
        
        for img_file in os.listdir(folder_path): # Iterate through images in the folder
            img_path = os.path.join(folder_path, img_file) # Get full image path
            
            label = self.class_to_idx[class_folder] # Get label from class_to_idx mapping

            with self.lock: # Ensure thread safety
                self.img_paths.append(img_path) # Add image path to list
                self.labels.append(label) # Add label to list

    def __len__(self):  # Return the total number of images in the dataset
        return len(self.img_paths) 
    
    def __getitem__(self, idx): # Get image and label by index
        img_path = self.img_paths[idx] # Get image path
        label = self.labels[idx] # Get label

        img_bytes = read_file(img_path) # Read image as bytes
        img_tensor = decode_jpeg(img_bytes, device='cuda') # Decode JPEG and move to GPU
    
        img_tensor = self.transform(img_tensor) # Apply transformations to the image
        
        return img_tensor, label.to('cuda') # Return image and label (moved to GPU)

# test dataset
class ImageNetDataset_test(Dataset):
    def __init__(self, root_dir, mapping_file, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.img_paths = []
        self.labels = test_parse_mapping_file(mapping_file)
        self._scan_folder()
        self.labels = torch.tensor(self.labels, dtype=torch.long)
        self.labels = F.one_hot(self.labels, num_classes=1000).float()
        
        print("test dataset load complete")

    def _scan_folder(self): 
        for img_file in sorted(os.listdir(self.root_dir)): # Scan images in sorted order
            img_path = os.path.join(self.root_dir, img_file)
            self.img_paths.append(img_path)

    def __len__(self): 
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]

        img_bytes = read_file(img_path)
        img_tensor = decode_jpeg(img_bytes, device='cuda')
        
        img_tensor = self.transform(img_tensor)
        
        return img_tensor, label.to('cuda')

def test_collate(batch): # Custom collate function for batching test data
    imgs, labels = zip(*batch)   # Unzip batch into images and labels
    imgs = list(imgs) # Convert to list for stacking
    for i in range(50):
        imgs[i] = torch.stack(imgs[i]) # Stack ten-crop images
    imgs = torch.stack(imgs) # Stack into final batch
    imgs = imgs.reshape(500, 3, 224, 224) # Reshape to batch size 500
    
    labels = torch.stack(labels) # Stack labels
    labels = torch.repeat_interleave(labels, 10, dim=0) # Repeat labels for ten-crop
    return imgs, labels

train_dir = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_img_train"  # training data location
train_mapping_file = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\map_clsloc.txt"  # training data mapping file location
trainset = ImageNetDataset_train(root_dir=train_dir, mapping_file=train_mapping_file, transform=transform_train) 
train_dataloader = DataLoader(trainset, batch_size=512, shuffle=True)
################################################################
test_dir = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_img_val"  # test data location
test_mapping_file = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_validation_ground_truth.txt"  # test data target label location
testset = ImageNetDataset_test(root_dir=test_dir, mapping_file = test_mapping_file, transform=transform_test)  
test_dataloader = DataLoader(testset, batch_size=50, shuffle=False, collate_fn = test_collate) 
