In [2]:
import numpy as np
import pandas as pd
from pathlib import Path
from enum import Enum
from collections import OrderedDict

from PIL import Image

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

%matplotlib inline

In [3]:
device = torch.device('mps') if torch.backends.mps.is_available() else 'cpu'

print(device)

mps


In [4]:
csv = pd.read_csv('../data/gztan/features_30_sec.csv')
display(csv.head())

Unnamed: 0,filename,length,chroma_stft_mean,chroma_stft_var,rms_mean,rms_var,spectral_centroid_mean,spectral_centroid_var,spectral_bandwidth_mean,spectral_bandwidth_var,...,mfcc16_var,mfcc17_mean,mfcc17_var,mfcc18_mean,mfcc18_var,mfcc19_mean,mfcc19_var,mfcc20_mean,mfcc20_var,label
0,blues.00000.wav,661794,0.350088,0.088757,0.130228,0.002827,1784.16585,129774.064525,2002.44906,85882.761315,...,52.42091,-1.690215,36.524071,-0.408979,41.597103,-2.303523,55.062923,1.221291,46.936035,blues
1,blues.00001.wav,661794,0.340914,0.09498,0.095948,0.002373,1530.176679,375850.073649,2039.036516,213843.755497,...,55.356403,-0.731125,60.314529,0.295073,48.120598,-0.283518,51.10619,0.531217,45.786282,blues
2,blues.00002.wav,661794,0.363637,0.085275,0.17557,0.002746,1552.811865,156467.643368,1747.702312,76254.192257,...,40.598766,-7.729093,47.639427,-1.816407,52.382141,-3.43972,46.63966,-2.231258,30.573025,blues
3,blues.00003.wav,661794,0.404785,0.093999,0.141093,0.006346,1070.106615,184355.942417,1596.412872,166441.494769,...,44.427753,-3.319597,50.206673,0.636965,37.31913,-0.619121,37.259739,-3.407448,31.949339,blues
4,blues.00004.wav,661794,0.308526,0.087841,0.091529,0.002303,1835.004266,343399.939274,1748.172116,88445.209036,...,86.099236,-5.454034,75.269707,-0.916874,53.613918,-4.404827,62.910812,-11.703234,55.19516,blues


In [5]:
class DatasetType(Enum):
    TRAIN = 0
    TEST = 1

In [6]:
class Preprocessor:
    def __init__(self, csv_path: str, save_crops: bool=False):
        self.csv_path = Path(csv_path)
        
        self.cropped_dir = self.csv_path.parent / 'cropped'
        if not self.cropped_dir.exists():
            self.cropped_dir.mkdir(exist_ok=False)
            print("Created cropped dir.")
        
        self.csv = pd.read_csv(csv_path)

        # Add a column for the filepath.
        self.csv['im_path'] = self.csv.apply(lambda x: x.filename.strip('wav').replace('.', '') + '.png', axis=1)
        
        self.genre_ix = self._make_ix_splits()
        self.genre_statistics = self.preprocess_image_and_compute_stats(save_crops)

    def _make_ix_splits(self):
        genre_dict = dict()
        for genre in self.csv.label.unique():
            index = self.csv.loc[self.csv.label == genre, :].index
            # shuffle the indices.
            shuffled = np.random.permutation(index)
            split = int(len(shuffled) * 0.8)
            genre_dict[genre] = {DatasetType.TRAIN: shuffled[:split], DatasetType.TEST: shuffled[split:]}
        return genre_dict

    def preprocess_image_and_compute_stats(self, save_crops=False):

        def _save_images(genre, im_paths, path_to_save):
            for im_path in im_paths:
                joined_path = self.csv_path.parent / 'images_original' / genre / im_path
                if not joined_path.exists():
                    continue
                
                with Image.open(joined_path, 'r') as image:
                    # Convert from CMYK to RGB.
                    image = image.convert('RGB')
                    # Crop the image.
                    # (224, 352)
                    im = image.crop(box=(44, 29, 396, 253))
                    im_name = 'cropped_' + im_path
                    im.save(path_to_save / im_name)

        def _compute_stats(im_path):
            images = list(im_path.glob('*.png'))
            means = []
            stds = []
            for image in images:
                with Image.open(image, 'r') as im:
                    # Convert the cropped images to ndarrays
                    arr = np.asarray(im)
                    # Compute channel-wise means
                    means.append(np.mean(arr, axis=(0, 1), keepdims=True))

                    # Compute the channel-wise
                    stds.append(np.std(arr, axis=(0, 1), keepdims=True))

            return (
                np.mean(np.concatenate(means, axis=0), axis=0).flatten(),
                np.mean(np.concatenate(stds, axis=0), axis=0).flatten()
            )

        genre_stats = dict()
        
        for genre, genre_dict in self.genre_ix.items():
            train_df = self.csv.loc[genre_dict[DatasetType.TRAIN]]
            test_df = self.csv.loc[genre_dict[DatasetType.TEST]]

            print(f"genre: {genre}, train: {train_df.shape}, test: {test_df.shape}")

            genre_dir = self.cropped_dir / genre
            if not genre_dir.exists():
                genre_dir.mkdir(exist_ok=False)
            
            train_dir = genre_dir / 'train'
            if not train_dir.exists():
                train_dir.mkdir(exist_ok=False)
                
            test_dir = genre_dir / 'test'
            if not test_dir.exists():
                test_dir.mkdir(exist_ok=False)
                
            if save_crops:
                _save_images(genre, train_df.im_path, train_dir)
                _save_images(genre, test_df.im_path, test_dir)

            genre_mean, genre_std = _compute_stats(train_dir)
            genre_stats[genre] = {'mean': genre_mean, 'std': genre_std}

        return genre_stats

