In [370]:
# from datasets import load_dataset

# ds = load_dataset("luli0034/music-tags-to-spectrogram", split='train')

In [49]:
import numpy as np
import torch
from torch import nn
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report, accuracy_score
from datasets import load_dataset
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.utils.data.dataloader import default_collate
import pandas as pd
from tqdm import tqdm

In [10]:
BATCH_SIZE = 64
IMAGE_SIZE = 512
MY_PWD = "/Users/akovel/Documents/HSE/Music-Predictor/"
DS_PATH = f"{MY_PWD}data/spectograms"
MIN_NUM_GENRES = 40


In [373]:
# subset_size = int(len(ds))
# subset = ds.select(range(subset_size))
# print(subset)
# ds = subset

Dataset({
    features: ['image', 'text'],
    num_rows: 1543
})


In [7]:
def read_dataset_from_json(path):
    path += "/spectogramsgenres.json"
    df = pd.read_json(path)

    return df.T

df = read_dataset_from_json(DS_PATH)

In [11]:
def make_simple_genres(df):
    genre_counts = df['genres'].str.split(expand=True).stack().value_counts()
    rare_genres = genre_counts[genre_counts < MIN_NUM_GENRES].index

    def transform_genres_to_simple(genres_text):
        genres = genres_text.split()
        return ' '.join([genre for genre in genres if genre not in rare_genres])
    df['simple_genre'] = df['genres'].apply(transform_genres_to_simple)
    df = df[df['simple_genre'] != ""]
    return df

In [12]:
df = make_simple_genres(df)

In [13]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 1001 entries, 0 to 1538
Data columns (total 3 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   genres        1001 non-null   object
 1   image_path    1001 non-null   object
 2   simple_genre  1001 non-null   object
dtypes: object(3)
memory usage: 31.3+ KB


In [14]:
df.head()

Unnamed: 0,genres,image_path,simple_genre
0,soundtrack classical,/Users/akovel/Documents/HSE/Music-Predictor/da...,soundtrack classical
1,hiphop electronic latin,/Users/akovel/Documents/HSE/Music-Predictor/da...,electronic
3,soundtrack ambient classical,/Users/akovel/Documents/HSE/Music-Predictor/da...,soundtrack ambient classical
4,soundtrack ambient classical,/Users/akovel/Documents/HSE/Music-Predictor/da...,soundtrack ambient classical
5,soundtrack ambient classical,/Users/akovel/Documents/HSE/Music-Predictor/da...,soundtrack ambient classical


In [15]:
ds_train, ds_test = train_test_split(df, test_size=0.2)

In [45]:
class MusicDataset(Dataset):
    def __init__(self, ds, transform=None):
        self.transform = transform
        self.data_frame = ds
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, index):
        genres = self.data_frame["simple_genre"].iloc[index]
        if self.transform:
            image = Image.open(self.data_frame["image_path"].iloc[index])
            image = self.transform(image)
        return image, genres


In [17]:
# def my_collate(batch):
#     batch = list(filter(lambda x: x is not None, batch))
#     return default_collate(batch)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [52]:
def extract_image_features(dataloader, model):
    model.eval()
    features = []
    with torch.no_grad():
        for inputs, _ in tqdm(dataloader):
            inputs = inputs.to(device)
            output = model(inputs)
            features.append(output.cpu().numpy())
    return np.vstack(features)


In [20]:
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)

### Поэтому я превращу их в квадрат

In [21]:
image_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [46]:
train_dataset = MusicDataset(ds_train, transform=image_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [47]:
test_dataset = MusicDataset(ds_test, transform=image_transforms)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [24]:
model = models.resnet50(pretrained=True)
model.fc = nn.Identity()
model.to(device) 
resnet = model



In [32]:
ds_train["simple_genre"][0].split()

['soundtrack', 'classical']

In [34]:
all_genres = [genre.split() for genre in ds_train["simple_genre"]]
all_genres_test = [genre.split()  for genre in ds_test["simple_genre"]]

In [35]:
all_genres[0]

['piano']

In [36]:
all_genres_test[0]

['pop']

In [37]:
mlb = MultiLabelBinarizer()
y_train = mlb.fit_transform(all_genres)
y_test_encoder = mlb.transform(all_genres_test)

In [38]:
mlb.classes_

array(['ambient', 'bass', 'chillout', 'classical', 'dance', 'drums',
       'easylistening', 'electricguitar', 'electronic', 'emotional',
       'film', 'guitar', 'happy', 'newage', 'orchestral', 'piano', 'pop',
       'relaxing', 'rock', 'soundtrack', 'synthesizer'], dtype=object)

In [39]:
len(mlb.classes_)

21

In [53]:
train_features = extract_image_features(train_loader, model)

100%|██████████| 13/13 [09:30<00:00, 43.89s/it]


In [54]:
feature_tensor = torch.tensor(train_features, dtype=torch.float32).to(device)
labels_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)

