In [288]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
# from util.image_pool import ImagePool
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import torchvision

In [289]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.InstanceNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.InstanceNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


In [290]:
import torch.nn as nn
import torch.nn.init
import torch as torch
import torch . nn . functional as F

class Generator_Model(torch.nn.Module):
    
    def relu(self,x):
        return nn.ReLU(x)
    def sigmoid(self,x):
        return nn.Sigmoid(x)      
    
    def __init__(self):
        
        super ( Generator_Model , self ). __init__ ()

        self.conv1 = torch.nn.Conv2d(1, 64, 7, stride=1, bias=True)
        torch.nn.init.xavier_normal(self.conv1.weight)
        self.bn1 = torch.nn.InstanceNorm2d(64)

        self.conv2 = torch.nn.Conv2d(64, 128, 3, stride=2, padding=0, bias=True)
        torch.nn.init.xavier_normal(self.conv2.weight)
        self.bn2 = torch.nn.InstanceNorm2d(128)

        self.conv3 = torch.nn.Conv2d(128, 256, 3, stride=2, padding=0, bias=True)
        torch.nn.init.xavier_normal(self.conv3.weight)
        self.bn3 = torch.nn.InstanceNorm2d(256)

        self.resnet10 = torch.nn.Sequential(*[BasicBlock(256,256) for x in range(10)])
        
        self.conv3_2 = torch.nn.ConvTranspose2d(256, 128, 3, stride=2, output_padding=1, bias=True)
        torch.nn.init.xavier_normal(self.conv3_2.weight)
        self.bn3_2 = torch.nn.InstanceNorm2d(128)
        
        self.conv2_1 = torch.nn.ConvTranspose2d(128, 64, 3, stride=2, output_padding=1, bias=True)
        torch.nn.init.xavier_normal(self.conv2_1.weight)
        self.bn2_1 = torch.nn.InstanceNorm2d(64)
        
        self.conv1_0 = torch.nn.ConvTranspose2d(64, 1, 7, stride=1, padding=0, bias=True)
        torch.nn.init.xavier_normal(self.conv1_0.weight)
        
        
    def forward(self,data):
        x = F.relu(self.bn1(self.conv1(data)))
#         print(x.shape)
        x = F.relu(self.bn2(self.conv2(x)))
#         print(x.shape)
        x = F.relu(self.bn3(self.conv3(x)))
#         print(x.shape)
        
        x = self.resnet10(x)
#         print(x.shape)
        x = F.relu(self.bn3_2(self.conv3_2(x)))
#         print(x.shape)
        x = F.relu(self.bn2_1(self.conv2_1(x)))
#         print(x.shape)
        x = F.sigmoid(self.conv1_0(x))
#         print(x.shape)
#         print('end')
        return x

In [291]:
class Discriminator_Model(torch.nn.Module):
    
    def lrelu(self,x):
        return nn.LeakyReLU(x)
    def sigmoid(self,x):
        return nn.Sigmoid(x)      
    
    def __init__(self):
        super ( Discriminator_Model , self ). __init__ ()
        self.conv1 = torch.nn.Conv2d(1, 64, 4, stride=2, padding=1, bias=True)
        torch.nn.init.xavier_normal(self.conv1.weight)

        self.conv2 = torch.nn.Conv2d(64, 256, 4, stride=2, padding=1, bias=True)
        torch.nn.init.xavier_normal(self.conv2.weight)
        self.bn2 = torch.nn.InstanceNorm2d(256)

        self.conv3 = torch.nn.Conv2d(256, 1, 1, stride=1, padding=0, bias=True)
        torch.nn.init.xavier_normal(self.conv3.weight)

    def forward(self,data):
#         print(data.shape)
        x = F.leaky_relu(self.conv1(data))
#         print(x.shape)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
#         print(x.shape)
        x = self.conv3(x)
#         print(x.shape)
        return x

