In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt
%matplotlib inline
import pickle
import os
import json
import numpy as np

# Dataloader

In [9]:
clean_dir = '/Users/Terry/Documents/Research/test_block/' 
mix_dir = '/Users/Terry/Documents/Research/test_block/' 
clean_label_dir = '/Users/Terry/Documents/Research/test_block/' 
mix_label_dir = '/Users/Terry/Documents/Research/test_block/' 

# cleanfolder = os.listdir(clean_dir)
# cleanfolder.sort()

# mixfolder = os.listdir(mix_dir)
# mixfolder.sort()

# cleanlabelfolder = os.listdir(clean_label_dir)
# cleanlabelfolder.sort()

# mixlabelfolder = os.listdir(mix_label_dir)
# mixlabelfolder.sort()

# clean_list = []
# mix_list = []
# clean_label_list = []
# mix_label_list = []

In [10]:
class MSourceDataSet(Dataset):
    
    def __init__(self, clean_dir, mix_dir, clean_label_dir, mix_label_dir):
        
        with open(clean_dir + 'clean0.json') as f:
            clean0 = torch.Tensor(json.load(f))
        
#         with open(clean_label_dir + 'clean_label0.json') as f:
#             cleanlabel0 = torch.Tensor(json.load(f))
            
        self.spec = clean0
#         self.label = cleanlabel0
            
        
#         for i in cleanfolder:
#             with open(clean_dir + '{}'.format(i)) as f:
#                 clean_list.append(torch.Tensor(json.load(f)))

#         for i in mixfolder:
#             with open(mix_dir + '{}'.format(i)) as f:
#                 mix_list.append(torch.Tensor(json.load(f)))
                
#         for i in cleanlabelfolder:
#             with open(clean_label_dir + '{}'.format(i)) as f:
#                 clean_label_list.append(torch.Tensor(json.load(f)))

#         for i in mixlabelfolder:
#             with open(mix_label_dir + '{}'.format(i)) as f:
#                 mix_label_list.append(torch.Tensor(json.load(f)))
        
#         cleanblock = torch.cat(clean_list, 0)
#         mixblock = torch.cat(mix_list, 0)
#         self.spec = torch.cat([cleanblock, mixblock], 0)
                
#         cleanlabel = torch.cat(clean_label_list, 0)
#         mixlabel = torch.cat(mix_label_list, 0)
#         self.label = torch.cat([cleanlabel, mixlabel], 0)

        
    def __len__(self):
        return self.spec.shape[0]

                
    def __getitem__(self, index): 

        spec = self.spec[index]
        return spec

In [11]:
trainset = MSourceDataSet(clean_dir, mix_dir, clean_label_dir, mix_label_dir)

trainloader = torch.utils.data.DataLoader(dataset = trainset,
                                                batch_size = 4,
                                                shuffle = True)

In [14]:
a = trainset.__len__()
print (a)

1000


# Model

In [15]:
''' ResBlock '''
class ResBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(ResBlock, self).__init__()

        self.channels_in = channels_in
        self.channels_out = channels_out

        self.conv1 = nn.Conv2d(in_channels=channels_in, out_channels=channels_out, kernel_size=(3,3), padding=1)
        self.conv2 = nn.Conv2d(in_channels=channels_out, out_channels=channels_out, kernel_size=(3,3), padding=1)

    def forward(self, x):
        if self.channels_out > self.channels_in:
            x1 = F.relu(self.conv1(x))
            x1 =        self.conv2(x1)
            x  = self.sizematch(self.channels_in, self.channels_out, x)
            return x + x1
        elif self.channels_out < self.channels_in:
            x = F.relu(self.conv1(x))
            x1 =       self.conv2(x)
            x = x + x1
            return x
        else:
            x1 = F.relu(self.conv1(x))
            x1 =        self.conv2(x1)
            x = x + x1
            return x

    def sizematch(self, channels_in, channels_out, x):
        zeros = torch.zeros( (x.size()[0], channels_out - channels_in, x.shape[2], x.shape[3]), dtype = torch.float32)
        return torch.cat((x, zeros), dim=1)

class ResTranspose(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(ResTranspose, self).__init__()

        self.channels_in = channels_in
        self.channels_out = channels_out

        self.deconv1 = nn.ConvTranspose2d(in_channels=channels_in, out_channels=channels_out, kernel_size=(2,2), stride=2)
        self.deconv2 = nn.Conv2d(in_channels=channels_out, out_channels=channels_out, kernel_size=(3,3), padding=1)

    def forward(self, x):
        # cin = cout
        x1 = F.relu(self.deconv1(x))
        x1 =        self.deconv2(x1)
        x = self.sizematch(x)
        return x + x1

    def sizematch(self, x):
        # expand
        x2 = torch.zeros(x.shape[0], self.channels_in, x.shape[2]*2, x.shape[3]*2)

        row_x  = torch.zeros(x.shape[0], self.channels_in, x.shape[2], 2*x.shape[3])
        row_x[:,:,:,odd(x.shape[3]*2)]   = x
        row_x[:,:,:,even(x.shape[3]*2)]  = x
        x2[:,:, odd(x.shape[2]*2),:] = row_x
        x2[:,:,even(x.shape[2]*2),:] = row_x

        return x2


def initialize(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight)
        init.constant_(m.bias, 0)
    if isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight)



