# BIRDClef 2023 Training with Supervised Contrastive Learning and PyTorch Lightning

Welcome to my Kaggle notebook! In this notebook, I'll take you through the process of training a model for the BIRDClef 2023 competition using supervised contrastive learning and PyTorch Lightning.

## Notebook Overview

- **Training Approach:** I am using supervised contrastive learning, a powerful technique that enhances the performance of models in classification tasks.
- **Framework:** The training is implemented using PyTorch Lightning, which simplifies the training process and provides helpful abstractions.

## How to Use This Notebook

Feel free to explore the code, experiment with different hyperparameters, and adapt the model for your own use cases. If you find this notebook helpful, please consider upvoting it to show your support.

**Notebook Credits:**
- This notebook is based on the work of [
Nischay Dhankhar](https://www.kaggle.com/nischaydnk/) in their notebook https://www.kaggle.com/nischaydnk/code where I adapted and extended their code for the BIRDClef 2023 competition.


## Upvote if Useful

If you find this notebook valuable or informative, please consider giving it an upvote. Your feedback and support are greatly appreciated!

Happy coding!

In [1]:
!pip install timm
!pip install pytorch-lightning



In [2]:
#Implementing Supervised Constrastive Learning using Pytorch Lightning with MNIST Dataset

# Stage 1: Training the Enconder

#import
import os
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric
import timm
import torch.nn as nn
import pandas as pd
import numpy as np



In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
class Config:
    use_aug = False
    num_classes = 264
    batch_size = 64
    epochs = 12
    PRECISION = 16    
    PATIENCE = 8    
    seed = 2023
    model = "tf_efficientnet_b0_ns"
    pretrained = True            
    weight_decay = 1e-3
    use_mixup = True
    mixup_alpha = 0.2   
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    

    data_root = "/kaggle/input/birdclef-2023/"
    train_images = "/kaggle/input/split-creating-melspecs-stage-1/specs/train/"
    valid_images = "/kaggle/input/split-creating-melspecs-stage-1/specs/valid/"
    train_path = "/kaggle/input/bc2023-train-val-df/train.csv"
    valid_path = "/kaggle/input/bc2023-train-val-df/valid.csv"
    
    
    SR = 32000
    DURATION = 5
    MAX_READ_SAMPLES = 5
    LR = 5e-4

In [5]:
pl.seed_everything(Config.seed, workers=True)

2023

In [6]:
df_train = pd.read_csv(Config.train_path)
df_valid = pd.read_csv(Config.valid_path)
df_train.head()

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename,len_sec_labels,path,frames,sr,duration
0,yebapa1,[],['song'],-3.3923,36.7049,Apalis flavida,Yellow-breasted Apalis,isaac kilusu,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/422175,yebapa1/XC422175.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,405504,32000,12.672
1,yebapa1,[],['song'],-0.6143,34.0906,Apalis flavida,Yellow-breasted Apalis,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/289562,yebapa1/XC289562.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,796630,32000,24.894687
2,combuz1,[],['call'],51.8585,-8.2699,Buteo buteo,Common Buzzard,Irish Wildlife Sounds,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/626969,combuz1/XC626969.ogg,0,/kaggle/input/birdclef-2023/train_audio/combuz...,254112,32000,7.941
3,chibat1,['laudov1'],"['adult', 'sex uncertain', 'song']",-33.1465,26.4001,Batis molitor,Chinspot Batis,Lynette Rudman,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/664196,chibat1/XC664196.ogg,1,/kaggle/input/birdclef-2023/train_audio/chibat...,1040704,32000,32.522
4,carcha1,[],['song'],-34.011,18.8078,Cossypha caffra,Cape Robin-Chat,Shannon Ronaldson,Creative Commons Attribution-NonCommercial-Sha...,1.0,https://www.xeno-canto.org/322333,carcha1/XC322333.ogg,0,/kaggle/input/birdclef-2023/train_audio/carcha...,40124,32000,1.253875


In [7]:
CLASS_23 = sorted(os.listdir('/kaggle/input/birdclef-2023/train_audio'))
NUM_CLASSES_23 = len(CLASS_23)
LABEL2NAME23 = dict(zip(list(range(NUM_CLASSES_23)), CLASS_23))
NAME2LABEL23 = {v:k for k, v in LABEL2NAME23.items()}

In [8]:
df_train['target'] = df_train.primary_label.map(NAME2LABEL23)
df_valid['target'] = df_valid.primary_label.map(NAME2LABEL23)
df_train.head()

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename,len_sec_labels,path,frames,sr,duration,target
0,yebapa1,[],['song'],-3.3923,36.7049,Apalis flavida,Yellow-breasted Apalis,isaac kilusu,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/422175,yebapa1/XC422175.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,405504,32000,12.672,249
1,yebapa1,[],['song'],-0.6143,34.0906,Apalis flavida,Yellow-breasted Apalis,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/289562,yebapa1/XC289562.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,796630,32000,24.894687,249
2,combuz1,[],['call'],51.8585,-8.2699,Buteo buteo,Common Buzzard,Irish Wildlife Sounds,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/626969,combuz1/XC626969.ogg,0,/kaggle/input/birdclef-2023/train_audio/combuz...,254112,32000,7.941,73
3,chibat1,['laudov1'],"['adult', 'sex uncertain', 'song']",-33.1465,26.4001,Batis molitor,Chinspot Batis,Lynette Rudman,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/664196,chibat1/XC664196.ogg,1,/kaggle/input/birdclef-2023/train_audio/chibat...,1040704,32000,32.522,66
4,carcha1,[],['song'],-34.011,18.8078,Cossypha caffra,Cape Robin-Chat,Shannon Ronaldson,Creative Commons Attribution-NonCommercial-Sha...,1.0,https://www.xeno-canto.org/322333,carcha1/XC322333.ogg,0,/kaggle/input/birdclef-2023/train_audio/carcha...,40124,32000,1.253875,60


In [9]:
Config.num_classes = len(df_train.primary_label.unique())

In [10]:
from torchvision.transforms.functional import to_pil_image, to_tensor

class BirdDataset(torch.utils.data.Dataset):

    def __init__(self, df, sr = Config.SR, duration = Config.DURATION, transforms = None, train = True):
        self.df = df
        self.sr = sr 
        self.train = train
        self.duration = duration
        self.transforms = transforms
        if train:
            self.img_dir = Config.train_images
        else:
            self.img_dir = Config.valid_images

    def __len__(self):
        return len(self.df)

    @staticmethod
    def normalize(image):
        image = image / 255.0
        return image

    def __getitem__(self, idx):

        row = self.df.iloc[idx]
        
        impath = self.img_dir + f"{row.filename}.npy"

        image = np.load(str(impath))[:Config.MAX_READ_SAMPLES]
        
        image = image[0]
        
        image = to_pil_image(image)
        
        images = self.transforms(image)
        
        return images, torch.tensor(row.target)


In [11]:
# Transformation to get multiple versions of a image (anchor, positives)
class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

In [12]:
class To3Channel:
    def __call__(self, img):

        stacked_img = torch.stack([img[0], img[0], img[0]])
    
        return stacked_img

In [13]:
import albumentations as A
#DataLoader
class BirdCLEFDataModuleSCL(pl.LightningDataModule):
    def __init__(self, train_df, valid_df, batch_size, num_workers):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.batch_size = batch_size
        self.num_workers = num_workers

    # Single GPU, We download the data here
    def prepare_data(self):
        pass

    # Multiple GPUs
    def setup(self, stage):
        transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                       transforms.RandomGrayscale(p=0.2), 
                                       transforms.ToTensor(),
                                       To3Channel()])

        self.train_ds = BirdDataset(self.train_df, transforms = TwoCropTransform(transform))
        self.valid_ds = BirdDataset(self.valid_df, transforms = TwoCropTransform(transform), train = False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

In [14]:
# SupConLoss: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [15]:
# Model
class Encoder(pl.LightningModule):
    def __init__(self, model_name, emb_dim):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained = True)
        self.in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(self.in_features, emb_dim)
        self.loss_fn = SupConLoss(0.07, 'one', 0.07)


    def forward(self, x):
        emb = self.backbone(x)
        return emb

    # Difference between Normal and Lightning: The train, valid and test steps is written here inside the class
    def training_step(self, batch, batch_idx):
        images, labels = batch

        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)

        #print(images.shape)

        features = self.forward(images)

        # Manipulating the features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)

        return loss

    # We have training_epoch_end function
    # def on_train_epoch_end(self):
    #     #print("Epoch Done")

    def validation_step(self, batch , batch_idx):

        images, labels = batch

        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)
        
        #print(images.shape)

        features = self.forward(images)

        # Manipulating the arrangment of features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)

        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)

        return loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        
        bsz = len(labels)

        images = torch.cat([images[0], images[1]], dim=0)
        

        features = self.forward(images)

        # Manipulating the arrangment of features for SupConLoss
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)

        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)

        # Calculating SupConLoss
        loss = self.loss_fn(features,labels)
        return loss

    # We can add schedulers to this method
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = 0.001)

