In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision.utils import save_image
from torchvision.transforms import Resize, ToPILImage, ToTensor

In [2]:
r = 0.9
image_size = 32
p = 0.95
threshold_norm = 2.5

In [3]:
def sign(x):
    if(x >= 0):
        return 1
    return -1

In [5]:
n = 10000
z = torch.randn((n, 2))

for i in range(n):
    if i % 1000 == 0:
        print(f"{i}/{n}")
    
    z1 = max(min((z[i][0].item()), threshold_norm), -threshold_norm)/threshold_norm
    z2 = max(min((z[i][1].item()), threshold_norm), -threshold_norm)/threshold_norm
    
    x = np.zeros((image_size, image_size))
    origin_pos_x = int(image_size/2) - int((1 - sign(z1))/2)
    origin_pos_y = int(image_size/2) - int((1 - sign(z2))/2)
    if abs(z1) > abs(z2):
        for x_pos in range(origin_pos_x, origin_pos_x + round(image_size/2*r*z1), sign(z1)):
            y_pos = round(z2/z1*(x_pos - origin_pos_x) + origin_pos_y)
            if (x_pos - origin_pos_x)**2 + (y_pos - origin_pos_y)**2 > (int(image_size/2*r))**2:
                break
            value_p = np.random.binomial(n=1, p=p, size=(3, 3))
            x_index = [x_pos - 1, x_pos, x_pos+1]
            y_index = [y_pos - 1, y_pos, y_pos+1]
            x[np.ix_(x_index, y_index)] = value_p
    else:
        for y_pos in range(origin_pos_y, origin_pos_y + round(image_size/2*r*z2), sign(z2)):
            x_pos = round(z1/z2*(y_pos - origin_pos_y) + origin_pos_x)
            if (x_pos - origin_pos_x)**2 + (y_pos - origin_pos_y)**2 > (int(image_size/2*r))**2:
                break
            value_p = np.random.binomial(n=1, p=p, size=(3, 3))
            x_index = [x_pos - 1, x_pos, x_pos+1]
            y_index = [y_pos - 1, y_pos, y_pos+1]
            x[np.ix_(x_index, y_index)] = value_p
            
    if i == 0:
        x_data = torch.tensor(x)[None, :]
    else:
        x_data = torch.cat((x_data, torch.tensor(x)[None, :]), 0)
        
x_data = x_data.unsqueeze(1)

0/10000
1000/10000
2000/10000
3000/10000
4000/10000
5000/10000
6000/10000
7000/10000
8000/10000
9000/10000


In [6]:
x = x_data

In [7]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, x, label):
        self.x = x
        self.label = label
        self.n = x.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.x[idx], self.label[idx]

In [13]:
label = torch.full((n, ), 0)

In [16]:
split_idx = 8000
train_x, train_label = x[:split_idx], label[:split_idx]
test_x, test_label = x[split_idx:], label[split_idx:]
train_dataset = CustomDataset(train_x, train_label)
test_dataset = CustomDataset(test_x, test_label)

import pickle
train_file = './data/model/train_dataset.pkl'
test_file = './data/model/test_dataset.pkl'

with open(train_file, 'wb') as f:
    pickle.dump(train_dataset, f)
with open(test_file, 'wb') as f:
    pickle.dump(test_dataset, f)

In [17]:
import pickle
train_file = './data/model/train_dataset.pkl'
test_file = './data/model/test_dataset.pkl'

with open(train_file, 'rb') as f:
    train_dataset = pickle.load(f)
with open(test_file, 'rb') as f:
    test_dataset = pickle.load(f)

bs = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

print(len(train_dataset), len(test_dataset))

8000 2000


In [19]:
import os
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

num_samples = 8000
indices = np.random.choice(len(train_dataset), num_samples, replace=False)
subset_train_dataset = Subset(train_dataset, indices)

print("start deleting")
import shutil
save_dir = './samples/base'
if os.path.exists(save_dir) and os.path.isdir(save_dir):
    shutil.rmtree(save_dir)
else:
    print(f"Directory does not exist.")
print("done deleting")

os.makedirs(save_dir, exist_ok=True)

print("start saving")
def save_images(dataset, save_dir):
    for idx, (image, label) in enumerate(dataset):
        image = transforms.ToPILImage()(image)
        image.save(os.path.join(save_dir, f'image_{idx}.png'))
save_images(subset_train_dataset, save_dir)
print("done saving")

start deleting
done deleting
start saving
done saving