class ResDAE(nn.Module):
    def __init__(self):
        super(ResDAE, self).__init__()

        # 128x128x1

        self.upward_net1 = nn.Sequential(
            ResBlock(1, 8),
            ResBlock(8, 8),
            ResBlock(8, 8),
            nn.BatchNorm2d(8),
        )

        # 64x64x8

        self.upward_net2 = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(8, 8),
            ResBlock(8, 16),
            ResBlock(16, 16),
            nn.BatchNorm2d(16),
        )

        # 32x32x16

        self.upward_net3 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(16, 16),
            ResBlock(16, 32),
            ResBlock(32, 32),
            nn.BatchNorm2d(32),
        )

        # 16x16x32

        self.upward_net4 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(32, 32),
            ResBlock(32, 64),
            ResBlock(64, 64),
            nn.BatchNorm2d(64),
        )

        # 8x8x64

        self.upward_net5 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(64, 64),
            ResBlock(64, 128),
            ResBlock(128, 128),
            nn.BatchNorm2d(128),
        )

        # 4x4x128

        self.upward_net6 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(128, 128),
            ResBlock(128, 256),
            ResBlock(256, 256),
            nn.BatchNorm2d(256),
        )

        # 2x2x256

        self.upward_net7 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(2,2), stride=2),
            nn.ReLU(),
            ResBlock(256, 256),
            ResBlock(256, 512),
            ResBlock(512, 512),
            nn.BatchNorm2d(512),
        )

        # 1x1x512
        self.downward_net7 = nn.Sequential(
            ResBlock(512, 512),
            ResBlock(512, 256),
            ResBlock(256, 256),
            ResTranspose(256, 256),
            nn.BatchNorm2d(256),
        )

        # 2x2x256

        self.downward_net6 = nn.Sequential(
            # 8x8x64
            ResBlock(256, 256),
            ResBlock(256, 128),
            ResBlock(128, 128),
            ResTranspose(128, 128),
            nn.BatchNorm2d(128),
        )

        # 4x4x128
        # cat -> 4x4x256
        self.uconv5 = nn.Conv2d(256, 128, kernel_size=(3,3), padding=(1,1))
        # 4x4x128
        self.downward_net5 = nn.Sequential(
            ResBlock(128, 128),
            ResBlock(128, 64),
            ResBlock(64, 64),
            ResTranspose(64, 64),
            nn.BatchNorm2d(64),
        )

        # 8x8x64
        # cat -> 8x8x128
        self.uconv4 = nn.Conv2d(128, 64, kernel_size=(3,3), padding=(1,1))
        # 8x8x64
        self.downward_net4 = nn.Sequential(
            ResBlock(64, 64),
            ResBlock(64, 32),
            ResBlock(32, 32),
            ResTranspose(32, 32),
            nn.BatchNorm2d(32),
        )

        # 16x16x32
        # cat -> 16x16x64
        self.uconv3 = nn.Conv2d(64, 32, kernel_size=(3,3), padding=(1,1))
        # 16x16x32
        self.downward_net3 = nn.Sequential(
            ResBlock(32, 32),
            ResBlock(32, 16),
            ResBlock(16, 16),
            ResTranspose(16, 16),
            nn.BatchNorm2d(16),
        )

        # 32x32x16
        # cat -> 32x32x32
        self.uconv2 = nn.Conv2d(32, 16, kernel_size=(3,3), padding=(1,1))
        # 32x32x16
        self.downward_net2 = nn.Sequential(
            ResBlock(16, 16),
            ResBlock(16, 8),
            ResBlock(8, 8),
            ResTranspose(8, 8),
            nn.BatchNorm2d(8),
        )

        # 64x64x8
        self.downward_net1 = nn.Sequential(
            ResBlock(8, 8),
            ResBlock(8, 4),
            ResBlock(4, 1),
            ResBlock(1, 1),
            ResTranspose(1, 1),
            nn.BatchNorm2d(1),
        )

        # 128x128x1


    def upward(self, x, a7=None, a6=None, a5=None, a4=None, a3=None, a2=None):
        x = x.view(bs, 1, 1025, 16)
        # 1x128x128

        x = self.upward_net1(x)
        # 8x64x64

        x = self.upward_net2(x)
        if a2 is not None: x = x * a2
        self.x2 = x
        # 16x32x32

        x = self.upward_net3(x)
        if a3 is not None: x = x * a3
        self.x3 = x
        # 32x16x16

        x = self.upward_net4(x)
        if a4 is not None: x = x * a4
        self.x4 = x
        # 64x8x8

        x = self.upward_net5(x)
        if a5 is not None: x = x * a5
        self.x5 = x
        # 128x4x4

        x = self.upward_net6(x)
        if a6 is not None: x = x * a6
        # 256x2x2

        x = self.upward_net7(x)
        if a7 is not None: x = x * a7
        # 512x1x1

        return x


    def downward(self, y, shortcut=True):

        # 512x1x1
        y = self.downward_net7(y)

        # 256x2x2
        y = self.downward_net6(y)

        # 128x4x4
        if shortcut:
            y = torch.cat((y, self.x5), 1)
            y = F.relu(self.uconv5(y))
        y = self.downward_net5(y)

        # 64x8x8
        if shortcut:
            y = torch.cat((y, self.x4), 1)
            y = F.relu(self.uconv4(y))
        y = self.downward_net4(y)
        
        # 32x16x16
        if shortcut:
            y = torch.cat((y, self.x3), 1)
            y = F.relu(self.uconv3(y))
        y = self.downward_net3(y)

        # 16x32x32
        if shortcut:
            y = torch.cat((y, self.x2), 1)
            y = F.relu(self.uconv2(y))
        y = self.downward_net2(y)

        # 8x64x64
        y = self.downward_net1(y)
        
        # 1x128x128

        return y