In [292]:
class CycleGAN(torch.nn.Module):
    
    def __init__(self):
        super ( CycleGAN , self ). __init__ ()
        self.G_A_to_B = Generator_Model()
        self.G_B_to_A = Generator_Model()
        self.D_A = Discriminator_Model()
        self.D_B = Discriminator_Model()
        self.D_A_m = Discriminator_Model()
        self.D_B_m = Discriminator_Model()
        
        if torch.cuda.is_available():
            self.G_A_to_B = self.G_A_to_B.cuda()
            self.G_B_to_A = self.G_B_to_A.cuda()
            self.D_A = self.D_A.cuda()
            self.D_B = self.D_B.cuda()
            self.D_A_m = self.D_A_m.cuda()
            self.D_B_m = self.D_B_m.cuda()

        
    def forward(self,x_A,x_B,x_M):
        self.x_A = x_A
        self.x_B = x_B
        self.x_M = x_M
        
        self.x_A_hat = self.G_B_to_A(self.x_B)
        self.x_B_hat = self.G_A_to_B(self.x_A)
        self.x_B_til = self.G_A_to_B(self.x_A_hat)
        self.x_A_til = self.G_B_to_A(self.x_B_hat)        
        
        self._D_A = self.D_A(self.x_A)
        self._D_B = self.D_B(self.x_B)
        self._D_A_hat = self.D_A(self.x_A_hat)
        self._D_B_hat = self.D_B(self.x_B_hat)
        
        self._D_A_m = self.D_A_m(self.x_M)
        self._D_B_m = self.D_B_m(self.x_M)
        self._D_A_m_hat = self.D_A_m(self.x_A_hat)
        self._D_B_m_hat = self.D_B_m(self.x_B_hat)
        
        
    def loss_A_to_B(self):
        return torch.norm(self._D_B_hat.view(-1,16*21)-1,2,dim=1)
        
    def loss_B_to_A(self):
        return torch.norm(self._D_A_hat.view(-1,16*21)-1,2,dim=1)
    
    def loss_C(self):
        return torch.norm((self.x_A_til-self.x_A).view(-1,64*84),1,dim=1)+torch.norm((self.x_B_til-self.x_B).view(-1,64*84),1,dim=1)
    
    def loss_G(self,_lambda):
        return self.loss_A_to_B()+self.loss_B_to_A()+self.loss_C()*_lambda
    
    def loss_DA(self):
        return 0.5*(torch.norm(self._D_A.view(-1,16*21)-1,2,dim=1) + torch.norm(self._D_A_hat.view(-1,16*21),2,dim=1))
    
    def loss_DB(self):
        return 0.5*(torch.norm(self._D_B.view(-1,16*21)-1,2,dim=1) + torch.norm(self._D_B_hat.view(-1,16*21),2,dim=1))
    
    def loss_DA_m(self):
        return 0.5*(torch.norm(self._D_A_m.view(-1,16*21)-1,2,dim=1) + torch.norm(self._D_A_m_hat.view(-1,16*21),2,dim=1))
    
    def loss_DB_m(self):
        return 0.5*(torch.norm(self._D_B_m.view(-1,16*21)-1,2,dim=1) + torch.norm(self._D_B_m_hat.view(-1,16*21),2,dim=1))
    
    def loss_D(self,_gamma):
        return self.loss_DA() + self.loss_DB() + _gamma*(self.loss_DA_m() + self.loss_DB_m())
    
    def total_loss(self,_gamma,_lambda):
        loss = self.loss_D(_gamma)+self.loss_G(_lambda)
        return torch.sum(loss)


In [293]:
classic_test = np.load('classic_test_piano.npy').astype(np.float32)
classic_train = np.load('classic_train_piano.npy').astype(np.float32)
jazz_test = np.load('jazz_test_piano.npy').astype(np.float32)
jazz_train = np.load('jazz_train_piano.npy').astype(np.float32)
pop_test = np.load('pop_test_piano.npy').astype(np.float32)
pop_train = np.load('pop_train_piano.npy').astype(np.float32)

In [294]:
cycleGAN = CycleGAN()

if torch.cuda.is_available():
    cycleGAN = cycleGAN.cuda()

if not torch.cuda.is_available():
    classic_test_tensor = torch.from_numpy(classic_test)
    classic_train_tensor = torch.from_numpy(classic_train)
    jazz_test_tensor = torch.from_numpy(jazz_test)
    jazz_train_tensor = torch.from_numpy(jazz_train)
    pop_test_tensor = torch.from_numpy(pop_test)
    pop_train_tensor = torch.from_numpy(pop_train)

