In [36]:
!python -V

Python 3.10.15


In [37]:
import os
from torchvision import datasets, transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image, ImageFilter

In [38]:
data_dir = './data'
paired_dir = './data/mnist_blur_pairs'

os.makedirs(paired_dir, exist_ok=True)

In [39]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)

In [40]:
for i, (image, _) in enumerate(mnist_dataset):
    pil_image = to_pil_image(image)  # Convert to PIL image
    blurred_image = pil_image.filter(ImageFilter.GaussianBlur(radius=2))  # Apply Gaussian blur

    # Concatenate original and blurred images horizontally to create paired data
    paired_image = Image.new('L', (pil_image.width * 2, pil_image.height))
    paired_image.paste(pil_image, (0, 0))
    paired_image.paste(blurred_image, (pil_image.width, 0))

    # Save the paired image
    paired_image.save(os.path.join(paired_dir, f'pair_{i}.png'))

print("Paired dataset created with original and blurred MNIST images.")

Paired dataset created with original and blurred MNIST images.


In [41]:
import shutil
from sklearn.model_selection import train_test_split

# Define directories
paired_dir = './data/mnist_blur_pairs'
train_dir = os.path.join(paired_dir, 'train')
test_dir = os.path.join(paired_dir, 'test')

# Create train and test directories
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Get all images in the paired directory
all_images = [f for f in os.listdir(paired_dir) if f.endswith('.png')]
print(f"Total images in paired directory: {len(all_images)}")

# Split images into train and test sets (e.g., 80% train, 20% test)
train_images, test_images = train_test_split(all_images, test_size=0.2, random_state=42)

# Move images to train and test directories
for image in train_images:
    shutil.move(os.path.join(paired_dir, image), os.path.join(train_dir, image))

for image in test_images:
    shutil.move(os.path.join(paired_dir, image), os.path.join(test_dir, image))

print("Dataset organized into train and test folders.")


Total images in paired directory: 60000
Dataset organized into train and test folders.


In [42]:
!nvidia-smi

Thu Nov 14 15:49:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               Off | 00000000:21:00.0 Off |                  Off |
| 30%   38C    P8              22W / 300W |     17MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               Off | 00000000:22:00.0 Off |  

In [46]:
!python train.py --dataroot ./data/mnist_blur_pairs --name mnist_blur_pix2pix --model pix2pix --direction AtoB --display_id 1 --gpu_ids 0 --batch_size 128


^C
Traceback (most recent call last):
  File "/home/project/GAN_project/pose_blur/train.py", line 22, in <module>
    from options.train_options import TrainOptions
  File "/home/project/GAN_project/pose_blur/options/train_options.py", line 1, in <module>
    from .base_options import BaseOptions
  File "/home/project/GAN_project/pose_blur/options/base_options.py", line 3, in <module>
    from util import util
  File "/home/project/GAN_project/pose_blur/util/util.py", line 3, in <module>
    import torch
  File "/home/project/anaconda3/envs/W-Net/lib/python3.10/site-packages/torch/__init__.py", line 367, in <module>
    from torch._C import *  # noqa: F403
  File "<frozen importlib._bootstrap>", line 216, in _lock_unlock_module
KeyboardInterrupt


In [None]:
!bash ./scripts/download_pix2pix_model.sh mnist_blur_pix2pix

Note: available models are edges2shoes, sat2map, map2sat, facades_label2photo, and day2night
Specified [mnist_blur_pix2pix]
for details.

--2024-11-14 15:44:46--  http://efrosgans.eecs.berkeley.edu/pix2pix/models-pytorch/mnist_blur_pix2pix.pth
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 404 Not Found
2024-11-14 15:44:46 ERROR 404: Not Found.



In [None]:
!python test.py --dataroot ./data/mnist_blur_pairs --direction AtoB --model pix2pix --name mnist_blur_pix2pix --gpu_ids -1 --use_wandb


Traceback (most recent call last):
  File "/home/project/GAN_project/pose_blur/test.py", line 30, in <module>
    from options.test_options import TestOptions
  File "/home/project/GAN_project/pose_blur/options/test_options.py", line 1, in <module>
    from .base_options import BaseOptions
  File "/home/project/GAN_project/pose_blur/options/base_options.py", line 6, in <module>
    import data
  File "/home/project/GAN_project/pose_blur/data/__init__.py", line 15, in <module>
    from data.base_dataset import BaseDataset
  File "/home/project/GAN_project/pose_blur/data/base_dataset.py", line 9, in <module>
    import torchvision.transforms as transforms
  File "/home/project/anaconda3/envs/W-Net/lib/python3.10/site-packages/torchvision/__init__.py", line 10, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
  File "/home/project/anaconda3/envs/W-Net/lib/python3.10/site-packages/torchvision/_meta_registrations.py", li

In [None]:
# Visualize the Results ----
import matplotlib.pyplot as plt

# Load and display generated images
# Note: Adjust the file path if necessary to match the output structure of your experiment
fake_image_path = './results/mnist_blur_pix2pix/test_latest/images/0_fake_B.png'
real_image_path = './results/mnist_blur_pix2pix/test_latest/images/0_real_A.png'
target_image_path = './results/mnist_blur_pix2pix/test_latest/images/0_real_B.png'

# Display fake (generated) image
img = plt.imread(fake_image_path)
plt.figure(figsize=(5, 5))
plt.title("Generated (Fake) Blurred Image")
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()

# Display real (original) input image
img = plt.imread(real_image_path)
plt.figure(figsize=(5, 5))
plt.title("Original MNIST Image (Real A)")
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()

# Display target (blurred) image
img = plt.imread(target_image_path)
plt.figure(figsize=(5, 5))
plt.title("Blurred MNIST Image (Real B)")
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()


ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# Zip the model checkpoint folder
!zip -r mnist_blur_pix2pix_checkpoints.zip checkpoints/mnist_blur_pix2pix/

# Download the zipped file to your local device
from google.colab import files
files.download("mnist_blur_pix2pix_checkpoints.zip")