In [7]:
preprocessor = Preprocessor('../data/gztan/features_30_sec.csv', save_crops=True)

genre: blues, train: (80, 61), test: (20, 61)
genre: classical, train: (80, 61), test: (20, 61)
genre: country, train: (80, 61), test: (20, 61)
genre: disco, train: (80, 61), test: (20, 61)
genre: hiphop, train: (80, 61), test: (20, 61)
genre: jazz, train: (80, 61), test: (20, 61)
genre: metal, train: (80, 61), test: (20, 61)
genre: pop, train: (80, 61), test: (20, 61)
genre: reggae, train: (80, 61), test: (20, 61)
genre: rock, train: (80, 61), test: (20, 61)


In [8]:
class AutoencoderDataset(Dataset):
    def __init__(self, dset: DatasetType, dir: Path, preprocessor: Preprocessor):
        self.dset = dset
        self.dir = dir
        self.preprocessor = preprocessor
        self.transforms = self._setup_transforms()
        self.data = self._load()

    def _load(self):
        rows = []
        for genre in self.preprocessor.genre_statistics.keys():
            subdir = self.dir / genre / 'train' if self.dset == DatasetType.TRAIN else self.dir / genre / 'test'
            pngs = list(subdir.glob('*.png'))
            for png_file in pngs:
                rows.append((genre, png_file))
        
        return pd.DataFrame(rows, columns=['genre', 'filename'])

    def _setup_transforms(self):
        transforms_dict = dict()
        for genre, tsfm in self.preprocessor.genre_statistics.items():
            transforms_dict[genre] = transforms.Compose([
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float),
                transforms.Normalize(mean=tsfm['mean'], std=tsfm['std'])
            ])
        return transforms_dict

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        im_name = self.data.loc[i, 'filename']
        genre = self.data.loc[i, 'genre']
        
        with Image.open(im_name, 'r') as im:
            im = self.transforms[genre](im)

        # X and Y are the same.
        return im, im

In [9]:
train_dset = AutoencoderDataset(DatasetType.TRAIN, preprocessor.cropped_dir, preprocessor)
train_loader = DataLoader(train_dset, batch_size=8, shuffle=True)

In [10]:
X, Y = next(iter(train_loader))

In [11]:
X.size(), Y.size()

(torch.Size([8, 3, 224, 352]), torch.Size([8, 3, 224, 352]))