else:
    classic_test_tensor = torch.from_numpy(classic_test).cuda()
    classic_train_tensor = torch.from_numpy(classic_train).cuda()
    jazz_test_tensor = torch.from_numpy(jazz_test).cuda()
    jazz_train_tensor = torch.from_numpy(jazz_train).cuda()
    pop_test_tensor = torch.from_numpy(pop_test).cuda()
    pop_train_tensor = torch.from_numpy(pop_train).cuda()
    
classic_test_var = Variable(classic_test_tensor.permute(0,3,1,2))
classic_train_var = Variable(classic_train_tensor.permute(0,3,1,2))
jazz_test_var = Variable(jazz_test_tensor.permute(0,3,1,2))
jazz_train_var = Variable(jazz_train_tensor.permute(0,3,1,2))    
pop_test_var = Variable(pop_test_tensor.permute(0,3,1,2))
pop_train_var = Variable(pop_train_tensor.permute(0,3,1,2))   
    

In [295]:
x_A = classic_train_var
x_B = pop_train_var
x_M = torch.cat((x_A,x_B),0)
n_M = np.random.permutation(x_M.shape[0])
x_M = x_M[n_M]

cycleGAN = CycleGAN()

lr = .0002
epochs = 30
B = 16
N = min(x_A.shape[0],x_B.shape[0])
NB = (N + B - 1) / B
optimizer = torch . optim . Adam ( cycleGAN.parameters() , lr = 1, betas=(0.5, 0.999))

