# How To Run

Download code from github repo (link: https://github.com/shubham7423/Music-Genre-Recognition). <br>
Place this notebook in the same folder as mgr. <br>
Install all libraries from requirements.txt file. <br>
Download the raw audio files, preprocessed data and model weights from https://drive.google.com/file/d/1v4FYfKXk6gyzohnz6UZ7BT4cznHUSjFV/view and extract the contents in the same folder as mgr folder. <br>
Path to the data and model weights can be updated in the mgr/configuration/configuration.yaml file

In [None]:
#Run this cell to set the device to use.

import mgr.configuration
import torch
import yaml

CFG = mgr.configuration.load.load_configurations()
CFG['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

with open('mgr\configuration\configuration.yaml', 'w') as file:
    documents = yaml.dump(CFG, file)

# Preprocessing

In [None]:
import mgr.preprocessing as preprocessing

In [None]:
preprocessing.process()

# Train

In [None]:
import mgr.train.transformer.cnn_transformer_v3 as cnn_transformer_v3
import mgr.train.transformer.cnn_patch_transformer as cnn_patch_transformer
import mgr.train.lstm.cnn_lstm as cnn_lstm
import mgr.train.cnn.cnn as cnn
import mgr.train.cnn.resnet as resnet

## Select model to train

1.CNN <br>
2.Resnet <br>
3.CNN+LSTM <br>
4.CNN+Transformer <br>
5.CNN+Transformer(Patched)

In [None]:
model_option = 2

In [None]:
if model_option == 1:
    model, History = cnn.start_training()
    
elif model_option == 2:
    model, History = resnet.start_training()

elif model_option == 3:
    model, History = cnn_lstm.start_training()

elif model_option == 4:
    model, History = cnn_transformer_v3.start_training()
    
elif model_option == 5:
    model, History = cnn_patch_transformer.start_training()

# Predict

In [None]:
import mgr.predict as predict
import mgr

import mgr.train.transformer.cnn_transformer_v3 as cnn_transformer_v3
import mgr.train.transformer.cnn_patch_transformer as cnn_patch_transformer
import mgr.train.lstm.cnn_lstm as cnn_lstm
import mgr.train.cnn.cnn as cnn
import mgr.train.cnn.resnet as resnet

import torch
import os

CFG = mgr.configuration.load_configurations()

# Model Architecture:
1.CNN <br>
2.Resnet <br>
2.CNN+LSTM <br>
3.CNN+Transformer(v3) <br>
4.CNN+Transformer(Patched)

For windows do not use .mp3 file <br>
Provide path to the audio file to (AUDIO_PATH) variable which is longer than 3 seconds.

In [None]:
AUDIO_PATH = "sample_audio/electronic_1.wav"

model_architecture = 5

model = None

if model_architecture == 1:
    model = cnn.getModel()
    ckpts = torch.load(os.path.join(CFG['cnn']['train']['save_model_at'], "cnn.pt"), map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])
    
elif model_architecture == 2:
    model = resnet.getModel()
    ckpts = torch.load(os.path.join(CFG['cnn']['train']['save_model_at'], "resnet.pt"), map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])

elif model_architecture == 3:
    model = cnn_lstm.getModel()
    ckpts = torch.load(os.path.join(CFG['lstm']['train']['save_model_at'], "lstm.pt"), map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])

elif model_architecture == 4:
    model = cnn_transformer_v3.getModel()
    ckpts = torch.load(os.path.join(CFG['transformer']['train']['save_model_at'], "transformerv3.pt"), map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])
    
elif model_architecture == 5:
    model = cnn_patch_transformer.getModel()
    ckpts = torch.load(os.path.join(CFG['transformer']['train']['save_model_at'], "cnn_patch_transformer.pt"), map_location=CFG['device'])
    model.load_state_dict(ckpts['model'])
    
else:
    print("Enter valid choice!!")

if model is not None:
    print("Top 3 genres: ", predict.predict(model, AUDIO_PATH))