In [6]:
import torch
import pickle
from model.EffectDecoder import EffectDecoder
from transformers import ASTModel
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
class EffectClassifier(torch.nn.Module):
    def __init__(self, n_classes,embed_dim=768):
        super(EffectClassifier, self).__init__()
        self.pretrained = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.cls = torch.nn.Linear(embed_dim, n_classes)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.pretrained(**x).pooler_output
        x = self.cls(x)
        return self.softmax(x)

In [8]:
with open("data/guitar_sample_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

In [9]:
train_data, test_data = train_test_split(dataset, test_size=0.2)

In [11]:
def eval(model, loss_fn, dl):
    model.eval()
    total_loss = 0
    labels = []
    preds = []
    for batch in tqdm.tqdm(dl):
        features = batch['wet_tone_features'].to(device)
        label = batch['effects']
        with torch.no_grad():
            output = model(features)
        loss = loss_fn(output, labels)
        total_loss += loss.item()
        preds.append(torch.argmax(output, dim=-1))
        labels.append(torch.argmax(label))
    print(f"Accuracy:{accuracy_score(labels, preds)} | Total Loss:{total_loss}")
    return

In [12]:
def train(model, optimizer, loss_fn, train_loader,test_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm.tqdm(train_loader):
            optimizer.zero_grad()
            features = batch['wet_tone_features'].to(device)
            labels = batch['effects'].to(device)
            output = model(features)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss}")
        eval(model, loss_fn, test_loader)
    return

In [13]:
model = EffectClassifier(2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
train(model, optimizer, loss_fn, train_data, test_data, epochs=5)