In [None]:
from google.colab import drive
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!for i in {1..6}; do tar -xf "/content/drive/MyDrive/KUBIG_melon/melon-dataset/arena_mel_${i}.tar" -C "/content"; done

In [None]:
songmeta_path = '/content/drive/MyDrive/KUBIG_melon/song_meta.json'
with open(songmeta_path, 'r', encoding='utf-8') as f:
    song_meta_json = json.load(f)
song_meta = pd.DataFrame(song_meta_json)

In [None]:
song_meta_1 = song_meta[(song_meta['song_gn_gnr_basket'].apply(len) == 1)]
song_meta_1 = song_meta_1[~song_meta_1['song_gn_gnr_basket'].apply(lambda x: 'GN0500' in x)]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ResBlock(nn.Module):
    def __init__(self, in_planes, planes, stride = 1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(planes))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3,7), padding=(1,3)),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.layer1 = ResBlock(32, 64, stride=(2,2))
        self.layer2 = ResBlock(64, 128, stride=(2,2))
        self.layer3 = ResBlock(128, 256, stride=(2,2))

        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = x.squeeze(-1).squeeze(-1)
        return self.fc(x)

def load_model(weight_path, num_classes):
    model = ResNet(num_classes)
    model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model

GENRE_MAP = {
    'GN0100': "발라드",
    'GN0200': "댄스",
    'GN0300': "랩/힙합", #국내
    'GN0400': "R&B/Soul", #국내
    'GN0600': "록/메탈", #국내
    'GN0700': "성인가요",
    'GN0800': "포크/블루스", #국내
    'GN0900': "POP",
    'GN1000': "록/메탈", #해외
    'GN1100': "일렉트로니카", #해외
    'GN1200': "랩/힙합",
    'GN1300': "R&B/Soul",
    'GN1400': "포크/블루스",
    'GN1500': "OST",
    'GN1600': "클래식",
    'GN1700': "재즈",
    'GN1800': "뉴에이지",
    'GN1900': "J-POP",
    'GN2000': "월드뮤직",
    'GN2100': "CCM",
    'GN2200': "어린이/태교",
    'GN2300': "종교음악",
    'GN2400': "국악",
    'GN2600': "일렉트로니카(스타일)",
    'GN2700': "EDM",
    'GN2800': "뮤직테라피",
    'GN9000': "UNKNOWN"
}

In [None]:
le = LabelEncoder()
song_meta_1['genre'] = song_meta_1['song_gn_gnr_basket'].str[0].map(GENRE_MAP)
song_meta_1['label'] = le.fit_transform(song_meta_1['genre'].str[0])
song_meta_2 = song_meta_1[['id', 'label']]
song_meta_100k = song_meta_2.iloc[:100000]
train_df, test_df = train_test_split(song_meta_100k, test_size=0.2, stratify=song_meta_100k['label'], random_state=42)

In [None]:
class MelDataset(Dataset):
  def __init__(self, df, mel_root, target_len = 1024):
    self.df = df.reset_index(drop=True)
    self.mel_root = mel_root
    self.target_len = target_len

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):
    row = self.df.iloc[idx]
    song_id = int(row['id'])
    subdir = song_id // 1000
    mel_path = os.path.join(self.mel_root, str(subdir), f'{song_id}.npy')
    try:
      mel = np.load(mel_path).astype(np.float32)
    except ValueError:
      print('error')
      return self.__getitem__((idx+1) % len(self.df))
    t = mel.shape[1]
    if t > self.target_len:
      start = np.random.randint(0, t-self.target_len)
      mel = mel[:, start:(start+self.target_len)]
    elif t < self.target_len:
      mel = np.pad(mel, ((0, 0), (0, self.target_len - t)), mode='constant')
    mel = (mel - mel.mean(axis=1, keepdims=True)) / (mel.std(axis=1, keepdims=True) + 1e-6)

    mel = torch.tensor(mel).unsqueeze(0)
    label = torch.tensor(row['label'], dtype=torch.long)
    return mel, label

