# Setup

In [1]:
%load_ext extensions
%cd_repo_root

'/mnt/batch/tasks/shared/LS_root/mounts/clusters/rubchume1/code/Users/rubchume/VoiceCloningFakeAudioDetection'

In [2]:
import random
from pathlib import Path
from typing import Iterable, List

import mlflow
import numpy as np
import pandas as pd
from pydub import AudioSegment
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

import directory_structure



In [3]:
def reproduce_audio_file_with_pydub(audio_file):
    audio = AudioSegment.from_file(audio_file)
    display(audio)

# Load dataset

In [4]:
cloned_voices_path = ""
real_voices_path = directory_structure.data_path / "Common Voice/cv-corpus-15-delta-2023-09-08/en"
real_voices_info_file = real_voices_path / "validated.tsv"

In [5]:
real_info = pd.read_csv(real_voices_info_file, delimiter="\t")["path"].map(
    lambda path: str(real_voices_path / "clips" / path)
)
cloned_info = pd.Series([str(path) for path in Path("outputs/OOTB-YourTTS/TIMITexamples/").glob("*.wav")]).rename("path")


# Pytorch Lightning

In [6]:
from IPython.display import Audio
def reproduce_audio_from_pcm_samples(pcm_samples: np.array, sample_rate: int):
    audio = Audio(data=pcm_samples, rate=sample_rate, autoplay=True)
    display(audio)

In [7]:
class AudioBinaryDataset(Dataset):
    def __init__(
        self,
        negative_audio_files: Iterable,
        postive_audio_files: Iterable,
        target_sample_rate: int,
        num_samples: int,
        max_imbalance=1,
        random_seed=0,
    ):
        self.negative_audio_files = list(negative_audio_files)
        self.positive_audio_files = list(postive_audio_files)
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
        
        self.random_instance = random.Random(random_seed)
        
        negative_samples, positive_samples = self._undersample_unbalanced_dataset(
            self.negative_audio_files,
            self.positive_audio_files,
            max_imbalance
        )
        
        negative_samples_with_label = [
            (sample, 0)
            for sample in negative_samples
        ]
        
        positive_samples_with_label = [
            (sample, 1)
            for sample in negative_samples
        ]
        
        self.samples = self.random_instance.sample(
            negative_samples_with_label + positive_samples_with_label,
            len(negative_samples_with_label) + len(positive_samples_with_label)
        )
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        audio_file, label = self.samples[index]
        audio_segment = AudioSegment.from_file(audio_file)
        audio_resampled = audio_segment.set_frame_rate(self.target_sample_rate)
        pcm_samples = self._bytes_to_numpy(
            audio_resampled.raw_data,
            audio_resampled.sample_width
        )
        resized_samples = np.zeros(self.num_samples)
        resized_samples[:len(pcm_samples)] = pcm_samples[:self.num_samples]
        return torch.Tensor(resized_samples), label
    
    def _undersample_unbalanced_dataset(self, dataset_A: List, dataset_B: List, max_imbalance):
        if len(dataset_A) > len(dataset_B):
            dataset_big = dataset_A
            dataset_small = dataset_B
            a_bigger_than_b = True
        else:
            dataset_big = dataset_B
            dataset_small = dataset_A
            a_bigger_than_b = True
        
        if max_imbalance < 1:
            max_imbalance = 1 / max_imbalance
            
        max_samples = int(len(dataset_small) * max_imbalance)
        samples_big = self.random_instance.sample(dataset_big, min(max_samples, len(dataset_big)))
        samples_small = self.random_instance.sample(dataset_small, len(dataset_small))
        
        if a_bigger_than_b:
            return samples_big, samples_small
        else:
            return samples_small, samples_big
    
    @staticmethod
    def _bytes_to_numpy(bytes_stream: bytes, sample_width=2) -> np.array:
        """
        sample_width: number of bytes per sample
        """
        dtype_map = {
            1: np.int8,
            2: np.int16,
            4: np.int32
        }

        if sample_width not in dtype_map:
            raise ValueError(f"Unsupported sample width: {sample_width}")

        return np.frombuffer(bytes_stream, dtype=dtype_map[sample_width])

In [8]:
from torch.utils.data import DataLoader, random_split


class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, target_sample_rate, num_samples, cloned_samples: pd.Series, real_samples: pd.Series):
        super().__init__()
        self.batch_size = batch_size
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
        self.cloned_samples = cloned_samples
        self.real_samples = real_samples
    
    def prepare_data(self):
        dataset = AudioBinaryDataset(
            self.real_samples,
            self.cloned_samples,
            self.target_sample_rate,
            self.num_samples
        )
        
        self.dataset_training, self.dataset_validation, self.dataset_test = random_split(
            dataset,
            [0.7, 0.1, 0.2],
            generator=torch.Generator().manual_seed(0)
        )
        
    def train_dataloader(self):
        return DataLoader(self.dataset_training, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.dataset_validation, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.dataset_test, batch_size=self.batch_size)

