<a href="https://colab.research.google.com/github/sunmulim/-/blob/main/Untitled20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch

from torchvision.datasets.mnist import MNIST

from torchvision.transforms import ToTensor

In [None]:
def gaussian_noise(x, scale=0.5):

     noise = np.random.normal(
        loc=0,
        scale=0.5,
        size=x.shape
     )
     noise_x = x + noise

     noise_x = np.clip(noise_x, 0, 1)
     noise_x = torch.Tensor(noise_x)
     noise_x = noise_x.type(torch.FloatTensor)
     return noise_x

In [None]:
training_dataset = MNIST(
    root="./data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_dataset = MNIST(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)

img = training_dataset.data[0]

noise_img = gaussian_noise(img)


plt.subplot(1, 2, 1)
plt.title("original")
plt.imshow(img, cmap="gray")
plt.subplot(1, 2, 2)
plt.title("noisy")
plt.imshow(noise_img, cmap="gray")

In [None]:
from torch.utils.data.dataset import Dataset


class Denoising(Dataset):

    def __init__(self):

        self.mnist = MNIST(
            root="./data",
            train=True,
            download=True,
            transform=ToTensor()
        )
        self.noise_data = []


        for i in range(len(self.mnist)):
            noise_x = gaussian_noise(self.mnist.data[i])
            noise_x = torch.tensor(noise_x)
            self.noise_data.append(torch.unsqueeze(noise_x, dim=0))

    def __len__(self):
      return len(self.noise_data)

    def __getitem__(self, idx):
        data = self.noise_data[idx]
        label = self.mnist.data[idx] / 255

        return data, label

In [None]:
train_dataset = Denoising()

In [None]:
print(train_dataset[0][0].shape, train_dataset[0][1].shape)

In [None]:
import torch.nn as nn


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dim):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels,
            hidden_dim,
            kernel_size=3,

            padding=1
        )
        self.conv2 = nn.Conv2d(
            hidden_dim,
            out_channels,
            kernel_size=3,

            padding=1
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

In [None]:
BasicBlock(1, 16, hidden_dim=8)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.block1 = BasicBlock(in_channels=1, out_channels=16, hidden_dim=16)
        self.block2 = BasicBlock(in_channels=16, out_channels=8, hidden_dim=8)

        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.block1(x)
        x = self.pool(x)
        x = self.block2(x)
        x = self.pool(x)

        return x

In [None]:
Encoder()

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.block1 = BasicBlock(in_channels=8, out_channels=8, hidden_dim=8)
        self.block2 = BasicBlock(in_channels=8, out_channels=16, hidden_dim=16)

        self.output_conv = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1)

        self.upsample1 = nn.ConvTranspose2d(8, 8, kernel_size=2, stride=2)
        self.upsample2 = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2)
    def forward(self, x):
        x = self.block1(x)
        x = self.upsample1(x)
        x = self.block2(x)
        x = self.upsample2(x)
        x = self.output_conv(x)

        return x

In [None]:
Decoder()

In [None]:
class CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()

        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = torch.squeeze(x)

        return x

In [None]:
CAE()

In [None]:
import tqdm

from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam

train_dataset = Denoising()
train_loader = DataLoader(train_dataset, batch_size=32)

In [None]:
model = CAE().cuda()


lr = 0.001
optim = Adam(params=model.parameters(), lr=lr)
criterion = nn.MSELoss()

for epoch in range(20):
    iterator = tqdm.tqdm(train_loader)

    for inputs, labels in iterator:
        inputs = inputs.cuda()
        labels = labels.cuda()

        optim.zero_grad()
        pred = model(inputs)
        loss = criterion(pred, labels)
        loss.backward()
        optim.step()

        iterator.set_description(f"[Epoch {epoch + 1}] loss: {loss.item()}")

In [None]:

torch.save(model.state_dict(), "./CAE.pt")

In [None]:

model.load_state_dict(torch.load("./CAE.pt"))
model = model.cuda()


with torch.no_grad():
    img = test_dataset.data[0]
    noise_img = gaussian_noise(img)

    input = torch.unsqueeze(noise_img, dim=0)
    input.type(torch.FloatTensor)
    input = input.cuda()
    input = torch.unsqueeze(input, dim=0)


    plt.subplot(1, 3, 1)
    plt.imshow(torch.squeeze(noise_img), cmap="gray")
    plt.subplot(1, 3, 2)
    plt.imshow(torch.squeeze(model(input).cpu()), cmap="gray")
    plt.subplot(1, 3, 3)
    plt.imshow(torch.squeeze(img), cmap="gray")
    plt.show()