In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from PIL import Image
import numpy as np
import os

## Creating the dataset with random preferences

In [2]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Initialize dataset with preferences
train_dataset = datasets.FashionMNIST(root='./fashion', train=True, download=True, transform=transform)

# Define user preferences (example)
preferences = torch.randint(0, 3, (len(train_dataset),))  # Random preferences (0: casual, 1: formal, 2: sporty)

# Modify the data loading process to include preferences
class FashionMNISTWithPreferences(datasets.FashionMNIST):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, preferences=None):
        super(FashionMNISTWithPreferences, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
        self.preferences = preferences

    def __getitem__(self, index):
        img, target = super(FashionMNISTWithPreferences, self).__getitem__(index)
        if self.preferences is not None:
            return img, target, self.preferences[index]
        else:
            return img, target

# Initialize dataset with preferences
train_dataset = FashionMNISTWithPreferences(root='./data', train=True, download=True, transform=transform, preferences=preferences)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 13441076.92it/s]


Extracting ./fashion/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./fashion/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 150806.06it/s]


Extracting ./fashion/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./fashion/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:05<00:00, 750132.06it/s] 


Extracting ./fashion/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./fashion/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 6811443.85it/s]


Extracting ./fashion/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./fashion/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNISTWithPreferences/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 14070717.99it/s]


Extracting ./data/FashionMNISTWithPreferences/raw/train-images-idx3-ubyte.gz to ./data/FashionMNISTWithPreferences/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNISTWithPreferences/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 268418.50it/s]


Extracting ./data/FashionMNISTWithPreferences/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNISTWithPreferences/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNISTWithPreferences/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3038712.74it/s]


Extracting ./data/FashionMNISTWithPreferences/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNISTWithPreferences/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNISTWithPreferences/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5031991.84it/s]

Extracting ./data/FashionMNISTWithPreferences/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNISTWithPreferences/raw





## CVAE model

In [18]:
class CVAE(nn.Module):
    def __init__(self, latent_dim, image_size, preference_size):
        super(CVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(image_size + preference_size, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + preference_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),  # Adjusted size here
            nn.ReLU(),
            nn.Linear(512, image_size),  
            nn.Tanh()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, preferences):
        # Flatten the input images
        x = x.view(x.size(0), -1)
        
        # Convert preferences to one-hot encoding
        preferences_onehot = F.one_hot(preferences, num_classes=preference_size).float()
        
        # Concatenate flattened images with preferences
        x = torch.cat((x, preferences_onehot), dim=1)
        
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(torch.cat((z, preferences_onehot), dim=1)), mu, logvar


In [19]:
batch_size = 128
latent_dim = 20
image_size = 784  # 28x28
preference_size = 3  # Number of preference categories

model = CVAE(latent_dim, image_size, preference_size)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

num_epochs = 10
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [20]:
for epoch in range(num_epochs):
    for i, (images, labels, preferences) in enumerate(train_loader):
        images = images.view(images.size(0), -1)
        recon_images, mu, logvar = model(images, preferences)

        recon_loss = criterion(recon_images, images)
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        loss = recon_loss + kl_divergence
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/10], Loss: 0.3815
Epoch [2/10], Loss: 0.3563
Epoch [3/10], Loss: 0.3622
Epoch [4/10], Loss: 0.3349
Epoch [5/10], Loss: 0.3669
Epoch [6/10], Loss: 0.3333
Epoch [7/10], Loss: 0.3462
Epoch [8/10], Loss: 0.3170
Epoch [9/10], Loss: 0.3786
Epoch [10/10], Loss: 0.3211


## Image Generation

In [63]:
model.eval()

latent_samples = torch.randn(64, latent_dim)
random_preferences = torch.randint(0, preference_size, (64,))
preferences_onehot = F.one_hot(random_preferences, num_classes=preference_size).float()
inputs = torch.cat((latent_samples, preferences_onehot), dim=1)

# Generate images
with torch.no_grad():
    generated_images = model.decoder(inputs)

# Set directory to save images
save_dir = "generated_images"
os.makedirs(save_dir, exist_ok=True)

for j in range(preference_size):
    latent_sample = torch.randn(1, latent_dim)
    preference = torch.tensor([j])
    preference_onehot = F.one_hot(preference, num_classes=preference_size).float()
    input_ = torch.cat((latent_sample, preference_onehot), dim=1)
    
    with torch.no_grad():
        generated_image = model.decoder(input_)
    
    image_vector = generated_image[0].cpu().detach().numpy()
    image_array = image_vector.reshape(28, 28)
    image_uint8 = ((image_array + 1) / 2 * 255).astype(np.uint8)
    image_pil = Image.fromarray(image_uint8, mode='L')  # 'L' mode for grayscale
    image_pil.save(os.path.join(save_dir, f"image_{j}.png"))

print("Images saved successfully.")


Images saved successfully.