In [55]:
model = LogisticRegressionModel(input_size=feature_tensor.shape[1], num_classes=labels_tensor.shape[1]).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

### Baseline CV Train

In [56]:
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(feature_tensor)
    loss = criterion(outputs, labels_tensor)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch + 1} /100], Loss: {loss.item():.4f}')

Epoch [1 /100], Loss: 0.7006
Epoch [2 /100], Loss: 0.6761
Epoch [3 /100], Loss: 0.6533
Epoch [4 /100], Loss: 0.6320
Epoch [5 /100], Loss: 0.6120
Epoch [6 /100], Loss: 0.5934
Epoch [7 /100], Loss: 0.5760
Epoch [8 /100], Loss: 0.5598
Epoch [9 /100], Loss: 0.5445
Epoch [10 /100], Loss: 0.5303
Epoch [11 /100], Loss: 0.5170
Epoch [12 /100], Loss: 0.5045
Epoch [13 /100], Loss: 0.4927
Epoch [14 /100], Loss: 0.4817
Epoch [15 /100], Loss: 0.4714
Epoch [16 /100], Loss: 0.4617
Epoch [17 /100], Loss: 0.4526
Epoch [18 /100], Loss: 0.4440
Epoch [19 /100], Loss: 0.4360
Epoch [20 /100], Loss: 0.4283
Epoch [21 /100], Loss: 0.4212
Epoch [22 /100], Loss: 0.4144
Epoch [23 /100], Loss: 0.4080
Epoch [24 /100], Loss: 0.4019
Epoch [25 /100], Loss: 0.3962
Epoch [26 /100], Loss: 0.3908
Epoch [27 /100], Loss: 0.3857
Epoch [28 /100], Loss: 0.3808
Epoch [29 /100], Loss: 0.3762
Epoch [30 /100], Loss: 0.3718
Epoch [31 /100], Loss: 0.3677
Epoch [32 /100], Loss: 0.3637
Epoch [33 /100], Loss: 0.3599
Epoch [34 /100], Lo

In [57]:
test_features = extract_image_features(test_loader, resnet)

100%|██████████| 4/4 [02:11<00:00, 32.98s/it]


In [58]:
feature_tensor_test = torch.tensor(test_features, dtype=torch.float32).to(device)
labels_tensor_test = torch.tensor(y_test_encoder, dtype=torch.float32).to(device)

In [71]:
model.eval()
with torch.no_grad():
    test_ouptputs = model(feature_tensor_test)
    test_predictions = torch.sigmoid(test_ouptputs).cpu().numpy()
    test_predictions = (test_predictions > 0.1).astype(int)

In [72]:
print(classification_report(y_test_encoder, test_predictions, target_names=mlb.classes_))

                precision    recall  f1-score   support

       ambient       0.10      0.95      0.17        20
          bass       0.04      0.64      0.08        11
      chillout       0.04      0.88      0.08         8
     classical       0.09      1.00      0.16        17
         dance       0.05      1.00      0.09         6
         drums       0.07      1.00      0.13        13
 easylistening       0.07      0.93      0.13        14
electricguitar       0.03      0.67      0.07         6
    electronic       0.12      1.00      0.21        24
     emotional       0.06      0.75      0.11         8
          film       0.06      0.67      0.11        12
        guitar       0.06      1.00      0.11         9
         happy       0.06      0.82      0.12        11
        newage       0.04      0.86      0.09         7
    orchestral       0.03      0.71      0.06         7
         piano       0.15      1.00      0.26        30
           pop       0.06      0.92      0.12  

In [63]:
torch.save(model.state_dict(), "multiclass_model_simple.pth")