In [16]:
device = torch.device("cuda" if torch.backends.mps.is_available() else 'cpu')

# HP
emb_dim = 128
learning_rate = 0.001
num_epochs = 2
backbone = "resnet50"

#Data Loading
dm = BirdCLEFDataModuleSCL(df_train, df_valid, batch_size = Config.batch_size, num_workers = 1)

# Init Model
model = Encoder(backbone, emb_dim).to(device)

# Trainer
trainer = pl.Trainer(accelerator = "cuda", devices = [0], min_epochs = 1, max_epochs = num_epochs, precision = 16)
trainer.fit(model, dm)

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [17]:
# Save the model checkpoint
model_checkpoint_path = './birdclef_supconencoder.ckpt'
trainer.save_checkpoint(model_checkpoint_path)

In [18]:
#DataLoader
class BirdCLEFDataModuleCE(pl.LightningDataModule):
    def __init__(self, train_df, valid_df, batch_size, num_workers):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        

    # Single GPU, We download the data here
    def prepare_data(self):
        pass

    # Multiple GPUs
    def setup(self, stage):
        transform = transforms.Compose([transforms.ToTensor(),
                                       To3Channel()])

        self.train_ds = BirdDataset(self.train_df, transforms = transform)
        self.valid_ds = BirdDataset(self.valid_df, transforms = transform, train = False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

    def test_dataloader(self):
        return DataLoader(self.valid_ds, batch_size = self.batch_size,
                          num_workers = self.num_workers, shuffle = False)

In [19]:
import sklearn.metrics

def padded_cmap(solution, submission, padding_factor=5):
    solution = solution
    submission = submission
    new_rows = []
    for i in range(padding_factor):
        new_rows.append([1 for i in range(len(solution.columns))])
    new_rows = pd.DataFrame(new_rows)
    new_rows.columns = solution.columns
    padded_solution = pd.concat([solution, new_rows]).reset_index(drop=True).copy()
    padded_submission = pd.concat([submission, new_rows]).reset_index(drop=True).copy()
    score = sklearn.metrics.average_precision_score(
        padded_solution.values,
        padded_submission.values,
        average='macro',
    )
    return score

def map_score(solution, submission):
    solution = solution
    submission = submission
    score = sklearn.metrics.average_precision_score(
        solution.values,
        submission.values,
        average='micro',
    )
    return score

In [20]:
import torch
import torch.nn as nn
import pickle
import pytorch_lightning as pl
from torch.optim import Adam



class SupConCE(pl.LightningModule):
    def __init__(self,):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task = 'multiclass', num_classes = 264)
        self.f1_score = torchmetrics.F1Score(task = 'multiclass', num_classes = 264)

        model_path = '/kaggle/working/birdclef_supconencoder.ckpt'
        pretrained_model = Encoder.load_from_checkpoint(model_path, model_name = backbone, emb_dim = 128)


        #Freezing all the encoder layers
        for param in pretrained_model.parameters():
            param.requires_grad = False


        #Trainging only the last layer
        pretrained_model.backbone.fc = nn.Linear(in_features=pretrained_model.backbone.fc.in_features, out_features=264)

        pretrained_model.backbone.fc.requires_grad = True

        self.model = pretrained_model


    def forward(self, x):
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        self.log_dict({'train_loss': loss, 'train_accuracy': accuracy, 'train_f1_score': f1_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        
        one_hot_target = F.one_hot(labels, num_classes=264)
        
        y_pred = pd.DataFrame(y_pred.cpu().detach().numpy())
        y_true = pd.DataFrame(one_hot_target.cpu().detach().numpy())
        
        cmap_score = padded_cmap(y_true, y_pred)
        
        self.log_dict({'valid_loss': loss, 'valid_accuracy': accuracy, 'valid_f1_score': f1_score, 'cmap_score': cmap_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def test_step(self, batch, batch_idx):
        images, labels = batch
        y_pred = self.forward(images)
        loss = self.loss_fn(y_pred,labels)
        accuracy = self.accuracy(y_pred,labels)
        f1_score = self.f1_score(y_pred,labels)
        
        one_hot_target = F.one_hot(labels, num_classes=264)
        
        y_pred = pd.DataFrame(y_pred.cpu().detach().numpy())
        y_true = pd.DataFrame(one_hot_target.cpu().detach().numpy())
        
        cmap_score = padded_cmap(y_true, y_pred)
        self.log_dict({'test_loss': loss, 'test_accuracy': accuracy, 'test_f1_score': f1_score, 'cmap_score': cmap_score},
                      on_step = False, on_epoch = True, prog_bar = True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = 0.001)



dm_ce = BirdCLEFDataModuleCE(df_train, df_valid, batch_size = Config.batch_size, num_workers = 1)

final_model = SupConCE().to(device)

trainer_ce = pl.Trainer(accelerator = "cuda", devices = [0], min_epochs = 1, max_epochs = 1, precision = 16)

trainer_ce.fit(final_model, dm_ce)

trainer_ce.validate(final_model, dm_ce)

trainer_ce.test(final_model, dm_ce)

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 4.297344207763672,
  'test_accuracy': 0.22025391459465027,
  'test_f1_score': 0.22025391459465027,
  'cmap_score': 0.9632693529129028}]

In [21]:
# Save the model checkpoint
model_checkpoint_path = './birdclef_supconmodel.ckpt'
trainer_ce.save_checkpoint(model_checkpoint_path)