In [None]:
train_ds = MelDataset(train_df, '/content/arena_mel/')
test_ds = MelDataset(test_df, '/content/arena_mel/')

train_loader = DataLoader(
    train_ds,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_planes, planes, stride = 1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm2d(planes))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 11), padding=(1,3)),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.layer1 = ResBlock(32, 64, stride=(2,2))
        self.layer2 = ResBlock(64, 128, stride=(2,2))
        self.layer3 = ResBlock(128, 256, stride=(2,2))
        self.layer4 = ResBlock(256, 256, stride=(2,2))

        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x)
        x = x.squeeze(-1).squeeze(-1)
        return self.fc(x)

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(le.classes_)
model = ResNet(num_classes=num_classes).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=2, factor=0.5
)

In [None]:
def train(model, train_loader, optimizer, log_interval):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (mel, label) in enumerate(tqdm(train_loader)):
        mel = mel.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(mel)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred==label).sum().item()
        total += label.size(0
                            )
        if batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(
                epoch, batch_idx * len(mel),
                len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()))

    avg_loss = train_loss / len(train_loader)
    acc = correct / total
    return avg_loss, acc

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for mel, label in test_loader:
            mel = mel.to(DEVICE)
            label = label.to(DEVICE)
            output = model(mel)
            test_loss += criterion(output, label).item()
            pred = output.argmax(dim=1)
            correct += (pred==label).sum().item()

    avg_loss = test_loss / len(test_loader)
    acc = 100. * correct / len(test_loader.dataset)
    return avg_loss, acc

In [None]:
BATCH_SIZE = 64
EPOCHS = 15

In [None]:
best_loss = float('inf')
for epoch in range(1, EPOCHS + 1):
    train(model, train_loader, optimizer, log_interval = 200)
    test_loss, test_accuracy = evaluate(model, test_loader)
    scheduler.step(test_loss)
    if best_loss > test_loss:
        best_loss = test_loss
        torch.save(model.state_dict(), "resnet_genre_best.pth")
    print("\n[EPOCH: {}], \tTest Loss: {:.4f}, \tTest Accuracy: {:.2f} % \n".format(
        epoch, test_loss, test_accuracy))

  0%|          | 1/1250 [00:00<14:19,  1.45it/s]



 16%|█▌        | 201/1250 [02:20<12:18,  1.42it/s]



 32%|███▏      | 401/1250 [04:40<09:47,  1.44it/s]



 48%|████▊     | 601/1250 [07:00<07:26,  1.45it/s]



 64%|██████▍   | 801/1250 [09:20<05:25,  1.38it/s]



 80%|████████  | 1001/1250 [11:39<02:50,  1.46it/s]



 96%|█████████▌| 1201/1250 [13:58<00:34,  1.41it/s]



100%|██████████| 1250/1250 [14:32<00:00,  1.43it/s]



[EPOCH: 1], 	Test Loss: 1.5314, 	Test Accuracy: 54.38 % 



  0%|          | 1/1250 [00:00<13:41,  1.52it/s]



 16%|█▌        | 201/1250 [02:19<12:22,  1.41it/s]



 32%|███▏      | 401/1250 [04:38<09:46,  1.45it/s]



 48%|████▊     | 601/1250 [06:56<07:30,  1.44it/s]



 64%|██████▍   | 801/1250 [09:15<05:11,  1.44it/s]



 80%|████████  | 1001/1250 [11:34<02:47,  1.49it/s]



 96%|█████████▌| 1201/1250 [13:53<00:34,  1.43it/s]



100%|██████████| 1250/1250 [14:27<00:00,  1.44it/s]



[EPOCH: 2], 	Test Loss: 1.4352, 	Test Accuracy: 56.55 % 



  0%|          | 1/1250 [00:00<14:24,  1.45it/s]



 16%|█▌        | 201/1250 [02:20<12:32,  1.39it/s]



 32%|███▏      | 401/1250 [04:40<09:46,  1.45it/s]



 32%|███▏      | 402/1250 [04:41<09:45,  1.45it/s]

In [None]:
torch.save(model.state_dict(), "resnet_genre.pth")