In [3]:
!pip install -r requirements.txt

Collecting wandb~=0.15.5
  Using cached wandb-0.15.5-py3-none-any.whl (2.1 MB)
Collecting numpy~=1.24.3
  Downloading numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[K     |████████████████████████████████| 17.3 MB 776 kB/s eta 0:00:01
Collecting opencv-python~=4.8.0.74
  Using cached opencv_python-4.8.0.74-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (61.7 MB)
Collecting requests~=2.31.0
  Using cached requests-2.31.0-py3-none-any.whl (62 kB)
Collecting setuptools~=47.1.0
  Downloading setuptools-47.1.1-py3-none-any.whl (583 kB)
[K     |████████████████████████████████| 583 kB 70.3 MB/s eta 0:00:01
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.32-py3-none-any.whl (188 kB)
[K     |████████████████████████████████| 188 kB 26.7 MB/s eta 0:00:01
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.28.0-py2.py3-none-any.whl (213 kB)
[K     |████████████████████████████████| 213 kB 20.4 MB/s eta 0:00:01
[?25hCollecting 

  Attempting uninstall: requests
    Found existing installation: requests 2.27.1
    Uninstalling requests-2.27.1:
      Successfully uninstalled requests-2.27.1
  Attempting uninstall: numpy
    Found existing installation: numpy 1.21.5
    Uninstalling numpy-1.21.5:
      Successfully uninstalled numpy-1.21.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
spyder 5.1.5 requires pyqt5<5.13, which is not installed.
spyder 5.1.5 requires pyqtwebengine<5.13, which is not installed.
daal4py 2021.5.0 requires daal==2021.4.0, which is not installed.
conda-repo-cli 1.0.4 requires pathlib, which is not installed.
anaconda-project 0.10.2 requires ruamel-yaml, which is not installed.
spyder 5.1.5 requires setuptools>=49.6.0, but you have setuptools 47.1.1 which is incompatible.
scipy 1.7.3 requires numpy<1.23.0,>=1.16.5, but you have numpy 1.24.4 which is incompat

In [16]:
import json
import os

import cv2
import torch
import wandb
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import torch.nn as nn
import numpy as np
from torchvision import transforms

# Define the AutoEncoder model 
### Encoder: 2 blocks of conv + relu + pool 
### Decoder: 2 blocks of unpool + transposed conv

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, act_fn=nn.ReLU()):
        super().__init__()

        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channels, 100, 5, padding=2),
            act_fn, )
        self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.conv_2 = nn.Sequential(
            nn.Conv2d(100, 200, 5, padding=2),
            act_fn,
        )

    def forward(self, x):
        output = self.conv_1(x)  # (3, 256, 256) -> (100, 256, 256)
        output, indices_1 = self.pool(output)  # (100, 256, 256) -> (100, 128, 128)
        output = self.conv_2(output)  # (100, 128, 128) -> (200, 128, 128)
        output, indices_2 = self.pool(output)  # (200, 128, 128) -> (200, 64, 64)
        return output, indices_1, indices_2


#  defining decoder
class Decoder(nn.Module):
    def __init__(self, act_fn=nn.ReLU()):
        super().__init__()
        self.unpool = nn.MaxUnpool2d(2, stride=2)
        self.deconv_1 = nn.Sequential(
            nn.ConvTranspose2d(200, 100, 5, padding=2),
            act_fn, )
        self.deconv_2 = nn.Sequential(
            nn.ConvTranspose2d(100, 3, 5, padding=2),
            nn.Sigmoid()
        )

    def forward(self, x, indices_1, indices_2):
        output = self.unpool(x, indices_1)  # (200, 64, 64) -> (200, 128, 128)
        output = self.deconv_1(output)  # (200, 128, 128) -> (100, 128, 128)
        output = self.unpool(output, indices_2) # (100, 128, 128) -> (200, 256, 256)
        output = self.deconv_2(output) # (100, 256, 256) -> (3, 256, 256)
        return output


#  defining autoencoder
class Autoencoder(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.encoder.to(device)

        self.decoder = decoder
        self.decoder.to(device)

    def forward(self, x):
        encoded, indices_2, indices_1 = self.encoder(x) # returning the pooling indices as well
        decoded = self.decoder(encoded, indices_1, indices_2)
        return decoded


# Define the final classifier model

It first used the pretrained encoder to extract the features of the the image, then use fully connected layers for final classification

In [6]:
class CNNClf(nn.Module):
    def __init__(self, encoder,  device, act_fn=nn.ReLU()):
        super().__init__()
        self.encoder = encoder
        for param in self.encoder.parameters():  # freeze the feature extractor
            param.requires_grad = False

        self.encoder.to(device)

        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(200*64*64, 400),
            act_fn,
            nn.Linear(400, 200),
            act_fn,
            nn.Linear(200, 3),
        )


# Define the two datasets

In [7]:
class ImageDataset(Dataset):

    def __init__(self, image_dir, device, preprocess, test=False):
        self.image_dir = image_dir
        self.test = test
        self.device = device
        self.preprocess = preprocess

    def __len__(self):
        return len(os.listdir(self.image_dir))

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, "image_" + str(idx) + ".jpg")
        image = read_image(image_path)
        image = torch.tensor(image)
        image = image[:3, :, :]

        # Works with 3 channels
        mask = np.random.choice([0, 1], size=(256, 256), p=[.2, .8]).astype(np.uint8)
        noise = cv2.bitwise_and(image.permute(1, 2, 0).numpy(), image.permute(1, 2, 0).numpy(), mask=mask)
        noise = torch.tensor(noise)
        noise = noise.permute(2, 0, 1)
        noise = noise[:3, :, :]

        image = self.preprocess(image)
        noise = self.preprocess(noise)


        return noise.to(self.device), image.to(
            self.device)


In [8]:
class ClfImageDataset(Dataset):

    def __init__(self, image_dir, labels, device, preprocess, test=False):
        self.image_dir = image_dir
        self.labels = labels
        self.test = test
        self.device = device
        self.preprocess = preprocess

    def __len__(self):
        return len(os.listdir(self.image_dir))

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, "image_" + str(idx) + ".jpg")
        image = read_image(image_path)
        image = image[:3, :, :]  # Works with 3 channels
        image = self.preprocess(image)

        image = torch.tensor(image)
        label = torch.tensor(self.labels[idx, :])
        return image.to(self.device), label.to(self.device)


