In [2]:
# a simple cnn model to separate noisy images from clean images

In [2]:
# https://www.kaggle.com/ateplyuk/pytorch-efficientnet

In [27]:
import numpy as np
import pandas as pd
import os
import matplotlib.image as mpimg

import torch
import torch.nn as nn
import torch.optim as optim 

import torchvision
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as utils
from torchvision import transforms

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

In [28]:
# dataloader
import glob
import cv2
class ImageData(Dataset):
    def __init__(self, pos_path_list, neg_path_list, transform): # pos = clean, neg = noisy
        super().__init__()
        self.imgs = []
        self.labels = []
        
        for p_path in pos_path_list:
            for f in glob.glob(p_path):
                self.imgs.append(f)
                self.labels.append(1.)
        
        for n_path in neg_path_list:
            for f in glob.glob(n_path):
                self.imgs.append(f)
                self.labels.append(0.)
                
        self.transform = transform

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):       
        img_name = self.imgs[index]
        label = self.labels[index]
        
        image = cv2.resize(mpimg.imread(img_name), (192, 96))
        image = self.transform(image)
        return image, label

In [29]:
data_transf = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
train_data = ImageData(["../gan_denoising/trainA/*.jpg"], ["../gan_denoising/trainB/*.jpg"], transform = data_transf)
train_loader = DataLoader(dataset = train_data, batch_size = 8)

In [17]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

Loaded pretrained weights for efficientnet-b1


In [18]:
print(model)

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=

In [19]:
# Unfreeze model weights
for param in model.parameters():
    param.requires_grad = True

In [20]:
num_ftrs = model._fc.in_features
model._fc = nn.Linear(num_ftrs, 1)

In [21]:
model = model.to('cuda')

In [22]:
optimizer = optim.Adam(model.parameters())
loss_func = nn.BCELoss()

In [31]:
%%time
# Train model
loss_log = []

for epoch in range(3):    
    model.train()
    train_loss = 0.0
    for ii, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        target = target.float()                

        optimizer.zero_grad()
        output = model(data)                
    
        m = nn.Sigmoid()
        loss = loss_func(m(output), target)
        loss.backward()

        optimizer.step()  
        
        if ii % 1000 == 0:
            train_loss += loss.item()
            loss_log.append(train_loss)
       
    print('Epoch: {} - Loss: {:.6f}'.format(epoch + 1, train_loss))

Epoch: 1 - Loss: 10.105850
Epoch: 2 - Loss: 9.751593
Epoch: 3 - Loss: 10.077042
Wall time: 19min 38s


In [33]:
test_dir = "../dump/clean1/*.jpg"

model.eval()

for f in glob.glob(test_dir):
    image = cv2.resize(mpimg.imread(f), (192, 96))
    image = data_transf(image)
    image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2])
    image = image.cuda()
    out = model(image)
    pred = torch.sigmoid(out)
    if pred[0][0] >= 0.5:
        print(pred[0][0])