coggers is my favorite song from interstellar, let's see what it looks like

In [1]:
import scipy.io.wavfile as wav
import matplotlib.pyplot as plt
import numpy as np
import librosa as libr

from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
coggers, sample_rate = libr.load('data/hans_zimmer/coggers.wav', sr=8000, mono=False)

In [3]:
length = len(coggers[0])

length, sample_rate

(1921150, 8000)

In [4]:
# song length: 240 seconds

length/240 # approximately the sample rate

8004.791666666667

In [5]:
# split the data into two tracks for separate training

a = coggers[0]
b = coggers[1]

In [6]:
# plt.plot(a, 'b')
# plt.plot(b, 'r')

In [7]:
# HYPERPARAMETERS

SEQ_LEN = 0.5 # this is in seconds
SEQ_LEN = int(SEQ_LEN * sample_rate)
print(SEQ_LEN)

VAL_PCT = 0.2

# gonna list some powers of 2 here for reference
# 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152
HIDDEN_DIM = 2048
N_LAYERS = 2

OUT_DIM = 100

EPOCHS = 4
LR = 0.001

4000


In [8]:
scaler_a = MinMaxScaler(feature_range=(-1,1))
scaler_b = MinMaxScaler(feature_range=(-1,1))

a = scaler_a.fit_transform(a.reshape(-1,1))
b = scaler_b.fit_transform(b.reshape(-1,1))

In [9]:
a = torch.tensor(a).cuda().float()
b = torch.tensor(b).cuda().float()

In [10]:
class LSTM(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layers, out_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x, h = self.lstm(x.view(len(x),1,-1))
        x = self.fc1(x)
        return x

In [11]:
# losses_a = []
# losses_b = []
def train():
    model_a = LSTM(SEQ_LEN, HIDDEN_DIM, N_LAYERS, OUT_DIM).cuda()
    optimizer_a = optim.Adam(model_a.parameters(), lr=LR)

    model_b = LSTM(SEQ_LEN, HIDDEN_DIM, N_LAYERS, OUT_DIM).cuda()
    optimizer_b = optim.Adam(model_b.parameters(), lr=LR)

    loss_fn = nn.MSELoss()


    for epoch in range(EPOCHS):
        for i in tqdm(  range(  int((len(a)-SEQ_LEN)/OUT_DIM)  )  ):
            Xa = a[i*OUT_DIM : i*OUT_DIM + SEQ_LEN]
            Ya = a[i*OUT_DIM+SEQ_LEN : (i+1)*OUT_DIM + SEQ_LEN]

            Xa = Xa.view(-1, SEQ_LEN)

            out_a = model_a(Xa)
            loss_a = loss_fn(out_a, Ya)

            model_a.zero_grad()
            loss_a.backward()
            optimizer_a.step()
            # losses_a.append(loss_a)

            
            Xb = b[i:i+SEQ_LEN]
            Yb = b[i+SEQ_LEN]

            Xb = Xb.view(-1, SEQ_LEN)


            out_b = model_b(Xb)
            loss_b = loss_fn(out_b, Yb)

            model_b.zero_grad()
            loss_b.backward()
            optimizer_b.step()
            # losses_b.append(loss_b)


    torch.save(model_a.state_dict(), 'models/hans_zimmer/first_a.pt')
    torch.save(model_b.state_dict(), 'models/hans_zimmer/first_b.pt')

In [12]:
model_a = LSTM(SEQ_LEN, HIDDEN_DIM, N_LAYERS, OUT_DIM).cuda()
model_b = LSTM(SEQ_LEN, HIDDEN_DIM, N_LAYERS, OUT_DIM).cuda()

model_a.load_state_dict(torch.load('models/hans_zimmer/first_a.pt'))
model_b.load_state_dict(torch.load('models/hans_zimmer/first_b.pt'))

<All keys matched successfully>

In [13]:
pred_a = torch.tensor([])
pred_b = torch.tensor([])

for i in tqdm(  range(  int((len(a)-SEQ_LEN)/OUT_DIM)  )  ):
    Xa = a[i*OUT_DIM : i*OUT_DIM + SEQ_LEN]
    Xb = b[i*OUT_DIM : i*OUT_DIM + SEQ_LEN]

    Xa = Xa.view(-1, SEQ_LEN)
    Xb = Xb.view(-1, SEQ_LEN)

    out_a = model_a(Xa).squeeze().cpu()
    out_b = model_b(Xb).squeeze().cpu()

    pred_a = torch.cat((pred_a, out_a))
    pred_b = torch.cat((pred_b, out_b))

100%|██████████| 19171/19171 [01:42<00:00, 187.93it/s]


In [14]:
pred_a = pred_a.detach().numpy()
pred_b = pred_b.detach().numpy()

In [16]:
song = np.stack((pred_a, pred_b))

In [17]:
wav.write('data/hans_zimmer/out.wav', sample_rate, song.T)