In [1]:
from torchvision import datasets, transforms
from timm.data import create_transform
import torch
import numpy as np
from fast_slic.avx2 import SlicAvx2
import os
import shutil
import math
import cv2
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# User variables
data_path = "/home/sk138/data/"
data_folder_name = "cifar-100-python/"

save_path = "/home/sk138/data/cifar-100-python-segmented/"
input_size = 224
num_workers = 10
n_segments=196
pin_mem = True
batch_size=2000
DEVICE='cuda:1'

# Pathing
save_parent = f"cifar-{n_segments}-BoW-rect/"
full_save_path = os.path.join(save_path, save_parent)

# Transforms - do not do the normalization
train_transform = create_transform(input_size, is_training=True, no_aug=True)
test_transform = create_transform(input_size, is_training=False, no_aug=True)

train_transform.transforms = train_transform.transforms[:-1]
test_transform.transforms = test_transform.transforms[:-1]

grayscale_transform = transforms.Grayscale()

# stats = torch.load('train_stats.pt')
# standard_transform_mag = transforms.Normalize(mean=stats[0], std=stats[1])
# standard_transform_phase = transforms.Normalize(mean=stats[2], std=stats[3])


print(train_transform)
print(test_transform)

# Make directories / delete them
if not os.path.exists(full_save_path):
    # Create the directory
    os.makedirs(full_save_path)

