In [None]:
import numpy as np 
import pandas as pd 
import os
from pandas.api.types import is_numeric_dtype
from typing import List, Iterable, Callable, Tuple

import sklearn
from sklearn import metrics
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix
from sklearn.metrics import precision_recall_curve, auc, roc_curve, recall_score


import matplotlib.pyplot as plt
from plotnine import *

from tqdm import tqdm
from tqdm.notebook import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

from collections import OrderedDict
from typing import Dict, List, Callable, Union

import copy
import random
import time

SEED = 1729

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
abundance_full_df = pd.read_csv(r"C:\Users\rongq\OneDrive\0 Rong\01 Documents\AMILI\amili-ml-probiotics\working_drafts\datasets\abundance.csv")
abundance_stool_full_df = pd.read_csv(r"C:\Users\rongq\OneDrive\0 Rong\01 Documents\AMILI\amili-ml-probiotics\working_drafts\datasets\abundance_stoolsubset.csv")
abundance_full_df.head()

In [None]:
disease_list = abundance_full_df.loc[:,'disease'].unique()

num_obs_list = [abundance_full_df.loc[abundance_full_df.disease == disease_list[i]
                                      ].shape[0] for i in range(len(disease_list))]
disease_obs_dict = {disease_list[i]:num_obs_list[i] for i in range(len(
                                                            disease_list))}
print(f" Number of observations for each disease:\n" ,disease_obs_dict)
low_data_diseases = [d for d in disease_list if disease_obs_dict[d] < 60]
print('\n diseases with less 60 samples:\n',low_data_diseases)

In [None]:
### merge 'n', 'nd', 'leaness' into 'control'
abundance_full_df.loc[:,'disease'] = abundance_full_df['disease'].apply(
    lambda x: 'control' if ((x == 'n') or (x == 'nd') or (x == 'leaness')) else x)

### merge 'ibd_crohn_disease' and 'ibd_ulcerative_colitis' into 'ibd'.
abundance_full_df.loc[:,'disease'] = abundance_full_df['disease'].apply(
    lambda x: 'ibd' if ('ibd' in x) else x)

### merge 'small_adenoma' and 'large_adenoma' into 'adenoma'.
abundance_full_df.loc[:,'disease'] = abundance_full_df['disease'].apply(
    lambda x: 'adenoma' if ('adenoma' in x) else x)

### retaining data only for the following diseases
diseases = ['control', 'obesity', 'ibd', 'stec2-positive', 
            'impaired_glucose_tolerance', 'cirrhosis', 't2d', 'cancer', 'adenoma']
abundance_df = abundance_full_df.loc[
                            abundance_full_df['disease'].isin(diseases)]

print(f'original dataframe shape is:  {abundance_full_df.shape}')
print(f'selected dataframe shape is: {abundance_df.shape}')
disease_list = abundance_df['disease'].unique()
print(f"total number of diseases is: {disease_list.size}")

num_obs_list = [abundance_df.loc[abundance_df.disease == disease_list[i]
                                      ].shape[0] for i in range(len(disease_list))]
disease_obs_dict = {disease_list[i]:num_obs_list[i] for i in range(len(
                                                            disease_list))}
print(f"# of observations for each disease:\n",disease_obs_dict)

In [None]:
disease_index_dict = {d:diseases.index(d) for d in diseases}
print(f'associating an id with each disease:')
print(disease_index_dict)
abundance_df['disease_id'] = abundance_df['disease'].apply(
                                    lambda x: diseases.index(x))

cols = abundance_df.columns.tolist()
species = [x for x in cols if x.startswith('k_')]
print(f'number of species is {len(species)}')
metadata = [x for x in cols if not x.startswith('k_')]
print(f'number of metadata columns is {len(metadata)}')

species_df = abundance_df.loc[:,species].copy()
species_df = species_df.astype('float32')
abundance_df = pd.concat([abundance_df.loc[:,metadata], species_df], axis = 1)

In [None]:
disease_ids = np.arange(1, len(diseases)).tolist()
test_disease_ids = random.sample(disease_ids, 2)
print(f'test diseases:', [diseases[id] for id in test_disease_ids])