In [39]:
from enum import Enum
from collections import defaultdict

import pytorch_lightning as pl
import torch.nn as nn
import torchmetrics


class Stage(Enum):
    TRAIN = "TRAIN"
    VALIDATION = "VALIDATION"
    TEST = "TEST"


class ClonedAudioDetector(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self._create_model()
        self._prepare_metrics()
        self.debug = False
        
    def _create_model(self):
        num_labels = 2

        label2id = dict(
            cloned=1,
            real=0)

        id2label = {
            1: "cloned",
            0: "real"
        }

        self.model = AutoModelForAudioClassification.from_pretrained(
            "facebook/wav2vec2-base",
            num_labels=num_labels,
            label2id=label2id,
            id2label=id2label
        )
             
    def _prepare_metrics(self):
        self.precision = torchmetrics.Precision(task='binary')
        self.recall = torchmetrics.Recall(task='binary')
        self.f1 = torchmetrics.F1Score(task='binary')
        self.confmat = torchmetrics.ConfusionMatrix(task="binary")

        self.targets_scores = {}
        self.targets_predicted = {}
        self.targets = {}
        
        self._reset_target_registries(Stage.TRAIN)
        self._reset_target_registries(Stage.VALIDATION)
        self._reset_target_registries(Stage.TEST)
 
    def forward(self, x):
        return self.model.forward(x)
    
    def criterion(self, logits, labels):
        return nn.functional.cross_entropy(logits, labels)
    
    def training_step(self, batch, batch_index):
        return self._step(batch, Stage.TRAIN)

    def validation_step(self, batch, batch_index):
        return self._step(batch, Stage.VALIDATION)
        
    def test_step(self, batch, batch_index):
        return self._step(batch, Stage.TEST)
    
    def on_train_epoch_start(self):
        self._reset_target_registries(Stage.TRAIN)
    
    def on_train_epoch_end(self):
        self._log_epoch_metrics(Stage.TRAIN)
    
    def on_validation_epoch_start(self):
        self._reset_target_registries(Stage.VALIDATION)
    
    def on_validation_epoch_end(self):
        self._log_epoch_metrics(Stage.VALIDATION)
    
    def on_test_epoch_start(self):
        self._reset_target_registries(Stage.TEST)
    
    def on_test_epoch_end(self):
        self._log_epoch_metrics(Stage.TEST)
                                
    def _reset_target_registries(self, stage: Stage):
        self.targets_scores[stage] = []
        self.targets_predicted[stage] = []
        self.targets[stage] = []

    def _step(self, batch, stage: Stage):
        audios, targets = batch
        logits, targets_predicted = self._predict(audios)
        if self.debug:
            import pdb; pdb.set_trace()
        self.targets_scores[stage].append(logits)
        self.targets_predicted[stage].append(targets_predicted)
        self.targets[stage].append(targets)
        
        loss = self.criterion(logits, targets)
        
        metric_name = {
            stage.TRAIN: "train_loss",
            stage.VALIDATION: "val_loss",
            stage.TEST: "test_loss",
        }
        
        self.log(metric_name[stage], loss, prog_bar=True)
        return loss
        
    def _predict(self, data):
        logits = self.forward(data).logits
        if self.debug:
            import pdb; pdb.set_trace()
        if torch.any(torch.isnan(logits)):
            import pdb; pdb.set_trace()
        targets_predicted = (logits[:, 1] > logits[:, 0]) * 1
        return logits, targets_predicted
        
    def _log_epoch_metrics(self, stage: Stage):
        targets_predicted = torch.cat(self.targets_predicted[stage], dim=0).squeeze()
        targets = torch.cat(self.targets[stage], dim=0)

        precision = self.precision(targets_predicted, targets)
        recall = self.recall(targets_predicted, targets)
        f1_score = self.f1(targets_predicted, targets)

        self.log(f'{stage.value}_precision', precision, prog_bar=True)
        self.log(f'{stage.value}_recall', recall, prog_bar=True)
        self.log(f'{stage.value}_f1', f1_score, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def get_targets(self, stage: Stage):
        return torch.cat(self.targets[stage], dim=0).to(torch.device("cpu"))
    
    def get_targets_scores(self, stage: Stage):
        if self.debug:
            import pdb; pdb.set_trace()
        return torch.cat(self.targets_scores[stage], dim=0).squeeze().to(torch.device("cpu"))
    
    def get_targets_predicted(self, stage: Stage):
        if self.debug:
            import pdb; pdb.set_trace()
        return torch.cat(self.targets_predicted[stage], 0).squeeze().to(torch.device("cpu"))
        

In [38]:
from pytorch_lightning.loggers import TensorBoardLogger


data_module = DataModule(4, 16000, 64000, real_info, cloned_info)

logger = TensorBoardLogger(str(directory_structure.training_artifacts_path), name="wav2vec2")
detector = ClonedAudioDetector()
trainer = pl.Trainer(
    logger=logger,
    max_epochs=3,
    accelerator="auto",
    log_every_n_steps=10,
    callbacks=[],
    limit_train_batches=10,
    # limit_val_batches=5,
)

trainer.fit(detector, data_module)
trainer.test(detector, data_module)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['projector.weight', 'classifier.bias', 'projector.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                              | Params
----------------------------------------------------------------
0 | model     | Wav2Vec2ForSequenceClassification | 94.6 M
1 | precision | BinaryPrecision                   | 0     
2 | recall    | BinaryRecall                      | 0     
3 | f1        | BinaryF1Score                     | 0     
4 | confmat   | BinaryConfusionMatrix             | 0     
----------------------------------------------------------------
94.6 M    Trainable params
0

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


Testing: 0it [00:00, ?it/s]

[{'test_loss': nan, 'TEST_precision': 0.0, 'TEST_recall': 0.0, 'TEST_f1': 0.0}]

In [51]:
next(iter(data_module.train_dataloader()))[0]

torch.Size([4, 64000])

In [52]:
detector.model(next(iter(data_module.train_dataloader()))[0])

SequenceClassifierOutput(loss=None, logits=tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [54]:
label2id = dict(
            cloned=1,
            real=0)

id2label = {
    1: "cloned",
    0: "real"
}
    
modelraw = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base",
    num_labels=2,
    label2id=label2id,
    id2label=id2label
)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['projector.weight', 'classifier.bias', 'projector.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [56]:
modelraw(next(iter(data_module.train_dataloader()))[0])

SequenceClassifierOutput(loss=None, logits=tensor([[-0.0192,  0.0464],
        [-0.0110,  0.0339],
        [-0.0276,  0.0096],
        [-0.0205,  0.0120]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [14]:
trainer.save_checkpoint(directory_structure.models_path / "wav2vec2.ckpt")

# Evaluation

In [18]:
from sklearn.metrics import ConfusionMatrixDisplay


prediction_metric_functions = dict(
    recall=torchmetrics.functional.recall,
    precision=torchmetrics.functional.precision,
    f1=torchmetrics.functional.f1_score,
    acc=torchmetrics.functional.accuracy,
)


score_metric_functions = dict(
    roc_auc=torchmetrics.AUROC(task="binary"),
)


def get_metrics(model, stage: Stage, threshold=0.5) -> pd.Series:
    targets = model.get_targets(stage)
    targets_scores = model.get_targets_scores(stage)
    targets_predicted = model.get_targets_predicted(stage)
    
    prediction_metrics = {
        metric: function(preds=targets_predicted, target=targets, task="binary", threshold=threshold)
        for metric, function in prediction_metric_functions.items()
    }
    
    import pdb; pdb.set_trace()
    score_metrics = {
        metric: function(targets_scores, targets)
        for metric, function in score_metric_functions.items()
    }
    
    return pd.Series(prediction_metrics | score_metrics)


def plot_metrics(model, stage: Stage, threshold=0.5):
    metrics = get_metrics(model, stage, threshold)
    return go.Figure(
        data=go.Bar(x=metrics.index, y=metrics),
        layout_title=f"Threshold: {threshold}"
    )


def plot_confusion_matrix(model, stage: Stage, threshold=0.5):
    targets = model.get_targets(stage)
    targets_scores = model.get_targets_scores(stage)
    # targets_predicted = targets_scores > 0#model.get_targets_predicted(stage)
    cm = torchmetrics.ConfusionMatrix(task="binary", threshold=threshold)(targets_scores, targets).numpy()
    
    ConfusionMatrixDisplay(confusion_matrix=cm).plot()


def draw_roc(model, stage: Stage, threshold=0.5):
    targets = model.get_targets(stage)
    targets_scores = model.get_targets_scores(stage)
    
    fpr, tpr, thresholds = torchmetrics.ROC(task="binary", threshold=threshold)(targets_scores, targets)
    
    index = np.argmin(abs(thresholds - threshold))
    
    return go.Figure(
        data=[
            go.Scatter(x=fpr, y=tpr),
            go.Scatter(x=[fpr[index]], y=[tpr[index]], showlegend=False, mode="markers+text", text=f"Threshold = {threshold}", textposition="middle right")
        ],
        layout=dict(
            height=500,
            width=500,
            xaxis_title="False positive rate",
            yaxis_title="True positive rate",
            title="ROC"
        )
    )

In [35]:
detector = ClonedAudioDetector.load_from_checkpoint(checkpoint_path=directory_structure.models_path / "wav2vec2.ckpt")
detector.eval()

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['projector.weight', 'classifier.bias', 'projector.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ClonedAudioDetector(
  (model): Wav2Vec2ForSequenceClassification(
    (wav2vec2): Wav2Vec2Model(
      (feature_extractor): Wav2Vec2FeatureEncoder(
        (conv_layers): ModuleList(
          (0): Wav2Vec2GroupNormConvLayer(
            (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (activation): GELUActivation()
            (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
          )
          (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (activation): GELUActivation()
          )
          (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
            (activation): GELUActivation()
          )
        )
      )
      (feature_projection): Wav2Vec2FeatureProjection(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (projection): Linear(in_features=512, ou

In [36]:
data_module = DataModule(4, 16000, 64000, real_info, cloned_info)

trainer = pl.Trainer()
# trainer.validate(detector, data_module)
trainer.test(detector, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: 0it [00:00, ?it/s]

> [0;32m/tmp/ipykernel_4183/1450390607.py[0m(114)[0;36m_predict[0;34m()[0m
[0;32m    112 [0;31m        [0mlogits[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mlogits[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 114 [0;31m        [0mtargets_predicted[0m [0;34m=[0m [0;34m([0m[0mlogits[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m1[0m[0;34m][0m [0;34m>[0m [0mlogits[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m0[0m[0;34m][0m[0;34m)[0m [0;34m*[0m [0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m        [0;32mreturn[0m [0mlogits[0m[0;34m,[0m [0mtargets_predicted[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    116 [0;31m[0;34m[0m[0m
[0m


ipdb>  logits


tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]])


ipdb>  data


tensor([[   0.,    0.,    0.,  ...,  234.,  441.,  659.],
        [   0.,    0.,    0.,  ...,    4.,    4.,    4.],
        [   0.,    0.,    0.,  ...,  185., -114.,  -90.],
        [   0.,    0.,    0.,  ...,    0.,  -12.,  -16.]])


ipdb>  data.shape


torch.Size([4, 64000])


ipdb>  self.forward(data)


SequenceClassifierOutput(loss=None, logits=tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), hidden_states=None, attentions=None)


ipdb>  torch.isnan(data)


tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


ipdb>  any(torch.isnan(data))


*** RuntimeError: Boolean value of Tensor with more than one value is ambiguous


ipdb>  torch.any(torch.isnan(data))


tensor(False)


ipdb>  exit


In [30]:
detector.get_targets_scores(Stage.TEST)

> [0;32m/tmp/ipykernel_4183/3190872002.py[0m(137)[0;36mget_targets_scores[0;34m()[0m
[0;32m    135 [0;31m    [0;32mdef[0m [0mget_targets_scores[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mstage[0m[0;34m:[0m [0mStage[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    136 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 137 [0;31m        [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mtargets_scores[0m[0;34m[[0m[0mstage[0m[0;34m][0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mdevice[0m[0;34m([0m[0;34m"cpu"[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    138 [0;31m[0;34m[0m[0m
[0m[0;32m    139 [0;31m    [0;32mdef[0m [0mget_targets_

ipdb>  torch.cat(self.targets_scores[stage], dim=0).squeeze().to(torch.device("cpu"))


tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],


ipdb>  self.targets_scores[stage]


[tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]]), tensor([[nan, n

ipdb>  exit


In [19]:
plot_metrics(detector, Stage.TEST).show()
draw_roc(detector, Stage.TEST).show()
plot_confusion_matrix(detector, Stage.TEST)

> [0;32m/tmp/ipykernel_4183/294121932.py[0m(28)[0;36mget_metrics[0;34m()[0m
[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m     27 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m    score_metrics = {
[0m[0;32m     29 [0;31m        [0mmetric[0m[0;34m:[0m [0mfunction[0m[0;34m([0m[0mtargets_scores[0m[0;34m,[0m [0mtargets[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m        [0;32mfor[0m [0mmetric[0m[0;34m,[0m [0mfunction[0m [0;32min[0m [0mscore_metric_functions[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  targets_scores


tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],


ipdb>  exit