# Start Training Convolutional AutoEncoder

In [9]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, save_path, patience=5, verbose=False, delta=0, save=True):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.save_path = save_path
        self.save = save

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            if self.save:
                self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            if self.save:
                self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.save_path)	# save the current best model
        self.val_loss_min = val_loss


In [13]:
def train_cae():
    image_dir = "./data/train_cae/image/"
    noise_dir = "./data/train_cae/noise/"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Preprocess the image to normalize the pixel values
    preprocess = transforms.Compose(
        [
            transforms.ConvertImageDtype(torch.float32),
        ]
    )

    # Randomly select 5500 images and discard the rest
    full_dataset = ImageDataset(image_dir, device=device, preprocess=preprocess)
    use_size = 5500
    rest_size = len(full_dataset) - use_size
    use_dataset, _ = torch.utils.data.random_split(full_dataset, [use_size, rest_size])

    # Split into train and validation dataset
    train_size = 5000
    val_size = 500
    train_dataset, val_dataset = torch.utils.data.random_split(use_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=32)
    valid_loader = DataLoader(val_dataset, batch_size=32)


    model = Autoencoder(Encoder(), Decoder(), device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    n_epoch = 30
    step = 0

    # Check and specify the checkpoint path
    ckpt_path = "/cluster/scratch/zhiychen/DeepPainter/checkpoint" if torch.cuda.is_available() else './checkpoint'
    os.makedirs(ckpt_path, exist_ok=True)
    save_filename = '%s_%s.pth' % (n_epoch, "cae")
    save_path = os.path.join(ckpt_path, save_filename)
    early_stopping = EarlyStopping(save_path=save_path, patience=5, verbose=False, delta=0)

    # Training
    for epoch in range(1, n_epoch + 1):
        train_loss = 0.0
        for data in train_loader:
            input, label = data
            optimizer.zero_grad()
            # temp_1 = input[0, :, :, :].permute(1, 2, 0).detach().numpy()
            # cv2.imshow("image", temp_1)
            # cv2.waitKey(0)

            output = model(input)

            # temp_2 = output[0, :, :, :].permute(1, 2, 0).detach().numpy()
            # cv2.imshow("image", temp_2)
            # cv2.waitKey(0)

            loss = criterion(output, label)

            loss.backward()
            optimizer.step()
            step += 1
            train_loss += loss.item()
            # logging
            print({"Train_loss": loss.item()})
        print({"Total_Train_loss": train_loss})

        # Validation after each epoch
        total_loss, batch_count = 0, 0
        model.eval()
        with torch.no_grad():
            for batch in valid_loader:
                input, label = batch
                output = model(input)
                loss = criterion(output, label)
                batch_count += len(input)
                total_loss += loss.item() * len(input)
        valid_loss = total_loss / batch_count
        print({"Valid_loss": valid_loss, "Epochs": epoch})

        # for early stopping
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

In [None]:
train_cae()

  image = torch.tensor(image)