In [12]:
class MultiModalEncoder(nn.Module):
    def __init__(self, im_size: int, config: dict):
        super(MultiModalEncoder, self).__init__()
        self.im_size = im_size
        self.config = config
        self.encoder = self._make_encoder()
        self.decoder = self._make_decoder()

    def _conv_block(self, config: dict):
        layers = list()
        layers.append(
            nn.Conv2d(
                in_channels=config['in'], 
                out_channels=config['out'], 
                padding=config['padding'], 
                kernel_size=config['ksize'],
                stride=config['stride'], 
                bias=False
            )
        )
        if config['act'] == 'relu':
            layers.append(nn.ReLU(inplace=False))
        elif config['act'] == 'lrelu':
            layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=False))
        return nn.Sequential(*layers)

    def _deconv_block(self, config: dict):
        layers = list()
        layers.append(
            nn.ConvTranspose2d(
                in_channels=config['in'], 
                out_channels=config['out'], 
                padding=config['padding'], 
                kernel_size=config['ksize'],
                output_padding=config['output_padding'] if 'output_padding' in config else 0,
                stride=config['stride'], 
                bias=False
            )
        )
        if config['act'] == 'relu':
            layers.append(nn.ReLU(inplace=False))
        elif config['act'] == 'lrelu':
            layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=False))
        return nn.Sequential(*layers)

    def _linear(self, config: dict):
        layers = list()
        layers.append(
            nn.Linear(
                in_features=config['in'], 
                out_features=config['out'],
                bias=config['bias']
            )
        )
        if 'act' in config:
            if config['act'] == 'relu':
                layers.append(nn.ReLU(inplace=False))
            elif config['act'] == 'lrelu':
                layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=False))

        return nn.Sequential(*layers)

    def _flatten(self):
        return nn.Sequential(nn.Flatten())

    def _reshape(self, config: dict):
        return nn.Sequential(nn.Unflatten(dim=1, unflattened_size=config['out']))

    def _make_encoder(self):
        encoder_config = self.config['encoder']
        encoder = nn.ModuleList()
        
        for layer_type, config in encoder_config.items():
            if 'conv' in layer_type:
                encoder.append(self._conv_block(config))
            elif 'linear' in layer_type:
                encoder.append(self._linear(config))
            elif 'flatten' in layer_type:
                encoder.append(self._flatten())

        return encoder

    def _make_decoder(self):
        decoder_config = self.config['decoder']
        decoder = nn.ModuleList()

        for layer_type, config in decoder_config.items():
            if 'linear' in layer_type:
                decoder.append(self._linear(config))
            elif 'reshape' in layer_type:
                decoder.append(self._reshape(config))
            elif 'deconv' in layer_type:
                decoder.append(self._deconv_block(config))

        return decoder
    
    def forward(self, im):
        
        # Pass through the encoder.
        for module in self.encoder:
            im = module(im)

        # Obtain a reference to the bottleneck.
        x = im

        # Pass through the decoder.
        for module in self.decoder:
            im = module(im)
        
        return x, im

In [13]:
encoder_config = OrderedDict()

# (336, 210, 3) -> (168, 105, 8)
encoder_config['conv1'] = {'in': 3, 'out': 8, 'ksize': 5, 'padding': 1, 'stride': 2, 'act': 'relu'}
# (168, 105, 8) -> (168, 105, 8)
encoder_config['conv2'] = {'in': 8, 'out': 8, 'ksize': 3, 'padding': 1, 'stride': 1, 'act': 'relu'}
# (168, 105, 8) -> (84, 52, 16)
encoder_config['conv3'] = {'in': 8, 'out': 16, 'ksize': 3, 'padding': 1, 'stride': 2, 'act': 'relu'}
# (84, 52, 16) -> (84, 52, 16)
encoder_config['conv4'] = {'in': 16, 'out': 16, 'ksize': 3, 'padding': 1, 'stride': 1, 'act': 'relu'}
# (84, 52, 32) -> (42, 26, 32)
encoder_config['conv5'] = {'in': 16, 'out': 32, 'ksize': 3, 'padding': 1, 'stride': 2, 'act': 'relu'}
# (42, 26, 32) -> (42, 26, 32)
encoder_config['conv6'] = {'in': 32, 'out': 32, 'ksize': 3, 'padding': 1, 'stride': 1, 'act': 'relu'}
# (42, 26, 32) -> (21, 13, 64)
encoder_config['conv7'] = {'in': 32, 'out': 64, 'ksize': 3, 'padding': 1, 'stride': 2, 'act': 'relu'}
# (21, 13, 64) -> (21, 13, 64)
encoder_config['conv8'] = {'in': 64, 'out': 64, 'ksize': 3, 'padding': 1, 'stride': 1, 'act': 'relu'}
# (21, 13, 64) -> (10, 6, 128)
encoder_config['conv9'] = {'in': 64, 'out': 128, 'ksize': 3, 'padding': 1, 'stride': 2, 'act': 'relu'}
# (10, 6, 128) -> (10, 6, 128)
encoder_config['conv10'] = {'in': 128, 'out': 128, 'ksize': 3, 'padding': 1, 'stride': 1, 'act': 'relu'}
# (10, 6, 128) -> (7680,)
encoder_config['flatten'] = {}
# (7680,) -> (512,)
encoder_config['linear1'] = {'in': 9856, 'out': 512, 'bias': True, 'act': 'relu'}


## Decoder config
decoder_config = OrderedDict()

