In [7]:
import torch
from torch import nn, optim, autograd
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable
import numpy as np
import muspy

In [8]:
ch, T, R, Z = 64, 16, 128, (1, 3)
ch2, ch3, ch4, ch6, ch8, ch16, ch32 = ch * np.array([2, 3, 4, 6, 8, 16, 32])
device = "cuda" if torch.cuda.is_available() else "cpu"
gStep, dStep, smooth, L1, L2 = 2, 1, 0.05, 0.1, 0.01

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            # (1, 48, 128)
            nn.Conv2d(1, ch, (128, 2), 2),
            nn.BatchNorm2d(ch),
            nn.LeakyReLU(0.2, inplace = True)
        )
        self.conv2 = nn.Sequential(
            # (ch, 1, 8)
            nn.Conv2d(ch, ch2, (1, 4), 2),
            nn.BatchNorm2d(ch2),
            nn.LeakyReLU(0.2, inplace = True),
            # (ch2, 1, 3)
        )
        self.nn = nn.Linear(ch2 * 3, 1)
    def forward(self, x):
        batch_size = x.shape[0]
        c1 = self.conv1(x)
        x = self.conv2(c1)
        x = x.view(batch_size, -1)
        return self.nn(x).squeeze(), c1

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, ch, (128, 2), 2),
            nn.BatchNorm2d(ch),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(ch, ch2, (1, 4), 2),
            nn.BatchNorm2d(ch2),
            nn.ReLU()
        )
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(ch4, ch2, (1, 4), 2),
            nn.BatchNorm2d(ch2),
            nn.ReLU()
        )
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(ch3, 1, (128, 2), 2),
            nn.Sigmoid()
        )
    def forward(self, x):
        c1 = self.conv1(x) # (ch, 1, 8)
        c2 = self.conv2(c1) # (ch2, 1, 4)
        z = (torch.randn(x.shape[0], ch2, *Z, device = device) +1) / 2
        x = self.deconv1(torch.cat((z, c2), dim = 1)) # (ch4, 1, 3) => (ch2, 1, 4)
        x = self.deconv2(torch.cat((x, c1), dim = 1)) # (ch6, 1, 4) => (ch3, 1, 8)
        return x

class GAN(LightningModule):
    
    def __init__(self, batch_size):
        super().__init__()
        self.G = Generator()
        self.D = Discriminator()
        self.criterion = nn.BCEWithLogitsLoss()
        self.ones = torch.ones(batch_size, device = device)
        self.oneSmooth = torch.ones(batch_size, device = device) * (1 - smooth)
        self.zeros = torch.ones(batch_size, device = device)
        #self.automatic_optimization = False

    def forward(self, z):
        return self.G(z)

    def gStep(self, batch):
        x, px = batch
        x, px = x.float().unsqueeze(1), px.float().unsqueeze(1)
        Gz = self(px)
        DGz, GzC1 = self.D(Gz)
        Dx, xC1 = self.D(x)
        return self.criterion(DGz, self.ones) + L1 * (x - Gz).norm(2) + L2 * (xC1 - GzC1).norm(2)

    def dStep(self, batch):
        x, px = batch
        x, px = x.float().unsqueeze(1), px.float().unsqueeze(1)
        Dx, xC1 = self.D(x)
        Gz = self(px)
        DGz, GzC1 = self.D(Gz)
        self.log('dAcc', ((DGz < 0).sum() + (Dx > 0).sum()) / (2 * x.shape[0]), on_epoch = True, prog_bar = True, logger = True)
        return self.criterion(Dx, self.oneSmooth) + self.criterion(DGz, self.zeros)

    def training_step(self, batch, batch_idx, optimizer_idx):
        if optimizer_idx < gStep:
            loss = self.gStep(batch)
            self.log('gLoss', loss, on_epoch = True, prog_bar = True, logger = True)
            return loss
        else:
            loss = self.dStep(batch)
            self.log('dLoss', loss, on_epoch = True, prog_bar = True, logger = True)
            return loss
    '''
    def training_epoch_end(self, training_step_outputs):
        self.G.eval()
        nMusic = 10
        nBar = 10
        for i in range(10):
            px = torch.randn(1, 1, R, T).float() * 2 - 1
            bars = []
            for _ in range(nBar):
                bars.append(self(px.to(device)).detach().cpu())
                px = bars[-1]
            music = muspy.from_pianoroll_representation(
                (torch.cat(tuple(bars), dim = 3).squeeze().squeeze() > 0.5).numpy().T,
                resolution = 4,
                encode_velocity = False
            )
            music.write(f'samples/{i + 1}.mid')
        self.G.train()
    '''

    def configure_optimizers(self):  
        gOpt = optim.AdamW(self.G.parameters(), lr = 0.0004)
        dOpt = optim.AdamW(self.D.parameters(), lr = 0.0001)
        return [gOpt] * gStep + [dOpt] * dStep, []

In [9]:
from torch.utils.data import Dataset

class ProcessedDataset(Dataset):
    def __init__(self, G, I):
        self.Gdata = G
        self.Idata = I
    def __len__(self):
        return len(self.Gdata)
    def __getitem__(self, idx):
        return self.Gdata[idx], self.Idata[idx]

batch_size = 128
dataset = torch.load(f'./GandI_len128_10240.pt')
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [10]:
model = GAN(batch_size)

In [11]:
trainer = Trainer(
    max_epochs = 10000,
    gpus = 1 if torch.cuda.is_available() else 0,
    track_grad_norm = 2,
    log_every_n_steps = 10
)
trainer.fit(model, data_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | G         | Generator         | 230 K 
1 | D         | Discriminator     | 50.1 K
2 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
280 K     Trainable params
0         Non-trainable params
280 K     Total params
1.122     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [12]:
testMod = model.G.to(device)
testMod.eval()
nMusic = 20
nBar = 8
for i in range(10):
    px = (torch.randn(1, 1, R, T).float() + 1) / 2
    bars = []
    for _ in range(nBar):
        bars.append(testMod(px.to(device)).detach().cpu())
        px = bars[-1] * 0.9 + (bars[-2] if len(bars) > 1 else bars[-1]) * 0.09 + (bars[-3] if len(bars) > 2 else bars[-1]) * 0.01
    music = muspy.from_pianoroll_representation(
        (torch.cat(tuple(bars), dim = 3).squeeze().squeeze() > 0.5).numpy().T,
        resolution = 4,
        encode_velocity = False
    )
    music.write(f'./samples/{i + 1}.mid')
torch.save(model, './model.pt')