remaining_disease_ids = [id for id in disease_ids if id not in test_disease_ids]
valid_disease_ids = random.sample(remaining_disease_ids, 2)
print(f'validation diseases:', [diseases[id] for id in valid_disease_ids])

train_disease_ids = [diseases.index(d) 
          for d in diseases if diseases.index(d) not in test_disease_ids and 
                     diseases.index(d) not in valid_disease_ids]
print(f'train diseases:', [diseases[id] for id in train_disease_ids])

train_df = abundance_df.loc[abundance_df['disease_id'].isin(
                                train_disease_ids)]
print(train_df['disease_id'].unique())
valid_df = abundance_df.loc[abundance_df['disease_id'].isin(
                                valid_disease_ids)]
print(valid_df['disease_id'].unique())
test_df = abundance_df.loc[abundance_df['disease_id'].isin(
                                test_disease_ids)]
print(test_df['disease_id'].unique())

In [None]:
diseased_train_df = train_df.loc[train_df['disease_id'] != diseases.index('control')]
print('train df has the following diseases:')
print(diseased_train_df.disease.unique())
counts, bins, histplot = plt.hist(diseased_train_df[species].std(), bins=20)
select_species = np.array(species)[diseased_train_df[species].std().values>1.]
print(f"number of total species are: {len(species)}")
print(f"number of selected species are: {select_species.size}")

select_species_id = np.append(select_species, ['disease_id'])
train_df = train_df.loc[:,select_species_id]
train_df['id'] = np.arange(train_df.shape[0])
valid_df = valid_df.loc[:,select_species_id]
valid_df['id'] = np.arange(valid_df.shape[0])
test_df = test_df.loc[:,select_species_id]
test_df['id'] = np.arange(test_df.shape[0])

In [None]:
class NShotTaskSampler(data.Sampler):
    def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 episodes_per_epoch: int = None,
                 n: int = None,
                 k: int = None,
                 q: int = None,
                 num_tasks: int = 1):
        """PyTorch Sampler subclass that generates batches of n-shot, k-way, 
        q-query tasks.

        Each n-shot task contains a "support set" of `k` sets of `n` samples and 
        a "query set" of `k` sets of `q` samples. The support set and the query set 
        are all grouped into one Tensor such that the first n * k samples are from 
        the support set while the remaining q * k samples are from the query set.

        The support and query sets are sampled such that they are disjoint 
        i.e. do not contain overlapping samples.

        # Arguments
            dataset: Instance of torch.utils.data.Dataset from which to draw samples
            episodes_per_epoch: Arbitrary number of batches of n-shot tasks to 
                                generate in one epoch
            n_shot: int. Number of samples for each class in the n-shot 
                            classification tasks.
            k_way: int. Number of classes in the n-shot classification tasks.
            q_queries: int. Number query samples for each class in the n-shot 
                            classification tasks.
            num_tasks: Number of n-shot tasks to group into a single batch
        """
        super(NShotTaskSampler, self).__init__(dataset)
        self.episodes_per_epoch = episodes_per_epoch
        self.dataset = dataset
        if num_tasks < 1:
            raise ValueError('num_tasks must be > 1.')

        self.num_tasks = num_tasks
        # TODO: Raise errors if initialise badly
        self.k = k
        self.n = n
        self.q = q
        
    def __len__(self):
        return self.episodes_per_epoch

    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            batch = []

            for task in range(self.num_tasks):
                # Get random classes
                episode_classes = np.random.choice(
                    self.dataset.data_frame['disease_id'].unique(), size=self.k, 
                                            replace=False)

                df = self.dataset.data_frame[
                        self.dataset.data_frame['disease_id'].isin(episode_classes)]

                support_k = {k: None for k in episode_classes}
                for k in episode_classes:
                    # Select support examples
                    support = df[df['disease_id'] == k].sample(self.n)
                    support_k[k] = support

                    for i, s in support.iterrows():
                        batch.append(s['id'])

                for k in episode_classes:
                    query = df[(df['disease_id'] == k) & (
                        ~df['id'].isin(support_k[k]['id']))].sample(self.q)
                    for i, q in query.iterrows():
                        batch.append(q['id'])

            yield np.stack(batch)
            

