# compute stats for line drawings

In [3]:
import numpy as np

np.array([0.485, 0.456, 0.406]).mean(), np.array([0.229, 0.224, 0.225]).mean()

(0.449, 0.226)

In [3]:
import os
import torch.nn as nn
from torchvision.datasets import ImageNet
from torchvision import transforms

class InvertTensorImageColors(nn.Module):
    def __init__(self):
        super(InvertTensorImageColors, self).__init__()

    def __call__(self, img):
        # Check if the image is in uint8 format
        if img.dtype == torch.uint8:
            return 255 - img
        # assume the image is in float format
        else:
            return 1 - img

In [4]:
import torch
from torch.utils.data import DataLoader
from fastprogress import progress_bar

from PIL import Image
import cv2

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
def cv2_loader(path, to_rgb=True):
    img = cv2.imread(path)
    if to_rgb: img = img[:,:,::-1]
    
    return img

def load_image(p, to_rgb=True):
    '''Our default image loader, takes `filename` and returns a PIL Image. 
        Speedwise, turbo_loader > pil_loader > cv2, but cv2 is the most robust, so 
        we try to load jpg images with turbo_loader, fall back to PIL, then cv2.
        
        This fallback behavior is needed, e.g., with ImageNet there are a few images
        that either aren't JPEGs or have issues that turbo_loader crashes on, but cv2 
        doesn't.
    '''
    if p.lower().endswith('.jpg') or p.lower().endswith('.jpeg'): 
        try:
            img = pil_loader(p)
        except:
            img = cv2.imread(p)
            if to_rgb: img = img[:,:,::-1]
    else:
        try:
            img = pil_loader(p)
        except:
            img = cv2.imread(p)
            if to_rgb: img = img[:,:,::-1]
                
    if img is not None and not isinstance(img, Image.Image):
        img = Image.fromarray(img)
        
    return img

# Custom collate function
def custom_collate(batch):
    # Extract images and labels
    images, labels = zip(*batch)
    return images, torch.tensor(labels)

def get_dataloader(dataset, batch_size=256, prefetch_factor=10,
                   num_workers=len(os.sched_getaffinity(0))):
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 
                            shuffle=False, pin_memory=True, prefetch_factor=prefetch_factor,
                            collate_fn=custom_collate)
    
    return dataloader

# compute line-drawing stats

In [5]:
style_name = "anime_style"
split = "train"
root_directory = os.path.join(os.environ['SHARED_DATA_DIR'], 'imagenet1k-line', f"imagenet1k-{style_name}")
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    InvertTensorImageColors(),
])
dataset = ImageNet(root_directory, split=split, transform=transform)
assert len(dataset)==1281167, f"Expected 1281167 val images, got {len(dataset)}"

dataset

Dataset ImageNet
    Number of datapoints: 1281167
    Root location: /n/alvarez_lab_tier1/Users/alvarez/datasets/imagenet1k-line/imagenet1k-anime_style
    Split: train
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               InvertTensorImageColors()
           )

In [6]:
dataloader = get_dataloader(dataset, batch_size=256, prefetch_factor=2)
print(dataloader.dataset)
dataloader

Dataset ImageNet
    Number of datapoints: 1281167
    Root location: /n/alvarez_lab_tier1/Users/alvarez/datasets/imagenet1k-line/imagenet1k-anime_style
    Split: train
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               InvertTensorImageColors()
           )


<torch.utils.data.dataloader.DataLoader at 0x1488c97ea8b0>

In [8]:
from pdb import set_trace
from collections import defaultdict

line_stats = defaultdict(list)
line_stats_nonzero = defaultdict(list)
for batch_idx,(imgs,labels) in enumerate(progress_bar(dataloader)):
    for img in imgs:
        img_flat = img.flatten(start_dim=-2) # flattened over space
        
        # including zeros
        img_mean = img_flat.mean(dim=-1) # averaged over space
        img_std = img_flat.std(dim=-1) # std over space
        line_stats['mean'].append(img_mean) 
        line_stats['std'].append(img_std)   
        
        # excluding zeros
        non_zero_mask = img_flat != 0
        
        # Initialize a tensor to store the means and stds
        non_zero_means = torch.zeros(3)
        non_zero_stds = torch.zeros(3)
        
        # Compute mean for each channel
        for channel_idx in range(3):
            # Select non-zero elements for the current channel
            non_zero_elements = img_flat[channel_idx][non_zero_mask[channel_idx]]

            # Compute the mean of these elements
            if non_zero_elements.nelement() != 0:  # Check to avoid division by zero
                non_zero_means[channel_idx] = non_zero_elements.mean()
                non_zero_stds[channel_idx] = non_zero_elements.std()

        line_stats_nonzero['mean'].append(non_zero_means) 
        line_stats_nonzero['std'].append(non_zero_stds)
    
for k,v in line_stats.items():
    line_stats[k] = torch.stack(line_stats[k])
    print(k, line_stats[k].shape)
    
for k,v in line_stats_nonzero.items():
    line_stats_nonzero[k] = torch.stack(line_stats_nonzero[k])
    print(k, line_stats_nonzero[k].shape)

mean torch.Size([1281167, 3])
std torch.Size([1281167, 3])
mean torch.Size([1281167, 3])
std torch.Size([1281167, 3])


In [9]:
for k,v in line_stats.items():
    print(k, line_stats[k].mean(dim=0))

mean tensor([0.0676, 0.0676, 0.0676])
std tensor([0.1459, 0.1459, 0.1459])


In [10]:
for k,v in line_stats_nonzero.items():
    print(k, line_stats_nonzero[k].mean(dim=0))

mean tensor([0.1241, 0.1241, 0.1241])
std tensor([0.1814, 0.1814, 0.1814])


# line stats

line drawing stats (including zeros):
- mean [0.0676, 0.0676, 0.0676]
- std [0.1459, 0.1459, 0.1459]

line drawing stats (excluding zeros):
- mean [0.1241, 0.1241, 0.1241]
- std [0.1814, 0.1814, 0.1814]

In [12]:
import numpy as np

np.array([0.1241, 0.1241, 0.1241])/0.1814

array([0.68412348, 0.68412348, 0.68412348])

In [13]:
np.array([0.1241, 0.1241, 0.1241])/np.mean([0.229, 0.224, 0.225])

array([0.54911504, 0.54911504, 0.54911504])