In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.autograd import Variable
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import models, datasets, transforms

In [None]:
class U_Net_Encoder(nn.Module):
    def __init__(self):
        super(U_Net_Encoder, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
            
        self.pool1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        )
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2, groups=2)
           
        self.pool2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        )
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
           
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
           
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1, groups=2)
           
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        self.fc6 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*6*6, 4096),
        )
        self.fc7 = nn.Linear(4096, 4096)
        
        self.fc8 = nn.Linear(4096, 1000)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):       
        conv1 = self.conv1(x)        # (96, 55, 55)
        x = self.relu(conv1)

        x = self.pool1(x)            # (96, 27, 27)
    
        conv2 = self.conv2(x)        # (256, 27, 27)
        x = self.relu(conv2)
        
        x = self.pool2(x)            # (256, 13, 13)

        conv3 = self.conv3(x)        # (384, 13, 13)
        x = self.relu(conv3)
     
        conv4 = self.conv4(x)        # (384, 13, 13)
        x = self.relu(conv4)

        conv5 = self.conv5(x)        # (256, 13, 13)
        x = self.relu(conv5)

        x = self.pool3(x)            # (256, 6, 6)

        fc6 = self.fc6(x)            # (4096)
        x = self.relu(fc6)
        x = self.dropout(x)
        
        fc7 = self.fc7(x)            # (4096)
        x = self.relu(fc7)
        x = self.dropout(x)

        fc8 = self.fc8(x)            # (1000)

        return conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8


In [None]:
class U_Net_Decoder(nn.Module):
    def __init__(self):
        super(U_Net_Decoder, self).__init__()

        self.rfc8 = nn.Linear(1000, 4096)
        
        self.rfc7 = nn.Linear(8192, 4096)
        
        self.rfc6 = nn.Linear(8192, 256*6*6)
        
        self.rpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2)
        self.rconv5 = nn.ConvTranspose2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
        
        self.rconv4 = nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
        
        self.rconv3 = nn.ConvTranspose2d(in_channels=768, out_channels=256, kernel_size=3, stride=1, padding=1)
        
        self.rpool2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2)
        self.rconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=96, kernel_size=5, stride=1, padding=2, groups=2)
        
        self.rpool1 = nn.ConvTranspose2d(in_channels=96, out_channels=96, kernel_size=3, stride=2)

        self.rconv1 = nn.ConvTranspose2d(in_channels=192, out_channels=3, kernel_size=11, stride=4, padding=0)

        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()
        self.loss_func = nn.MSELoss()
        self.sigmoid = nn.Sigmoid()

    def forward(self, conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8):

        x = self.rfc8(fc8)                # (4096)
        x = torch.cat((x, fc7), dim=1)    # (8192)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.rfc7(x)                  # (4096)
        x = torch.cat((x, fc6), dim=1)    # (8192)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.rfc6(x)                  # (256*6*6)
        x = self.relu(x)
        x = self.dropout(x)
        x = x.view(-1, 256, 6, 6)         # (256, 6, 6)

        x = self.rpool3(x)                # (256, 13, 13)
        x = torch.cat((x, conv5), dim=1)  # (512, 13, 13)

        x = self.rconv5(x)                # (384, 13, 13)
        x = torch.cat((x, conv4), dim=1)  # (768, 13, 13)
        x = self.relu(x)

        x = self.rconv4(x)                # (384, 13, 13)
        x = torch.cat((x, conv3), dim=1)  # (768, 13, 13)
        x = self.relu(x)

        x = self.rconv3(x)                # (256, 13, 13)
        x = self.relu(x)

        x = self.rpool2(x)                # (256, 27, 27)
        x = torch.cat((x, conv2), dim=1)  # (512, 27, 27)

        x = self.rconv2(x)                # (96, 27, 27)
        x = self.relu(x)

        x = self.rpool1(x)                # (96, 55, 55)
        x = torch.cat((x, conv1), dim=1)  # (192, 55, 55)

        x = self.rconv1(x)                # (3, 227, 227)
        x = self.sigmoid(x)      

        return x

    def loss(self, x, x_recon):

        loss = self.loss_func(x, x_recon)

        return loss


In [None]:
# images

image_dir = '/home/shunosuga/data/img_npy'

images = np.load(os.path.join(image_dir, 'img_npy_227.npy'))
print(images.dtype, flush=True)
print(images.shape, flush=True)

In [None]:
BATCH_SIZE = 128
N_EPOCHS = 100

img_input = torch.Tensor(images)
img_input = img_input / 255
img_input = img_input.permute(0, 3, 1, 2)

train_input = img_input[50:]
test_input = img_input[:50]

train_dataset = TensorDataset(train_input)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = TensorDataset(test_input)
test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False)

print(train_input.size(), test_input.size(), flush=True)    # (400000, 3, 227, 227)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.set_default_tensor_type('torch.cuda.FloatTensor')

n_epochs = 10
lr = 1e-3
decay_lr = 0.75

encoder = U_Net_Encoder().to(device)
decoder = U_Net_Decoder().to(device)

optimizer = Adam(params=[
    {"params" : encoder.parameters()},
    {"params" : decoder.parameters()}
], lr=lr)

lr_ = ExponentialLR(optimizer, gamma=decay_lr)

In [None]:
losses = []

for i in range(N_EPOCHS):

    for j, data in enumerate(train_loader):

        x = data[0]

        x_in = Variable(x, requires_grad=False).to(device)

        conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8 = encoder(x_in)

        x_out = decoder(conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8)

        loss = U_Net_Decoder().loss(x_in, x_out)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_n = loss.to('cpu').detach().numpy()

        losses.append(loss_n)

        print('epoch:{}'.format(i), 'batch:{}'.format(j), 'loss:{}'.format(loss_n))

    lr_.step()

In [None]:
# save model in /home/shunosuga/data/model/u_net