class SpeciesAbundanceDataset(data.Dataset):
    """Species Abundance dataset."""

    def __init__(self, abundance_df, species_columns, target_column):
        """
        Args:
        """
        self.data_frame = abundance_df
        self.species_cols = species_columns
        self.target_col = target_column
        
        
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            target = self.data_frame[self.target_col].iloc[idx].to_numpy(dtype='int')
        elif type(idx) != int:
            idx = int(idx)
            target = self.data_frame[self.target_col].iloc[idx]

        species_abundance = self.data_frame[self.species_cols].iloc[idx].to_numpy(
                                                                    dtype='float32')
        sample = (species_abundance, target)

        return sample

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, layer_size: int = 128):
        super().__init__()

        self.input_fc = nn.Linear(input_dim, layer_size)
        self.output_fc = nn.Linear(layer_size, output_dim)

    def forward(self, x):

        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        h_1 = F.relu(self.input_fc(x))
        y_pred = self.output_fc(h_1)

        return y_pred
    
    
    def functional_forward(self, x, weights):
        """Applies the same forward pass using PyTorch functional 
        operators using a specified set of weights."""

        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        x = F.linear(x, weights['input_fc.weight'], weights['input_fc.bias'])
        x = F.relu(x)
        x = F.linear(x, weights['output_fc.weight'], weights['output_fc.bias'])

        return x

In [None]:
def create_nshot_task_label(k: int, q: int) -> torch.Tensor:
    """Creates an n-shot task label.

    Label has the structure:
        [0]*q + [1]*q + ... + [k-1]*q

    # Arguments
        k: Number of classes in the n-shot classification task
        q: Number of query samples for each class in the n-shot classification task

    # Returns
        y: Label vector for n-shot task of shape [q * k, ]
    """
    y = torch.arange(0, k, 1 / q).long()

    return y

def prepare_meta_batch(n, k, q, meta_batch_size):
    def prepare_meta_batch_(batch):
        x, y = batch
        # Reshape to `meta_batch_size` number of tasks. Each task contains
        # n*k support samples to train the fast model on and q*k query samples to
        # evaluate the fast model on and generate meta-gradients
        x = x.reshape(meta_batch_size, n*k + q*k, x.shape[-1])
        # Move to device
        x = x.double().to(device)
        # Create label
        y = create_nshot_task_label(k, q).to(device).repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_


def replace_grad(parameter_gradients, parameter_name):
    def replace_grad_(module):
        return parameter_gradients[parameter_name]

    return replace_grad_

