In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from snntorch import utils

import numpy as np
import matplotlib.pyplot as plt

from generate_noised_dataset import generate_one_noisy_image

In [2]:
sys_root = '/home/hwkang/jupyter/root/dataset'

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root=sys_root, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=sys_root, train=False, download=True, transform=transform)

In [4]:
train_dataset = utils.data_subset(train_dataset, 6000)
print(len(train_dataset))

10


In [5]:
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

In [6]:
class NoisedDataset(Dataset):
    def __init__(self, data_loader, noise_type='gaussian'):
        self.x = []
        self.y = []
        for image, label in data_loader:
            image = image.squeeze(0)
            if( np.random.rand() >= 0.5 ):
                self.x.append( generate_one_noisy_image(image, intensity=np.random.rand(), noise_type=noise_type) )
                self.y.append( 1 )
            else:
                self.x.append( image )
                self.y.append( 0 )

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_data = self.x[idx]
        y_data = self.y[idx]
        return x_data, y_data

In [7]:
noised_train_dataset = NoisedDataset(train_loader)
noised_test_dataset = NoisedDataset(test_loader)

noised_gaussian = NoisedDataset(train_loader)
noised_snp = NoisedDataset(train_loader, 'snp')
noised_uniform = NoisedDataset(train_loader, 'uniform')
noised_poisson = NoisedDataset(train_loader, 'poisson')

In [8]:
#train_loader = DataLoader(noised_train_dataset, batch_size=128, shuffle=True)
#test_loader = DataLoader(noised_test_dataset, batch_size=128, shuffle=False)

gaussian_loader = DataLoader(noised_gaussian, batch_size=1, shuffle=False)
snp_loader = DataLoader(noised_snp, batch_size=1, shuffle=False)
uniform_loader = DataLoader(noised_uniform, batch_size=1, shuffle=False)
poisson_loader = DataLoader(noised_poisson, batch_size=1, shuffle=False)

In [None]:
# Sanity check

In [11]:
gaussian_images = [(x.squeeze(0), y) for x, y in gaussian_loader]
snp_images = [(x.squeeze(0), y) for x, y in snp_loader]
uniform_images = [(x.squeeze(0), y) for x, y in uniform_loader]
poisson_images = [(x.squeeze(0), y) for x, y in poisson_loader]

In [12]:
print(gaussian_images[0][0].shape, gaussian_images[0][1])

torch.Size([1, 28, 28]) tensor([1])


In [None]:
fig, axes = plt.subplots(5, 2, figsize=(6, 12))

for i, ax in enumerate(axes.flat):
    ax.imshow(gaussian_images[i][0].permute(1,2,0), cmap='gray')
    ax.axis('off')
    ax.set_title(f'Image[{i+1}] y:{gaussian_images[i][1]}')

print()

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(6, 12))

for i, ax in enumerate(axes.flat):
    ax.imshow(snp_images[i][0].permute(1,2,0), cmap='gray')
    ax.axis('off')
    ax.set_title(f'Image[{i+1}] y:{snp_images[i][1]}')

print()

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(6, 12))

for i, ax in enumerate(axes.flat):
    ax.imshow(uniform_images[i][0].permute(1,2,0), cmap='gray')
    ax.axis('off')
    ax.set_title(f'Image[{i+1}] y:{uniform_images[i][1]}')

print()

In [None]:
fig, axes = plt.subplots(5, 2, figsize=(6, 12))

for i, ax in enumerate(axes.flat):
    ax.imshow(poisson_images[i][0].permute(1,2,0), cmap='gray')
    ax.axis('off')
    ax.set_title(f'Image[{i+1}] y:{poisson_images[i][1]}')

print()

In [17]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 첫 번째 합성곱 레이어
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # 두 번째 합성곱 레이어
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        # 세 번째 합성곱 레이어
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        # 완전 연결 레이어
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 1)
    
    def forward(self, x):
        # 첫 번째 합성곱 + ReLU + 풀링
        x = self.pool(F.relu(self.conv1(x)))
        # 두 번째 합성곱 + ReLU + 풀링
        x = self.pool(F.relu(self.conv2(x)))
        # 세 번째 합성곱 + ReLU + 풀링
        x = self.pool(F.relu(self.conv3(x)))
        # 텐서를 평탄화
        x = x.view(-1, 64 * 3 * 3)
        # 완전 연결 레이어 + ReLU
        x = F.relu(self.fc1(x))
        # 출력 레이어 (이진 분류)
        x = torch.sigmoid(self.fc2(x))
        return x

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN()
model = model.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        # 이진 분류를 위해 라벨을 0 또는 1로 변경 (임시로)
        labels = labels.float().unsqueeze(1)

        inputs, labels = inputs.to(device), labels.to(device)
        
        # 옵티마이저 초기화
        optimizer.zero_grad()
        
        # 순전파
        outputs = model(inputs)
        
        # 손실 계산
        loss = criterion(outputs, labels)
        
        # 역전파
        loss.backward()
        
        # 옵티마이저 업데이트
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    if( epoch % 10 == 9 ):
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [None]:
# 모델 평가
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        labels = labels.float().unsqueeze(1)  # 이진 분류를 위해 라벨을 0 또는 1로 변경 (임시로)
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        predicted = (outputs >= 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy: {accuracy:.2f}%')