# Classes and variables

In [None]:
from birdclassification.preprocessing.filtering import filter_recordings_30
from torch.utils.data import DataLoader
import torch
from sklearn.model_selection import train_test_split
from training.dataset import Recordings30
from birdclassification.visualization.plots import plot_torch_spectrogram
from training.cnn_training_torch.CNN_model import CNNNetwork
from torchsummary import summary

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

SEED = 123
RECORDINGS_DIR = '/mnt/d/recordings_30/'
NOISES_DIR = '/aaa/'

SAMPLE_RATE = 32000
NUM_SAMPLES = SAMPLE_RATE * 1
BATCH_SIZE = 32
NUM_WORKERS = 8

LEARNING_RATE = 0.0001
EPOCHS = 5

# Prepare dataset and dataloaders, visualize dataset

In [None]:
df = filter_recordings_30("../../data/xeno_canto_recordings.csv", "../../data/bird-list-extended.csv", )

#subset for test purpose
# df = df.sample(frac = 0.1, random_state=SEED)

train_df, test_val_df = train_test_split(df, stratify=df['Latin name'], test_size=0.2, random_state = SEED)
val_df, test_df = train_test_split(test_val_df, stratify=test_val_df['Latin name'], test_size=0.5, random_state = SEED)

train_ds = Recordings30(train_df, recording_dir=RECORDINGS_DIR, noises_dir=NOISES_DIR, sample_rate=SAMPLE_RATE, device = DEVICE)
val_ds = Recordings30(val_df, recording_dir=RECORDINGS_DIR, noises_dir=NOISES_DIR, sample_rate = 32000, device = DEVICE)
test_ds = Recordings30(test_df, recording_dir=RECORDINGS_DIR, noises_dir=NOISES_DIR,sample_rate = 32000,device = DEVICE)

train_dl  = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
val_dl  = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dl  = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
train_ds.visualize_dataset(3207,5)

# Prepare a model, loss functions

In [None]:
cnn = CNNNetwork().to(DEVICE)
summary(cnn, (1, 64, 251)) 

In [None]:
cnn.eval()

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(),
                             lr=LEARNING_RATE)

# Train loop

In [None]:
from datetime import datetime
import sys
from training.training_utils import train_one_epoch
from torch.utils.tensorboard import SummaryWriter
from training.validation_metrics import calculate_metric
from sklearn.metrics import f1_score

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'logs/fashion_trainer_{timestamp}')
epoch_number = 0

best_vloss = sys.float_info.max

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))
    
    # Make sure gradient tracking is on, and do a pass over the data
    cnn.train(True)
    avg_loss = train_one_epoch(epoch_number, writer, train_dl, optimizer, loss_fn, cnn, DEVICE)

    # Set the model to evaluation mode, disabling dropout and using population 
    # statistics for batch normalization.
    cnn.eval()
    running_vloss = 0.0

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_dl):
            vinputs, vlabels = vdata
            vinputs = torch.unsqueeze(vinputs, dim=1).to(DEVICE)
            voutputs = cnn(vinputs)
            vloss = loss_fn(voutputs, vlabels.to(DEVICE))
            running_vloss += vloss
    
    avg_vloss = running_vloss / (i + 1)
    print("#############################################################")
    print("Epoch results:")
    print(f'Loss train {avg_loss} valid loss: {avg_vloss}')
    validation_f1_score = calculate_metric(cnn, val_dl, DEVICE, metric=f1_score)
    train_f1_score = None
    print(f'F1 score train {train_f1_score} valid f1 score {validation_f1_score}')
    print("#############################################################\n\n")
    
    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    
    
    writer.add_scalars('Macro_averaged_f1_score',
                    { 'Validation' : validation_f1_score},
                    epoch_number + 1)
    
    writer.flush()
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = f'model_{timestamp}_{epoch_number}'
        torch.save(cnn.state_dict(), model_path)
    
    epoch_number += 1

# Save the model

In [None]:
torch.save(cnn.state_dict(), "../saved_models/cnn_1.pt")