def meta_gradient_step(model: nn.Module,
                       optimizer: optim.Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimizer: Optimizer to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task.
        The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML 
        (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating 
            through the gradients of the updated weights on the query with respect 
            to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights 
                            during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch in x:
        # By construction x is a 5D tensor of shape: 
        # (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension 
        # we are iterating through the meta batches
        x_task_train = meta_batch[:n_shot * k_way]
        x_task_val = meta_batch[n_shot * k_way:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(k_way,n_shot).to(device)
            y_pred = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(y_pred, y)
            gradients = torch.autograd.grad(loss, fast_weights.values(), 
                                            create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients))

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(k_way,q_queries).to(device)
        y_pred = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(y_pred, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_prob = y_pred.softmax(dim=1)
        task_predictions.append(y_prob)

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss, fast_weights.values(), 
                                        create_graph=create_graph)
        named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), 
                                                         gradients)}
        task_gradients.append(named_grads)

    if order == 1:
        if train:
            sum_task_gradients = {k: torch.stack(
                [grad[k] for grad in task_gradients]).mean(dim=0)
                                  for k in task_gradients[0].keys()}
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients, name))
                )

            model.train()
            optimizer.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            y_pred = model(torch.zeros((k_way, ) + data_shape).to(device, 
                                                              dtype=torch.double))
            loss = loss_fn(y_pred, create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimizer.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimizer.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimizer.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def calculate_accuracy(y_pred, y):
    return torch.eq(y_pred.argmax(dim=-1), y).sum().item() / y_pred.shape[0]

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def train(model: nn.Module, 
            taskloader: data.DataLoader, 
            optimizer: optim.Optimizer, 
            loss_fn: Callable, 
            prepare_batch: Callable, 
            **kwargs):
    seen = 0
    total_loss = 0
    total_acc = 0

    for batch_index, batch in enumerate(taskloader):
        x, y = prepare_batch(batch)

        loss, y_pred = meta_gradient_step(
            model,
            optimizer,
            loss_fn,
            x,
            y,
            **kwargs)

        seen += y_pred.shape[0]

        total_loss += loss.item() * y_pred.shape[0]
        total_acc += calculate_accuracy(y_pred, y) * y_pred.shape[0]

    total_loss = total_loss/seen
    total_acc = total_acc/seen
    return total_loss, total_acc

def evaluate(model: nn.Module, 
            taskloader: data.DataLoader, 
            optimizer: optim.Optimizer, 
            loss_fn: Callable, 
            prepare_batch: Callable, 
            **kwargs):
    seen = 0
    total_loss = 0
    total_acc = 0
    for batch_index, batch in enumerate(taskloader):
        x, y = prepare_batch(batch)

        loss, y_pred = meta_gradient_step(
            model,
            optimizer,
            loss_fn,
            x,
            y,
            **kwargs)

        seen += y_pred.shape[0]

        total_loss += loss.item()* y_pred.shape[0]
        total_acc += calculate_accuracy(y_pred, y) * y_pred.shape[0]

    total_loss = total_loss/seen
    total_acc = total_acc/seen
    return total_loss, total_acc

In [None]:
### 3-shot, 2-way classification ###
n_shot = 3
k_way = 2
q_queries = n_shot
meta_batch_size = 2

INPUT_DIM = len(select_species)
OUTPUT_DIM = k_way
LAYER_SIZE = 128
inner_lr = 0.01
meta_lr = 0.001
epoch_len = 800
epochs = 30
first_order_epochs = 10 #int(0.4*epochs)
second_order_epochs = epochs - first_order_epochs
eval_batches = 80
inner_train_steps = 5
inner_val_steps = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
meta_model = MLP(INPUT_DIM, OUTPUT_DIM, LAYER_SIZE).to(device, dtype=torch.double)
meta_optimizer = optim.Adam(meta_model.parameters(), lr=meta_lr)
loss_fn = nn.CrossEntropyLoss().to(device)

print(f'The model has {count_parameters(meta_model):,} trainable parameters')

train_dataset = SpeciesAbundanceDataset(train_df, select_species, 'disease_id')
valid_dataset = SpeciesAbundanceDataset(valid_df, select_species, 'disease_id')
test_dataset = SpeciesAbundanceDataset(test_df, select_species, 'disease_id')

train_taskloader = data.DataLoader(
    train_dataset,
    batch_sampler=NShotTaskSampler(train_dataset, epoch_len, n=n_shot, 
                                   k=k_way, q=q_queries,
                                   num_tasks=meta_batch_size),
    num_workers=1
)

valid_taskloader = data.DataLoader(
    valid_dataset,
    batch_sampler=NShotTaskSampler(valid_dataset, eval_batches, n=n_shot, 
                                   k=k_way, q=q_queries,
                                   num_tasks=meta_batch_size),
    num_workers=1
)

test_taskloader = data.DataLoader(
    test_dataset,
    batch_sampler=NShotTaskSampler(test_dataset, eval_batches, n=n_shot, 
                                   k=k_way, q=q_queries,
                                   num_tasks=meta_batch_size),
    num_workers=1
)


train_kwargs = {'n_shot':n_shot, 'k_way': k_way, 'q_queries': q_queries,
                'inner_train_steps': inner_train_steps, 'inner_lr': inner_lr,
                'order': 1, 'train': True, 'device': device} 

valid_kwargs = {'n_shot':n_shot, 'k_way': k_way, 'q_queries': q_queries,
                'inner_train_steps': inner_val_steps, 'inner_lr': inner_lr,
                'order': 1, 'train': True, 'device': device}

test_kwargs = {'n_shot':n_shot, 'k_way': k_way, 'q_queries': q_queries,
                'inner_train_steps': inner_val_steps, 'inner_lr': inner_lr,
                'order': 1, 'train': True, 'device': device}

In [None]:
best_valid_loss = float('inf')
for epoch in trange(first_order_epochs):
    start_time = time.monotonic()

    train_loss, train_acc = train(meta_model, 
                                  train_taskloader, 
                                  meta_optimizer, 
                                  loss_fn, 
                                  prepare_meta_batch(n_shot, k_way, q_queries, 
                                                    meta_batch_size),
                                  **train_kwargs
                                 )
    
    valid_loss, valid_acc = evaluate(meta_model, 
                                  valid_taskloader, 
                                  meta_optimizer, 
                                  loss_fn, 
                                  prepare_meta_batch(n_shot, k_way, q_queries, 
                                                    meta_batch_size),
                                  **valid_kwargs
                                 )

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(meta_model.state_dict(),'MLP-classifier.pt')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02}/{first_order_epochs} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f}| Best Val. Loss:'
          f'{best_valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

