In [370]:
from datasets import load_dataset

ds = load_dataset("luli0034/music-tags-to-spectrogram", split='train')

In [371]:
import numpy as np
import torch
from torch import nn
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report, accuracy_score
from datasets import load_dataset
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.utils.data.dataloader import default_collate

In [None]:
BATCH_SIZE = 64
IMAGE_SIZE = 512

In [373]:
subset_size = int(len(ds))
subset = ds.select(range(subset_size))
print(subset)
ds = subset

Dataset({
    features: ['image', 'text'],
    num_rows: 1543
})


In [374]:
ds = ds.train_test_split(test_size=0.2)

In [375]:
ds_train, ds_test = ds["train"], ds["test"]

In [None]:
class MusicDataset(Dataset):
    def __init__(self, ds, transform=None):
        self.transform = transform
        self.data_frame = ds
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, index):
        try:
            genres = self.data_frame[index]["text"]
            if self.transform:
                image = self.transform(self.data_frame[index]["image"])
            return image, genres
        except Exception as e:
            print(e)
            return self.transform(np.ones((IMAGE_SIZE, IMAGE_SIZE, 3)), "rock"


In [377]:
# def my_collate(batch):
#     batch = list(filter(lambda x: x is not None, batch))
#     return default_collate(batch)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [379]:
def extract_image_features(dataloader, model):
    model.eval()
    features = []
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            output = model(inputs)
            features.append(output.cpu().numpy())
    return np.vstack(features)


In [380]:
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)

### Поэтому я превращу их в квадрат

In [381]:
image_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [382]:
train_dataset = MusicDataset(ds_train, transform=image_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [383]:
test_dataset = MusicDataset(ds_test, transform=image_transforms)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [384]:
model = models.resnet50(pretrained=True)
model.fc = nn.Identity()
model.to(device) 
resnet = model



In [385]:
all_genres = ds_train.remove_columns('image')
all_genres_test = ds_test.remove_columns('image')

In [386]:
all_genres = [genre["text"].split(" ") for genre in all_genres]
all_genres_test = [genre["text"].split(" ")  for genre in all_genres_test]

In [387]:
all_genres[0]

['soundtrack', 'electronic', 'experimental']

In [388]:
all_genres_test[0]

['orchestral', 'classical', 'soundtrack']

In [389]:
mlb = MultiLabelBinarizer()
y_train = mlb.fit_transform(all_genres)
y_test_encoder = mlb.transform(all_genres_test)



In [390]:
mlb.classes_

array(['60s', '70s', '80s', '90s', 'accordion', 'acidjazz',
       'acousticbassguitar', 'acousticguitar', 'action', 'adventure',
       'advertising', 'african', 'alternative', 'alternativerock',
       'ambient', 'ambiental', 'atmospheric', 'background', 'ballad',
       'bass', 'beat', 'bell', 'blues', 'bongo', 'bossanova', 'brass',
       'breakbeat', 'calm', 'cello', 'celtic', 'chanson', 'children',
       'chillout', 'choir', 'christmas', 'clarinet', 'classical',
       'classicalguitar', 'club', 'commercial', 'computer',
       'contemporary', 'cool', 'corporate', 'country', 'dance', 'dark',
       'darkambient', 'darkwave', 'deep', 'deephouse', 'documentary',
       'doublebass', 'downtempo', 'drama', 'dramatic', 'dream',
       'drummachine', 'drumnbass', 'drums', 'dubstep', 'easylistening',
       'edm', 'electricguitar', 'electricpiano', 'electronic',
       'electronica', 'electropop', 'emotional', 'energetic', 'epic',
       'ethno', 'eurodance', 'experimental', 'fast', 'f

In [391]:
len(mlb.classes_)

182

In [392]:
train_features = extract_image_features(train_loader, model)

unrecognized data stream contents when reading image file


TypeError: expected Tensor as element 197 in argument 0, but got numpy.ndarray

In [None]:
self.data_frame[index]["image"])

616

In [None]:
feature_tensor = torch.tensor(train_features, dtype=torch.float32).to(device)
labels_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)

In [None]:
model = LogisticRegressionModel(input_size=feature_tensor.shape[1], num_classes=labels_tensor.shape[1]).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

### Baseline CV Train

In [None]:
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(feature_tensor)
    loss = criterion(outputs, labels_tensor)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch + 1} /100], Loss: {loss.item():.4f}')

Epoch [1 /100], Loss: 0.6898
Epoch [2 /100], Loss: 0.6828
Epoch [3 /100], Loss: 0.6759
Epoch [4 /100], Loss: 0.6692
Epoch [5 /100], Loss: 0.6626
Epoch [6 /100], Loss: 0.6560
Epoch [7 /100], Loss: 0.6496
Epoch [8 /100], Loss: 0.6432
Epoch [9 /100], Loss: 0.6370
Epoch [10 /100], Loss: 0.6308
Epoch [11 /100], Loss: 0.6247
Epoch [12 /100], Loss: 0.6188
Epoch [13 /100], Loss: 0.6129
Epoch [14 /100], Loss: 0.6071
Epoch [15 /100], Loss: 0.6014
Epoch [16 /100], Loss: 0.5958
Epoch [17 /100], Loss: 0.5902
Epoch [18 /100], Loss: 0.5848
Epoch [19 /100], Loss: 0.5794
Epoch [20 /100], Loss: 0.5741
Epoch [21 /100], Loss: 0.5689
Epoch [22 /100], Loss: 0.5638
Epoch [23 /100], Loss: 0.5587
Epoch [24 /100], Loss: 0.5537
Epoch [25 /100], Loss: 0.5488
Epoch [26 /100], Loss: 0.5440
Epoch [27 /100], Loss: 0.5392
Epoch [28 /100], Loss: 0.5346
Epoch [29 /100], Loss: 0.5299
Epoch [30 /100], Loss: 0.5254
Epoch [31 /100], Loss: 0.5209
Epoch [32 /100], Loss: 0.5165
Epoch [33 /100], Loss: 0.5121
Epoch [34 /100], Lo

In [None]:
test_features = extract_image_features(test_loader, resnet)

In [None]:
feature_tensor_test = torch.tensor(test_features, dtype=torch.float32).to(device)
labels_tensor_test = torch.tensor(y_test_encoder, dtype=torch.float32).to(device)

In [None]:
model.eval()
with torch.no_grad():
    test_ouptputs = model(feature_tensor_test)
    test_predictions = torch.sigmoid(test_ouptputs).cpu().numpy()
    test_predictions = (test_predictions > 0.5).astype(int)

In [None]:
print(classification_report(y_test_encoder, test_predictions, target_names=mlb.classes_))

                    precision    recall  f1-score   support

               60s       0.00      0.00      0.00         0
               70s       0.00      0.00      0.00         0
               80s       0.00      0.00      0.00         1
               90s       0.00      0.00      0.00         1
         accordion       0.00      0.00      0.00         0
          acidjazz       0.00      0.00      0.00         1
acousticbassguitar       0.00      0.00      0.00         0
    acousticguitar       0.00      0.00      0.00         2
           african       0.00      0.00      0.00         0
       alternative       0.00      0.00      0.00         5
   alternativerock       0.00      0.00      0.00         0
           ambient       0.00      0.00      0.00        19
       atmospheric       0.00      0.00      0.00         7
              bass       0.00      0.00      0.00         2
              beat       0.00      0.00      0.00         2
             blues       0.00      0.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
torch.save(model.state_dict(), "multiclass_model_simple.pth")