# torch.Size([8, 3, 220, 336])
# torch.Size([8, 8, 109, 167])
# torch.Size([8, 8, 109, 167])
# torch.Size([8, 16, 55, 84])
# torch.Size([8, 16, 55, 84])
# torch.Size([8, 32, 28, 42])
# torch.Size([8, 32, 28, 42])
# torch.Size([8, 64, 14, 21])
# torch.Size([8, 64, 14, 21])
# torch.Size([8, 128, 7, 11])
# torch.Size([8, 128, 7, 11])
# torch.Size([8, 9856])

decoder_config['linear1'] = {'in': 512, 'out': 9856, 'bias': True, 'acr': 'relu'}
decoder_config['reshape'] = {'in': 9856, 'out': (128, 7, 11)}
decoder_config['deconv1'] = {'in': 128, 'out': 128, 'ksize': 3, 'stride': 1, 'padding': 1, 'act': 'relu'}
decoder_config['deconv2'] = {'in': 128, 'out': 64, 'ksize': 3, 'stride': 2, 'padding': 1, 'output_padding': 1, 'act': 'relu'}
decoder_config['deconv3'] = {'in': 64, 'out': 64, 'ksize': 3, 'stride': 1, 'padding': 1, 'act': 'relu'}
decoder_config['deconv4'] = {'in': 64, 'out': 32, 'ksize': 3, 'stride': 2, 'padding': 1, 'output_padding': 1, 'act': 'relu'}
decoder_config['deconv5'] = {'in': 32, 'out': 32, 'ksize': 3, 'stride': 1, 'padding': 1, 'act': 'relu'}
decoder_config['deconv6'] = {'in': 32, 'out': 16, 'ksize': 3, 'stride': 2, 'padding': 1, 'output_padding': 1, 'act': 'relu'}
decoder_config['deconv7'] = {'in': 16, 'out': 16, 'ksize': 3, 'stride': 1, 'padding': 1, 'act': 'relu'}
decoder_config['deconv8'] = {'in': 16, 'out': 8, 'ksize': 3, 'stride': 2, 'padding': 1, 'output_padding': 1, 'act': 'relu'}
decoder_config['deconv9'] = {'in': 8, 'out': 8, 'ksize': 3, 'stride': 1, 'padding': 1, 'act': 'relu'}
decoder_config['deconv10'] = {'in': 8, 'out': 3, 'ksize': 3, 'stride': 2, 'padding': 1, 'output_padding': 1, 'act': 'relu'}

sample_config = {
    'encoder': encoder_config,
    'decoder': decoder_config
}

In [21]:
autoenc = MultiModalEncoder(im_size=1, config=sample_config)

In [22]:
autoenc = autoenc.to(device)

In [23]:
print(sum(p.numel() for p in autoenc.parameters()))

10692400


In [25]:
optimizer = optim.Adam(autoenc.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [26]:
def train(model, optim, dataloader):
    model.train()
    optim.zero_grad()

    for ix, (X, Y) in enumerate(dataloader):
        X = X.to(device)
        Y = Y.to(device)

        z, y_pred = model(X)
        loss = loss_fn(y_pred, Y)

        print(f"Loss: {loss.item()}")

        loss.backward()
        optim.step()

In [27]:
train(autoenc, optimizer, train_loader)

Loss: 0.40295109152793884
Loss: 0.5041136145591736
Loss: 0.3473391830921173
Loss: 0.406434029340744
Loss: 0.4231855273246765
Loss: 0.4638163447380066
Loss: 0.3942844867706299
Loss: 0.33832645416259766
Loss: 0.473519504070282
Loss: 0.5593042969703674
Loss: 0.45632055401802063
Loss: 0.4605395197868347
Loss: 0.5319695472717285
Loss: 0.37657880783081055
Loss: 0.5240647196769714
Loss: 0.5458301901817322
Loss: 0.49803802371025085
Loss: 0.4092603921890259
Loss: 0.4578917622566223
Loss: 0.38769233226776123
Loss: 0.35605019330978394
Loss: 0.4764402508735657
Loss: 0.5691649317741394
Loss: 0.4800168573856354
Loss: 0.5159707069396973
Loss: 0.40255218744277954
Loss: 0.41760018467903137
Loss: 0.3906107246875763
Loss: 0.532140851020813
Loss: 0.4584776759147644
Loss: 0.4789537787437439
Loss: 0.5543557405471802
Loss: 0.45214441418647766
Loss: 0.42188596725463867
Loss: 0.5485376715660095
Loss: 0.4799860417842865
Loss: 0.4641628861427307
Loss: 0.40561196208000183
Loss: 0.587681770324707
Loss: 0.448297142