for epoch in trange(second_order_epochs):
    start_time = time.monotonic()

    train_kwargs['order'] = 2
    train_loss, train_acc = train(meta_model, 
                                  train_taskloader, 
                                  meta_optimizer, 
                                  loss_fn, 
                                  prepare_meta_batch(n_shot, k_way, q_queries, 
                                                    meta_batch_size),
                                  **train_kwargs
                                 )

    valid_kwargs['order'] = 2    
    valid_loss, valid_acc = evaluate(meta_model, 
                                  valid_taskloader, 
                                  meta_optimizer, 
                                  loss_fn, 
                                  prepare_meta_batch(n_shot, k_way, q_queries, 
                                                    meta_batch_size),
                                  **test_kwargs
                                 )

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(meta_model.state_dict(),'MLP-classifier.pt')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02}/{second_order_epochs} | Epoch Time: '
          f'{epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f}| Best Val. Loss: '
          f'{best_valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

In [None]:
def calculate_metrics(y_pred, y):
    ### calculate confusion matrix, f1_score, roc_auc score
    pred_labels = y_pred.argmax(dim=-1).cpu()
    true_labels = y.cpu()
    cm = metrics.confusion_matrix(true_labels, pred_labels)
    f1_score = metrics.f1_score(true_labels, pred_labels)
    roc_auc_score = metrics.roc_auc_score(true_labels, pred_labels)
    return cm, f1_score, roc_auc_score

def evaluate_metrics(model: nn.Module, 
            taskloader: data.DataLoader, 
            optimizer: optim.Optimizer, 
            loss_fn: Callable, 
            prepare_batch: Callable, 
            **kwargs):
    seen = 0
    total_loss = 0
    total_acc = 0
    
    for batch_index, batch in enumerate(taskloader):
        x, y = prepare_batch(batch)

        loss, y_pred = meta_gradient_step(
            model,
            optimizer,
            loss_fn,
            x,
            y,
            **kwargs)

        seen += y_pred.shape[0]

        total_loss += loss.item()* y_pred.shape[0]
        total_acc += calculate_accuracy(y_pred, y) * y_pred.shape[0]
        if batch_index == 0:
            total_cm, avg_f1_score, avg_roc_auc_score = calculate_metrics(y_pred,
                                                                             y)
        else:
            cm, f1_score, roc_auc_score = calculate_metrics(y_pred, y)
            total_cm += cm
            avg_f1_score += f1_score
            avg_roc_auc_score += roc_auc_score
        
    total_loss = total_loss/seen
    total_acc = total_acc/seen
    avg_f1_score = avg_f1_score/batch_index
    avg_roc_auc_score = avg_roc_auc_score/batch_index
    return total_loss, total_acc, total_cm, avg_f1_score, avg_roc_auc_score

In [None]:
meta_model.load_state_dict(torch.load('MLP-classifier.pt'))

test_kwargs['order'] = 2    
eval_loss, eval_acc, eval_cm, eval_f1, eval_roc_auc = evaluate_metrics(
                                meta_model, 
                                test_taskloader, 
                                meta_optimizer, 
                                loss_fn, 
                                prepare_meta_batch(n_shot, k_way, q_queries, 
                                                            meta_batch_size),
                                **test_kwargs)

print(f'Test Loss: {eval_loss:.3f}| Best Val. Loss: '
      f'{best_valid_loss:.3f} |  Test. Acc: {eval_acc*100:.2f}%')
print(f'Test F1 score: {eval_f1} | Test ROC AUC score: {eval_roc_auc}')
print(f'Confusion matrix:')
print(eval_cm)