In [None]:
for epoch in range(epochs):
    
    running_loss = 0.0
    idxminibatches_A = np.random.permutation(NB)
    idxminibatches_B = np.random.permutation(NB)
    idxminibatches_M = np.random.permutation(NB)
    
    for k in range(NB):
        i_A = idxminibatches_A[k]
        i_B = idxminibatches_B[k]
        i_M = idxminibatches_M[k]
        
        idx_A = np.arange(B*i_A, min(B*(i_A+1), N))
        idx_B = np.arange(B*i_B, min(B*(i_B+1), N))
        idx_M = np.arange(B*i_M, min(B*(i_M+1), N))
        
        inputs_A = x_A[idx_A]
        inputs_B = x_B[idx_B]
        inputs_M = x_M[idx_M]
        
        # Initialize the gradients to zero
        optimizer.zero_grad()
        
        # Forward propogation
        outputs = cycleGAN(inputs_A,inputs_B,inputs_M)
        
        # Error evaluation
        loss = cycleGAN.total_loss(1,10)
        # Back propogation
        loss.backward()
        # Optimize step
        optimizer.step()
        # Print statistics
        running_loss += loss[0]
        if k % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
            (epoch + 1, k + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 128829.211
[1,   200] loss: 113629.391
[1,   300] loss: 113690.414
[1,   400] loss: 111855.828
[1,   500] loss: 114210.070
[1,   600] loss: 109126.930
[1,   700] loss: 115502.391
[1,   800] loss: 110880.938
[1,   900] loss: 112756.094
[2,   100] loss: 112692.844
[2,   200] loss: 109057.156
[2,   300] loss: 110255.117
[2,   400] loss: 114279.070
[2,   500] loss: 113894.805
[2,   600] loss: 110809.000
[2,   700] loss: 116717.289
[2,   800] loss: 115049.320
[2,   900] loss: 112786.445
[3,   100] loss: 111585.867
[3,   200] loss: 111764.789
[3,   300] loss: 112040.656
[3,   400] loss: 110828.836
[3,   500] loss: 114746.078
[3,   600] loss: 112817.539
[3,   700] loss: 114045.234
[3,   800] loss: 113594.641
[3,   900] loss: 112686.891
[4,   100] loss: 112409.641
[4,   200] loss: 115409.125
[4,   300] loss: 115408.008
[4,   400] loss: 114350.766
[4,   500] loss: 109603.836
[4,   600] loss: 110706.438
[4,   700] loss: 111814.414
[4,   800] loss: 112214.320
[4,   900] loss: 117

In [None]:
classic_train[0][0].shape

In [65]:
print(torch.__version__)


0.3.1.post3


In [74]:
[x[0] for x in cycleGAN.named_parameters() if 'G' not in x[0]]

['D_A.conv1.weight',
 'D_A.conv1.bias',
 'D_A.conv2.weight',
 'D_A.conv2.bias',
 'D_A.conv3.weight',
 'D_A.conv3.bias',
 'D_B.conv1.weight',
 'D_B.conv1.bias',
 'D_B.conv2.weight',
 'D_B.conv2.bias',
 'D_B.conv3.weight',
 'D_B.conv3.bias',
 'D_A_m.conv1.weight',
 'D_A_m.conv1.bias',
 'D_A_m.conv2.weight',
 'D_A_m.conv2.bias',
 'D_A_m.conv3.weight',
 'D_A_m.conv3.bias',
 'D_B_m.conv1.weight',
 'D_B_m.conv1.bias',
 'D_B_m.conv2.weight',
 'D_B_m.conv2.bias',
 'D_B_m.conv3.weight',
 'D_B_m.conv3.bias']

In [69]:
print(list(cycleGAN.named_parameters()))


[('G_A_to_B.conv1.weight', Parameter containing:
(0 ,0 ,.,.) = 
1.00000e-02 *
  2.0263 -1.1158  3.2483  ...  -1.8913  3.1456 -3.2938
  3.7425  2.7359 -1.2952  ...   0.2737 -3.5323  2.9033
  4.8793 -1.8771  2.2856  ...  -1.5156 -0.1007 -3.5183
           ...             ⋱             ...          
 -0.5586 -0.4971  2.2749  ...   0.8657  1.5532 -0.4738
 -3.2061  3.0685  0.4130  ...  -1.0171  0.8625 -1.3910
  0.1185  1.7487  0.6266  ...   2.0194  3.2241 -0.8578
     ⋮ 

(1 ,0 ,.,.) = 
1.00000e-02 *
  1.2109  1.9695  3.5980  ...  -2.1914 -1.2175 -2.6734
  5.1610  1.3584  1.6061  ...  -2.7371 -2.2486  2.2956
 -3.5624 -1.2408 -0.0073  ...   1.6168 -1.7344  2.5738
           ...             ⋱             ...          
  0.4630  0.0799  1.5488  ...  -1.4148  2.1237 -1.9850
  2.7106 -2.8638  0.9619  ...  -2.9043 -4.1627 -3.0951
 -4.5182  0.8214  0.1946  ...   6.5370  1.1103 -3.2793
     ⋮ 

(2 ,0 ,.,.) = 
1.00000e-02 *
  2.2461 -1.7950  0.2517  ...  -2.7830  1.4677  1.1283
 -0.4615  4.8965  2.0

In [70]:
 cycleGAN.state_dict()

OrderedDict([('G_A_to_B.conv1.weight', 
              (0 ,0 ,.,.) = 
              1.00000e-02 *
                2.0263 -1.1158  3.2483  ...  -1.8913  3.1456 -3.2938
                3.7425  2.7359 -1.2952  ...   0.2737 -3.5323  2.9033
                4.8793 -1.8771  2.2856  ...  -1.5156 -0.1007 -3.5183
                         ...             ⋱             ...          
               -0.5586 -0.4971  2.2749  ...   0.8657  1.5532 -0.4738
               -3.2061  3.0685  0.4130  ...  -1.0171  0.8625 -1.3910
                0.1185  1.7487  0.6266  ...   2.0194  3.2241 -0.8578
                   ⋮ 
              
              (1 ,0 ,.,.) = 
              1.00000e-02 *
                1.2109  1.9695  3.5980  ...  -2.1914 -1.2175 -2.6734
                5.1610  1.3584  1.6061  ...  -2.7371 -2.2486  2.2956
               -3.5624 -1.2408 -0.0073  ...   1.6168 -1.7344  2.5738
                         ...             ⋱             ...          
                0.4630  0.0799  1.5488  ...  -1.41