In [1]:
import sys
import os 

sys.path.append("../")
os.chdir("../")

from tqdm import tqdm 
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import torchvision

from copy import deepcopy

import torch
import torchvision
from PIL import Image
from src.utils import init_object

In [2]:
# check accuracy function
def check(model, loader):
    with torch.inference_mode():
        model.eval()
        accuracy = 0
        for batch in loader:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to("cuda").float(), labels.to("cuda").long()
            logits = model(images)
            ids = F.softmax(logits, dim=-1).argmax(dim=-1)
            batch_true = (ids == labels).sum()
            accuracy += batch_true.item()
        return accuracy / len(loader.dataset)

In [3]:
inds = 0

In [4]:
class PACS(torch.utils.data.Dataset):
    """Class with standard methods for torch dataset for working with PACS dataset.
    Inherited from standard class torch.utils.data.Dataset.
    Dataset paper: https://arxiv.org/abs/1710.03077.
    """

    def __init__(
            self,
            dataset_type: list[str],
            domain_list: list[str],
            transforms: torchvision.transforms.Compose,
            augmentations: torchvision.transforms.Compose = None) -> None:
        """Dataset initialization. Creates images list (where file paths are stored)
        and classes (labels) torch.Tensor for them.

        Args:
            dataset_types (list[str]): list of values from {'train', 'test'}.
            domain (list[str]): list of values from {'art_painting', 'cartoon', 'photo', 'sketch'}.
            transforms (torchvision.transforms.Compose): transforms that are applied to each
                image (regardless of whether it is in train or test selection).
            augmentations (torchvision.transforms.Compose,
            optional): augmentations that apply only to the train selection. Defaults to None.
        """
        self.images = []
        self.labels = torch.Tensor([])
        self.domain_list = domain_list
        for domain in domain_list:
            imgs, lbls = self.get_paths_and_labels(dataset_type, domain)
            self.images += imgs
            self.labels = torch.cat((self.labels, lbls))

        self.transforms = transforms
        self.augmentations = augmentations

    def get_paths_and_labels(self,
                             dataset_types: list[str],
                             domain: str) -> tuple[list[str],
                                                   torch.Tensor]:
        """Return list of images paths for a given type of the dataset.

        Args:
            dataset_types (list[str]): list of values from {'train', 'test'}.
            domain (str): one of 'art_painting', 'cartoon', 'photo', 'sketch'.

        Returns:
            tuple[list[str], torch.Tensor]: paths to images and tensor with class labels.
        """

        paths = []
        labels = []
        for ds_type in dataset_types:
            filepath = f"data/pacs/labels/{domain}_{ds_type}.txt"
            f = open(filepath, 'r')
            lines = f.readlines()
            f.close()
            lines = [l.split() for l in lines]
            cur_paths, cur_labels = zip(*lines)
            cur_labels = [int(l) for l in cur_labels]
            paths += cur_paths
            labels += cur_labels
        global inds
        inds = np.zeros(len(paths))        
        return paths, torch.Tensor(labels)

    def __len__(self) -> int:
        """Returns the number of images in the dataset.

        Returns:
            int: dataset len
        """
        return len(self.images)

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        """Returns a picture from the dataset by its number.
        First, the image is read along the path, augmentations are applied to it (if necessary), then transforms.
        Also, the class label is returned.

        Args:
            idx (int): index of image
        Returns:
            dict[str, torch.Tensor]: dict:
                {
                    "image": image torch.Tensor,
                    "label": label torch.Tensor
                }
        """
        img_name = self.images[idx]
        label = self.labels[idx]

        image = Image.open(img_name)

        if self.augmentations:
            sample = {
                'image':
                self.augmentations(image)
            }
        else:
            sample = {
                'image': image,
            }

        sample['image'] = self.transforms(sample['image'])
        sample['label'] = label
        inds[idx] += 1
        return sample

In [5]:
# dataset config
dataset = {
    "name": "PACS",
    "kwargs": {
        "domain_list": ["art_painting", "photo", "sketch", "cartoon"],
        "transforms": [
                {
                    "name": "ToTensor",
                    "kwargs": {}
                },
                {  
                    "name": "Normalize",
                    "kwargs": {
                        "mean": [0.5, 0.5, 0.5],
                        "std": [0.5, 0.5, 0.5]
                    }
                }
            ]
    }
}

dataset["kwargs"]["transforms"] = torchvision.transforms.Compose(
    [init_object(torchvision.transforms, obj_config)
        for obj_config in dataset["kwargs"]["transforms"]]
)

domains = dataset["kwargs"]["domain_list"]

In [6]:
def create_loader(test_domain):
    test_dataset = deepcopy(dataset)
    test_dataset["kwargs"]["dataset_type"] = ["train", "test"]

  
    test_dataset["kwargs"]["domain_list"] = [domains[test_domain]]
    test_dataset["kwargs"]["augmentations"] = None

   
    test_dataset = PACS(**test_dataset["kwargs"])

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=64,
        shuffle=True)
    return test_loader, test_dataset

In [7]:
from itertools import cycle, islice
test_loader, test_dataset = create_loader(0)

In [8]:
def num_iters_loader(loader: torch.utils.data.DataLoader, num_iters: int):
    iter = 0
    while iter < num_iters:
        for batch in loader:
            yield batch 
            iter += 1
            if iter == num_iters:
                break

In [9]:
num = 0
for batch in num_iters_loader(test_loader, 500):
    num += batch["image"].shape[0]
    print(num)

64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024
1088
1152
1216
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
4480
4544
4608
4672


In [None]:
num

32000

In [None]:
inds.shape[0], inds.sum()

(2048, 2048.0)