In [2]:
from learner import Learner

import os

import warnings
warnings.filterwarnings(action='ignore')

import pandas as pd
import librosa
import numpy as np
import torch
from PIL import Image
from sklearn.utils import shuffle
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

In [4]:
config = [
    ('conv2d', [32, 1, 7, 7, 1, 3]),
    ('relu', [True]),
    ('bn', [32]),
    ('conv2d', [16, 32, 5, 5, 1, 2]),
    ('relu', [True]),
    ('bn', [16]),
    ('conv2d', [8, 16, 3, 3, 1, 1]),
    ('relu', [True]),
    ('bn', [8]),
    ('flatten', []),
    ('linear', [5, 8 * 28 * 28])
]

In [5]:
from birdCallNShot import BirdCallNShot

db_train = BirdCallNShot('birdCall',
                       batchsz=32,
                       n_way=5,
                       k_shot=80,
                       k_query=20,
                       imgsz = 28)

DB: train (99, 100, 28, 28, 1) test (33, 100, 28, 28, 1)


In [21]:
x_spt, y_spt, x_qry, y_qry = map(lambda x: torch.from_numpy(x).to(torch.device('cpu')), db_train.next())
device = torch.device('cpu')

In [18]:
x_spt.shape

torch.Size([32, 400, 1, 28, 28])

In [28]:
def check_accuracy(data, labels, model, idx = 0):
    # print('Checking accuracy on set')   
    num_correct = 0
    num_samples = 0
    num_data_points = data.shape[0]
    
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        x, y = data[idx], labels[idx]
        x = x.to(device=torch.device("cpu"), dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=torch.device("cpu"), dtype=torch.long)
        scores = model(x)
        _, preds = scores.max(1)
        num_correct += (preds == y).sum()
        num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        # print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
        return acc
        
def train(data, labels, model, optimizer, epochs=50, idx = 0):
    model = model.to(device=torch.device("cpu"))  # move the model parameters to CPU/GPU    
    test_accuracies = []
    for e in range(epochs):
        # print(f"Epoch number: {e}")
        x, y = data[idx], labels[idx]
        model.train()  # put model to training mode
        x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=device, dtype=torch.long)

        scores = model(x)
        loss = F.cross_entropy(scores, y)

        # Zero out all of the gradients for the variables which the optimizer
        # will update.
        optimizer.zero_grad()

        # This is the backwards pass: compute the gradient of the loss with
        # respect to each  parameter of the model.
        loss.backward()

        # Actually update the parameters of the model using the gradients
        # computed by the backwards pass.
        
        
        optimizer.step()

        param_idx = 0
        
        test_accuracies.append(check_accuracy(x_qry, y_qry, model, idx))
    return test_accuracies

In [35]:
lr = 0.001

models = []
models = [Learner(config, 1, 28) for _ in range(32)]
test_accs = []
for idx in range(32):
    print(f"On iteration {idx}")
    models[idx].eval()
    optimizer = optim.SGD(models[idx].parameters(), lr=lr, momentum=0.9, nesterov=True)
    test_accs.append(train(x_spt, y_spt, models[idx], optimizer, epochs = 50, idx = idx))

avg_test_accuracies = np.average(np.array(test_accs), axis = 0)

On iteration 0
On iteration 1
On iteration 2
On iteration 3
On iteration 4
On iteration 5
On iteration 6
On iteration 7
On iteration 8
On iteration 9
On iteration 10
On iteration 11
On iteration 12
On iteration 13
On iteration 14
On iteration 15
On iteration 16
On iteration 17
On iteration 18
On iteration 19
On iteration 20
On iteration 21
On iteration 22
On iteration 23
On iteration 24
On iteration 25
On iteration 26
On iteration 27
On iteration 28
On iteration 29
On iteration 30
On iteration 31


In [36]:
avg_test_accuracies

array([0.2221875, 0.240625 , 0.264375 , 0.28875  , 0.320625 , 0.34625  ,
       0.3584375, 0.375    , 0.3875   , 0.40125  , 0.4125   , 0.4275   ,
       0.4403125, 0.4465625, 0.4565625, 0.466875 , 0.4746875, 0.483125 ,
       0.48625  , 0.4915625, 0.495    , 0.4990625, 0.5046875, 0.5071875,
       0.510625 , 0.511875 , 0.5184375, 0.52375  , 0.52625  , 0.5284375,
       0.528125 , 0.53     , 0.5325   , 0.535625 , 0.536875 , 0.5375   ,
       0.54125  , 0.54125  , 0.5434375, 0.546875 , 0.5490625, 0.5484375,
       0.549375 , 0.5509375, 0.5540625, 0.554375 , 0.5546875, 0.5575   ,
       0.559375 , 0.5603125])