model = ResDAE()
print (model)

ResDAE(
  (upward_net1): Sequential(
    (0): ResBlock(
      (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ResBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ResBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True)
  )
  (upward_net2): Sequential(
    (0): Conv2d(8, 16, kernel_size=(2, 2), stride=(2, 2))
    (1): ReLU()
    (2): ResBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResBlock(
      (conv1): Conv2d(8, 16, kernel_size=(3

## Optimizer

In [16]:
epoch = 1
lr = 0.005
mom = 0.005
bs = 4

In [17]:
criterion = nn.MSELoss(size_average = True)
optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = mom)

# Training

In [18]:
## 這部分還沒加進去

def odd(w):
    return list(np.arange(1, w, step=2, dtype='long'))

def even(w):
    return list(np.arange(0, w, step=2, dtype='long'))

def white(x):
    fw, tw = x.shape[1], x.shape[2]

    first = F.relu(torch.normal(mean=torch.zeros(fw, tw), std=torch.ones(fw, tw)) ) * 0.05
    second_seed = F.relu(torch.normal(mean=torch.zeros(fw//2, tw//2), std=torch.ones(fw//2, tw//2))) * 0.03
    second = torch.zeros(fw, tw)

    row_x  = torch.zeros(int(fw//2), tw)
    # row_x = torch.zeros(int(fw/2), tw)

    row_x[:, odd(tw)]  = second_seed
    row_x[:, even(tw)] = second_seed

    second[odd(fw), :]  = row_x
    second[even(fw), :] = row_x

    return second + first

In [24]:
loss_record = []
every_loss = []
epoch_loss = []

model.train()
for epo in range(epoch):
    for i, data in enumerate(trainloader, 0):
        inputs = data
        inputs = Variable(inputs)
        print (inputs.shape)
        optimizer.zero_grad()

        top = model.upward(inputs + white(inputs))
        outputs = model.downward(top, shortcut = True)

        
        loss = criterion(inputs, outputs)
        loss.backward()
        optimizer.step()

        loss_record.append(loss.item())
        every_loss.append(loss.item())
        
        if i % 10 == 0:
            print ('[%d, %5d] loss: %.3f' % (epo, i, loss.item()))
        
    epoch_loss.append(np.mean(every_loss))
    every_loss = []
        
        
# # forward
# top = model.upward(source + white(source))
# recover = model.downward(top).view(end - block_pos, 128, 128)

# # loss-bp
# loss = criterion(optouts, recover)
# loss.backward()


# # step
# optimizer.step()
# optimizer.zero_grad()


torch.Size([4, 1025, 16])


TypeError: torch.normal received an invalid combination of arguments - got (std=torch.FloatTensor, mean=torch.FloatTensor, ), but expected one of:
 * (torch.FloatTensor std)
 * (torch.FloatTensor means)
 * (float mean, torch.FloatTensor std)
      didn't match because some of the arguments have invalid types: ([31;1mmean=torch.FloatTensor[0m, [32;1mstd=torch.FloatTensor[0m, )
 * (torch.FloatTensor means, float std)
      didn't match because some of the keywords were incorrect: mean
 * (torch.FloatTensor means, torch.FloatTensor std)
      didn't match because some of the keywords were incorrect: mean
 * (torch.Generator generator, torch.FloatTensor std)
      didn't match because some of the keywords were incorrect: mean
 * (torch.Generator generator, torch.FloatTensor means)
      didn't match because some of the keywords were incorrect: std, mean
 * (torch.Generator generator, float mean, torch.FloatTensor std)
 * (torch.Generator generator, torch.FloatTensor means, float std)
 * (torch.Generator generator, torch.FloatTensor means, torch.FloatTensor std)
