In [1]:
import torch
from tqdm import tqdm

from model import Model

In [2]:
%load_ext autoreload
%autoreload 2

### 0. Load the data

In [3]:
data_path = "../data/"

In [4]:
train_path = data_path + "train_data.pkl"
val_path = data_path + "val_data.pkl"
train_input, train_target = torch.load(train_path)
val_input, val_target = torch.load(val_path)

# Rescale the tensor between [0, 1]
train_input = train_input.float() / 255.0
train_target = train_target.float() / 255.0
val_input = val_input.float() / 255.0
val_target = val_target.float() / 255.0

### 1. Build noise2noise model using our framework

#### Train model

In [15]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ff530d300d0>

In [30]:
model = Model()

In [None]:
model.train(train_input, train_target, num_epochs=20, verbose=True)

### Test model

In [24]:
def compute_psnr(x, y, max_range=1.0):
    assert x.shape == y.shape and x.ndim == 4
    return (
        20 * torch.log10(torch.tensor(max_range)) - 
        10 * torch.log10(((x-y) ** 2).mean((1,2,3))).mean()
    )

In [25]:
model_outputs = []
for b in tqdm(range(0, val_input.size(0), model.batch_size)):
    output = model.predict(val_input.narrow(0, b, model.batch_size))
    model_outputs.append(output)
model_outputs = torch.cat(model_outputs, dim=0) / 255.0

output_psnr = compute_psnr(model_outputs, val_target)
print(f"[PSNR {2}: {output_psnr:.2f} dB]")

100%|██████████| 100/100 [00:00<00:00, 162.59it/s]

[PSNR 2: 12.59 dB]





### 2. Build noise2noise model using Pytorch

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
torch.set_grad_enabled(True)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, (3,3), stride=2, padding=1)
        self.conv2 = nn.Conv2d(10, 20, (3,3), stride=2, padding=1)
        self.upsampling1 = nn.ConvTranspose2d(20, 10, (4,4), stride=2, padding=1)
        self.upsampling2 = nn.ConvTranspose2d(10, 3, (4,4), stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.upsampling1(x))
        x = F.sigmoid(self.upsampling2(x))
        return x

In [None]:
net = Net()

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [None]:
verbose = True
num_epochs=50
batch_size=10

In [None]:
for e in range(num_epochs):
    epoch_loss = 0.0
    for inputs, targets in zip(train_input.split(batch_size), train_target.split(batch_size)):
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        epoch_loss += loss
        loss.backward()
        optimizer.step()

    if verbose:
        print(f'Epoch #{e+1}: MSE Loss = {epoch_loss:.6f}')