else: 
    user_input = input(f"The directory specified has already been created. Would you like to delete the directory and replace with new data? (yes/no)")

    if user_input.lower() == "yes":
        shutil.rmtree(full_save_path)
        os.makedirs(full_save_path)
    else:
        print("Program execution canceled. Please change pathing inputs to save to a new directory.")



Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
)
Compose(
    Resize(size=256, interpolation=bilinear, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
)
Program execution canceled. Please change pathing inputs to save to a new directory.


In [3]:
# Create data loaders
train_dataset = datasets.CIFAR100(data_path, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR100(data_path, train=False, transform=test_transform, download=True)

class_names = train_dataset.classes

train_sampler = torch.utils.data.SequentialSampler(train_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

data_loader_train = torch.utils.data.DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_mem,
    drop_last=False,
)

data_loader_test = torch.utils.data.DataLoader(
    test_dataset,
    sampler=test_sampler,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_mem,
    drop_last=False,
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def process_segment(x, n_segments=196):
    B, C, H, W = x.shape
    x = x.permute(0, 2, 3, 1) # Change channels to be last dimension
    x = x.cpu().numpy()

    if not x.flags['C_CONTIGUOUS']:
        x = x.copy(order='C')

    assert x.flags['C_CONTIGUOUS']

     # Iterate over each image in batch and get segmentation mask
    save_mask = torch.zeros((B, H, W), device=DEVICE)
    for i, img in enumerate(x):
        cp_img = np.squeeze(img)

        # slic = Slic(num_components=n_segments, min_size_factor=0)
        # slic = SlicAvx2(num_components=n_segments, min_size_factor=0)
        # segmentation_mask = slic.iterate(cp_img)
        # print(i, len(np.unique(segmentation_mask)))
        # assert len(np.unique(segmentation_mask)) == n_segments, f"Got {len(np.unique(segmentation_mask))} segments from SLIC, but expected {n_segments}"

        segmentation_mask = np.zeros((224, 224))

        for w in range(0, 224, 16):
            for j in range(0, 224, 16):
                segmentation_mask[w:w+16, j:j+16] = int(w/16 + (j/16)*14)

        save_mask[i, :, :] = torch.from_numpy(segmentation_mask).to(DEVICE)
        
    # print(torch.unique(save_mask, return_counts=True))    
    return save_mask

def process_BoW(x, save_mask, n_segments=196):

    # Format
    x = x.permute(0, 2, 3, 1)
    B = x.shape[0]
    
    # Allocate
    seg_out = torch.zeros((B, n_segments, 256*3))
    pos_out = torch.zeros((B, n_segments, 5))

    # TODO: Verify implementation
    for i in tqdm(range(n_segments)):
        # print(i)
        # Get mask for segment i for all images in batch
        binary_mask = (save_mask == i)

        for j, image_mask in enumerate(binary_mask):
            # Crop image j based on binary mask for segment i
            cpu_image_mask = np.uint8(image_mask.cpu().numpy())
            coords = cv2.findNonZero(cpu_image_mask)
            bbox_x, bbox_y, bbox_w, bbox_h = cv2.boundingRect(coords)
            cropped_image = x[j, bbox_y:bbox_y+bbox_h, bbox_x:bbox_x+bbox_w, :]

            # Get cropped binary mask and find statistics needed for distribution
            cropped_mask = image_mask[bbox_y:bbox_y+bbox_h, bbox_x:bbox_x+bbox_w]
            num_zeros_false = torch.sum(~cropped_mask) # Number of zeros in each mask (to use in pdf calculation)
            norm_factor = torch.sum(cropped_mask) # Number of pixels in each mask (divide in pdf calculation to shift to probability)
            norm_factor[norm_factor == 0] = 1 # So we don't divide by 0

            # Get unique distribution for each channel of the image
            distribution = torch.stack([ torch.bincount(channel.flatten().to(torch.int), minlength=256) for channel in cropped_image.permute(2,0,1)], dim=-1) # Iterate over each channel. Find distribution and stack

            distribution[0, :] = distribution[0, :] - num_zeros_false.unsqueeze(-1) # Subtract number of False values in binary mask from each image (as it was counted in distribution). Broadcasted across channels
            distribution = distribution / norm_factor.unsqueeze(-1) # Divide by total number of pixels in mask to shift to a probability distribution    

            # Save
            # assert torch.sum(torch.isnan(distribution)).item() == 0, "NaN element in the distribution"
            # assert torch.sum(distribution, dim=0).all() == 1, f"Not a valid probability distribution, sum across al channels is {torch.sum(distribution)}, expected 3.0"
            # assert torch.sum(distribution) >= 0, "Negative value"
            
            seg_out[j, i, :] = distribution.flatten()
        
        # num_zeros_false = torch.sum(~binary_mask, axis=(1,2)) # Number of zeros in each mask (to use in pdf calculation)
        # area = torch.sum(binary_mask, axis=(1,2)) # Number of pixels in each mask (divide in pdf calculation to shift to probability)
        # norm_factor = area.clone()
        # norm_factor[area == 0] = 1 # So that we don't divide by 0

        # # Get segment data for all images in batch
        # segmented_imgs = binary_mask.unsqueeze(-1) * x # Broadcast mask and get segment i in all images in batch (1000, 224, 224, 3)

        # # Get unique distribution for each image and each channe;
        # distribution = torch.stack([ torch.stack([ torch.bincount(channel.flatten().to(torch.int), minlength=256) for channel in img.permute(2,0,1)], dim=-1) for img in segmented_imgs]) # Iterate over each image. Iterate over each channel. Find distribution and stack

        # distribution[:, 0, :] = distribution[:, 0, :] - num_zeros_false.unsqueeze(-1) # Subtract number of False values in binary mask from each image (as it was counted in distribution). Broadcasted across channels
        # distribution = distribution / norm_factor.unsqueeze(-1).unsqueeze(-1) # Divide by total number of pixels in mask to shift to a probability distribution      


        # Find positional info (Area, center coordinates, max width/height)
        area = torch.sum(binary_mask, dim=(1,2)) / (binary_mask.shape[1] * binary_mask.shape[2])

        # If the binary mask has any True values for segment i, then find the center coordinate. Else, return 0
        centroid = torch.stack([ (torch.mean(torch.argwhere(img_mask).float(),axis=0) / torch.tensor(img_mask.shape, device=DEVICE)) if torch.argwhere(img_mask).numel() > 0 else torch.zeros((2), device=DEVICE) for img_mask in binary_mask])
        center_x = centroid[:, 1]
        center_y = centroid[:, 0]

        # If the binary mask has any True values for segment i, then find the max width/height of the segment. Else, return 0
        rect = torch.stack([ (1 + torch.max(torch.argwhere(img_mask), axis=0).values - torch.min(torch.argwhere(img_mask), axis=0).values) / torch.tensor(img_mask.shape, device=DEVICE) if torch.argwhere(img_mask).numel() > 0 else torch.zeros((2), device=DEVICE) for img_mask in binary_mask])

        width = rect[:, 1]
        height = rect[:, 0]

        # Save
        pos = torch.stack([area, center_x, center_y, width, height])
        assert torch.sum(torch.isnan(pos)).item() == 0, "NaN element in the positional stats"
        pos_save = torch.transpose(pos, 1, 0) #TODO: make this device (how does this work with DataParallel)
        pos_out[:, i, :] = pos_save

        # assert torch.sum(torch.isnan(distribution)).item() == 0, "NaN element in the distribution"
        # seg_out[:, i, :] = distribution.reshape(B, -1)
            
    return seg_out, pos_out

In [5]:
data_type = "train"
use = data_loader_train if data_type=="train" else data_loader_test

# Iterate over dataloader
for i, (inputs, labels) in enumerate(use):
    print(i)
    inputs = inputs.to(DEVICE)
    inputs *= 255
    inputs = inputs.to(torch.uint8)
    

    # Segment image, take FT, and find pos embed. 
    seg_mask = process_segment(inputs, n_segments=n_segments)

    seg_out, pos_out = process_BoW(inputs, seg_mask, n_segments=n_segments)

    save_data = torch.cat((seg_out, pos_out), dim=2).cpu().numpy() # Stack data and format for save

    # Iterate over data and save
    for i, data in enumerate(save_data):
        label = labels[i]
        class_name = class_names[label]
        class_path = os.path.join(full_save_path, data_type, class_name)
        # print(class_path)

        # Create class folder if it does not exist
        if not os.path.exists(class_path):
            os.makedirs(class_path)

        save_num = len(os.listdir(class_path)) 
        np.savez_compressed(os.path.join(class_path, f'{save_num}.npz'), data=data)    
    

0


100%|██████████| 196/196 [05:20<00:00,  1.64s/it]


1


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]


2


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


3


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


4


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


5


100%|██████████| 196/196 [05:24<00:00,  1.65s/it]


6


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


7


100%|██████████| 196/196 [05:23<00:00,  1.65s/it]


8


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


9


100%|██████████| 196/196 [05:23<00:00,  1.65s/it]


10


100%|██████████| 196/196 [05:24<00:00,  1.65s/it]


11


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]


12


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


13


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


14


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]


15


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


16


100%|██████████| 196/196 [05:24<00:00,  1.65s/it]


17


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


18


100%|██████████| 196/196 [05:37<00:00,  1.72s/it]


19


100%|██████████| 196/196 [05:23<00:00,  1.65s/it]


20


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]


21


100%|██████████| 196/196 [05:24<00:00,  1.65s/it]


22


100%|██████████| 196/196 [05:24<00:00,  1.66s/it]


23


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]


24


100%|██████████| 196/196 [05:25<00:00,  1.66s/it]
