In [3]:
from torchvision import datasets, transforms
from timm.data import create_transform
import torch
import numpy as np
from fast_slic.avx2 import SlicAvx2
import os
import shutil
import math


In [4]:
# User variables
data_path = "/home/sk138/data/"
data_folder_name = "cifar-100-python/"

save_path = "/home/sk138/data/cifar-100-python-segmented/"
input_size = 224
num_workers = 10
n_segments=196
pin_mem = True
n_points=64
batch_size=1000
DEVICE='cuda:1'

# Pathing
save_parent = f"cifar-{n_segments}-{n_points}-standard/"
full_save_path = os.path.join(save_path, save_parent)

# Transforms
train_transform = create_transform(input_size, is_training=True, no_aug=True)
test_transform = create_transform(input_size, is_training=False, no_aug=True)
grayscale_transform = transforms.Grayscale()

stats = torch.load('train_stats.pt')
standard_transform_mag = transforms.Normalize(mean=stats[0], std=stats[1])
standard_transform_phase = transforms.Normalize(mean=stats[2], std=stats[3])


print(train_transform)
print(test_transform)

# Make directories / delete them
if not os.path.exists(full_save_path):
    # Create the directory
    os.makedirs(full_save_path)

else: 
    user_input = input(f"The directory specified has already been created. Would you like to delete the directory and replace with new data? (yes/no)")

    if user_input.lower() == "yes":
        shutil.rmtree(full_save_path)
        os.makedirs(full_save_path)
    else:
        print("Program execution canceled. Please change pathing inputs to save to a new directory.")



Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
Compose(
    Resize(size=256, interpolation=bilinear, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


In [5]:
# Create data loaders
train_dataset = datasets.CIFAR100(data_path, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR100(data_path, train=False, transform=test_transform, download=True)

class_names = train_dataset.classes

train_sampler = torch.utils.data.SequentialSampler(train_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

data_loader_train = torch.utils.data.DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_mem,
    drop_last=False,
)

data_loader_test = torch.utils.data.DataLoader(
    test_dataset,
    sampler=test_sampler,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_mem,
    drop_last=False,
)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
def process_segment(x, n_segments=196):
    B, C, H, W = x.shape
    x = x.permute(0, 2, 3, 1) # Change channels to be last dimension
    x = x.cpu().numpy()

    if not x.flags['C_CONTIGUOUS']:
        x = x.copy(order='C')

    assert x.flags['C_CONTIGUOUS']

     # Iterate over each image in batch and get segmentation mask
    save_mask = torch.zeros((B, H, W), device=DEVICE)
    for i, img in enumerate(x):
        cp_img = np.squeeze(img)

        # slic = Slic(num_components=n_segments, min_size_factor=0)
        slic = SlicAvx2(num_components=n_segments, min_size_factor=0)
        segmentation_mask = slic.iterate(cp_img)
        # print(i, len(np.unique(segmentation_mask)))
        # assert len(np.unique(segmentation_mask)) == n_segments, f"Got {len(np.unique(segmentation_mask))} segments from SLIC, but expected {n_segments}"

        save_mask[i, :, :] = torch.from_numpy(segmentation_mask).to(DEVICE)
        
    return save_mask

def process_ft(x, save_mask, n_segments=196, n_points=64, grayscale=True):

    # Format
    x = x.permute(0, 2, 3, 1).squeeze()
    B = x.shape[0]
    
    # Allocate
    seg_out = torch.zeros((B, n_segments, int(n_points * (n_points/2 + 1) * 2)))
    pos_out = torch.zeros((B, n_segments, 5))

    # TODO: Verify implementation
    for i in range(n_segments):
        # print(i)
        # Get mask for segment i for all images in batch
        binary_mask = (save_mask == i)
        # print(binary_mask.shape)

        # Get segment data for all images in batch
        segmented_imgs = binary_mask * x

        # Take FT and separate magnitude and phase info
        # fourier_transform = torch.fft.fft2(segmented_imgs, s=(n_points, n_points))
        fourier_transform = torch.fft.rfft2(segmented_imgs, s=(n_points, n_points))
        magnitude = torch.abs(fourier_transform) 
        phase = torch.angle(fourier_transform) 

        assert torch.sum(torch.isnan(magnitude)).item() == 0, "NaN element in the magnitude before normalization"

        # Normalize scales & standardize data according to dataset statistics
        # magnitude = magnitude / (1 if torch.max(magnitude)==0 else torch.max(magnitude))  # [0 1] -- divide by 1 if max is already 0.
        magnitude = standard_transform_mag(magnitude)
        phase = phase / torch.tensor(math.pi) # [-1 1]
        phase = standard_transform_phase(phase)

        # print(torch.max(magnitude))
        # print(torch.max(phase))
        # print(torch.min(phase))

        assert torch.sum(torch.isnan(magnitude)).item() == 0, "NaN element in the magnitude after normalization"
        assert torch.sum(torch.isnan(phase)).item() == 0, "NaN element in the magnitude"

        # Save
        to_save = torch.stack((magnitude, phase)).reshape((fourier_transform.shape[0], -1))
        seg_out[:, i, :] = to_save

        # Find positional info (Area, center coordinates, max width/height)
        area = torch.sum(binary_mask, axis=(1,2)) / (binary_mask.shape[1] * binary_mask.shape[2])

        # If the binary mask has any True values for segment i, then find the center coordinate. Else, return 0
        centroid = torch.stack([ (torch.mean(torch.argwhere(img_mask).float(),axis=0)) / torch.tensor(img_mask.shape, device=DEVICE) if torch.argwhere(img_mask).numel() > 0 else torch.zeros((2), device=DEVICE) for img_mask in binary_mask])
        center_x = centroid[:, 0]
        center_y = centroid[:, 1]

        # If the binary mask has any True values for segment i, then find the max width/height of the segment. Else, return 0
        rect = torch.stack([ (torch.max(torch.argwhere(img_mask), axis=0).values - torch.min(torch.argwhere(img_mask), axis=0).values) / torch.tensor(img_mask.shape, device=DEVICE) if torch.argwhere(img_mask).numel() > 0 else torch.zeros((2), device=DEVICE) for img_mask in binary_mask])

        width = rect[:, 0]
        height = rect[:, 1]

        # Save
        pos = torch.stack([area, center_x, center_y, width, height])
        assert torch.sum(torch.isnan(pos)).item() == 0, "NaN element in the positional stats"
        pos_save = torch.t(pos) #TODO: make this device (how does this work with DataParallel)
        pos_out[:, i, :] = pos_save
            
    return seg_out, pos_out

In [17]:
data_type = "test"
use = data_loader_train if data_type=="train" else data_loader_test

# Iterate over dataloader
for i, (inputs, labels) in enumerate(use):
    print(i)
    inputs = inputs.to(DEVICE)
    inputs = inputs.to(torch.uint8)

    # Segment image, take FT, and find pos embed. 
    seg_mask = process_segment(inputs, n_segments=n_segments)
    input_gray = grayscale_transform(inputs)
    seg_out, pos_out = process_ft(input_gray, seg_mask, n_segments=n_segments, n_points=n_points)
    save_data = torch.cat((seg_out, pos_out), dim=2).cpu().numpy() # Stack data and format for save

    # Iterate over data and save
    for i, data in enumerate(save_data):
        label = labels[i]
        class_name = class_names[label]
        class_path = os.path.join(full_save_path, data_type, class_name)
        # print(class_path)

        # Create class folder if it does not exist
        if not os.path.exists(class_path):
            os.makedirs(class_path)

        save_num = len(os.listdir(class_path)) 
        np.savez_compressed(os.path.join(class_path, f'{save_num}.npz'), data=data)    
    

0
1
2
3
4
5
6
7
8
9
