In [None]:
"""
My computer system specification
i9-10900X
RAM 64GB
Samsung SSD 970 PRO 512GB
RTX 3090 2 units

Window 11
Pytorch 2.5.1
Anaconda3-2024.10-1-Windows-x86_64
cudnn-windows-x86_64-8.9.7.29_cuda12-archive
cuda_12.4.0_windows_network

Email: sjw007s@korea.ac.kr
"""
import torch
import os
import torchvision.transforms.v2 as transforms_v2
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_file, decode_jpeg
from concurrent.futures import ThreadPoolExecutor
from threading import Lock


# GPU setting
torch.cuda.set_device(0)
print("GPU device currently in use:", torch.cuda.current_device())

# Parsing a mapping file (reading a text file) from 2017 ILSVRC kit for target label
def train_parse_mapping_file(mapping_file):
    class_to_idx = {}
    with open(mapping_file, 'r') as f:
        for line in f:
            folder, idx, _ = line.strip().split(' ', 2)
            class_to_idx[folder] = int(idx)-1  
    return class_to_idx

# Parsing validation ground truth file
def test_parse_mapping_file(mapping_file):
    class_to_idx = []
    with open(mapping_file, 'r') as f:
        for line in f:
            number = line.strip()
            class_to_idx.append(int(number)-1)
    return class_to_idx

# training data augmentation
transform_train = transforms_v2.Compose([
    transforms_v2.RandomResize(min_size=256, max_size=481),
    transforms_v2.RandomHorizontalFlip(p=0.5),
    transforms_v2.ToDtype(torch.float32, scale=True), 
    transforms_v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms_v2.RandomCrop((224, 224))
])

# test data augmentation
transform_test = transforms_v2.Compose([
    transforms_v2.Resize(256), 
    transforms_v2.CenterCrop(256),
    transforms_v2.ToDtype(torch.float32, scale=True),   
    transforms_v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms_v2.TenCrop(224)
])

# training dataset
class ImageNetDataset_train(Dataset): 
    def __init__(self, root_dir, mapping_file, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.class_to_idx = train_parse_mapping_file(mapping_file)
        self.img_paths = []
        self.labels = []
        self.lock = Lock()
        
        with ThreadPoolExecutor(max_workers=40) as executor:
            executor.map(self._scan_folder, os.listdir(root_dir))
        self.labels = torch.tensor(self.labels, dtype=torch.long, device='cuda')
        self.labels = F.one_hot(self.labels, num_classes=1000).float()
        print("training dataset load complete")

    def _scan_folder(self, class_folder):
        folder_path = os.path.join(self.root_dir, class_folder)
        
        for img_file in os.listdir(folder_path):
            img_path = os.path.join(folder_path, img_file)
            
            with self.lock:
                self.img_paths.append(img_path)
                self.labels.append(self.class_to_idx[class_folder])

    def __len__(self): 
        return len(self.img_paths) 
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]

        img_bytes = read_file(img_path)
        img_tensor = decode_jpeg(img_bytes, device='cuda')
    
        img_tensor = self.transform(img_tensor)
        
        return img_tensor, label

# test dataset
class ImageNetDataset_test(Dataset):
    def __init__(self, root_dir, mapping_file, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.img_paths = []
        self.labels = test_parse_mapping_file(mapping_file)
        self._scan_folder()
        self.labels = torch.tensor(self.labels, dtype=torch.long, device='cuda')
        self.labels = F.one_hot(self.labels, num_classes=1000).float()
        print("test dataset load complete")

    def _scan_folder(self):
        for img_file in sorted(os.listdir(self.root_dir)):
            img_path = os.path.join(self.root_dir, img_file)
            self.img_paths.append(img_path)

    def __len__(self): 
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]

        img_bytes = read_file(img_path)
        img_tensor = decode_jpeg(img_bytes, device='cuda')
        
        img_tensor = self.transform(img_tensor)
        
        return img_tensor, label

def test_collate(batch):
    imgs, labels = zip(*batch)  
    imgs = list(imgs)
    for i in range(50):
        imgs[i] = torch.stack(imgs[i])
    imgs = torch.stack(imgs)
    imgs = imgs.reshape(500, 3, 224, 224)
    
    labels = torch.stack(labels)
    labels = torch.repeat_interleave(labels, 10, dim=0)
    return imgs, labels

train_dir = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_img_train"  # training data location
train_mapping_file = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\map_clsloc.txt"  # training data mapping file location
trainset = ImageNetDataset_train(root_dir=train_dir, mapping_file=train_mapping_file, transform=transform_train) 
train_dataloader = DataLoader(trainset, batch_size=512, shuffle=True)
################################################################
test_dir = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_img_val"  # test data location
test_mapping_file = r"C:\Users\sjw00\OneDrive\Desktop\dataset\imagenet\ILSVRC2012_validation_ground_truth.txt"  # test data target label location
testset = ImageNetDataset_test(root_dir=test_dir, mapping_file = test_mapping_file, transform=transform_test)  
test_dataloader = DataLoader(testset, batch_size=50, shuffle=False, collate_fn = test_collate) 

for i, (batch, target) in enumerate(train_dataloader):
    print(i, batch.shape, target.shape)
    break

for i, (batch, target) in enumerate(test_dataloader):
    print(i, batch.shape, target.shape)
    break
