In [None]:
#run
# 数据集处理
# imdb
import collections

import numpy as np
import pandas as pd
from openprompt.data_utils.utils import InputExample
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger
from openprompt.data_utils.data_processor import DataProcessor

class IMDBProcessor(DataProcessor):
    """
    # TODO citation

    Examples:

    .. code-block:: python

        from openprompt.data_utils.conditional_generation_dataset import PROCESSORS

        base_path = "datasets/CondGen"

        dataset_name = "webnlg_2017"
        dataset_path = os.path.join(base_path, dataset_name)
        processor = PROCESSORS[dataset_name.lower()]()
        train_dataset = processor.get_train_examples(dataset_path)
        valid_dataset = processor.get_train_examples(dataset_path)
        test_dataset = processor.get_test_examples(dataset_path)

        assert len(train_dataset) == 18025
        assert len(valid_dataset) == 18025
        assert len(test_dataset) == 4928
        assert test_dataset[0].text_a == " | Abilene_Regional_Airport : cityServed : Abilene,_Texas"
        assert test_dataset[0].text_b == ""
        assert test_dataset[0].tgt_text == "Abilene, Texas is served by the Abilene regional airport."
    """

    def __init__(self, is_id_only=True, mode=None):
        super().__init__()
        self.labels = None
        self.is_id_only = is_id_only
        self.monitor_mode = 'label'

    def get_examples(self, data_dir: str, split: str, frac: int = 0) -> List[InputExample]:
        examples = []
        if split == 'train' and frac:
            path = os.path.join(data_dir, "{}_{}.csv".format(split, frac))
        else:
            path = os.path.join(data_dir, "{}.csv".format(split))
        data = pd.read_csv(path)
        intents = list(data["intent"])
        # domian_index = list(data['domain_index'])
        indexs = list(data["index"])
        utts = list(data["utt"])
        id_utt = [utts[i] for i in range(len(utts)) if indexs[i] != -1]
        # id_intent = [intents[i] for i in range(len(utts)) if indexs[i] != -1]
        ood_utt = [utts[i] for i in range(len(utts)) if indexs[i] == -1]
        if self.is_id_only:
            utts = id_utt
            is_oods = [0 for _ in range(len(utts))]
            new_indexs = [indexs[i] for i in range(len(indexs)) if indexs[i] != -1]
            # new_domian_index = [domian_index[i] for i in range(len(domian_index)) if domian_index[i] != -1]
        else:
            utts = id_utt + ood_utt
            is_oods = [0 for _ in range(len(id_utt))] + [1 for _ in range(len(ood_utt))]
            new_indexs = [indexs[i] for i in range(len(indexs)) if indexs[i] != -1] + [-1 for _ in range(len(ood_utt))]
            # new_domian_index = [domian_index[i] for i in range(len(domian_index)) if domian_index[i] != -1] + [-1 for _ in range(len(ood_utt))]
            assert len(new_indexs) == len(utts)
            # assert len(new_domian_index) == len(utts)
        if self.monitor_mode == 'label':
            monitor = new_indexs
        # elif self.monitor_mode == 'domain':
        #     monitor = new_domian_index
        for i, (tgt, is_ood, intent) in enumerate(zip(utts, is_oods, new_indexs)):
            example = InputExample(guid=str(i), text_a="", tgt_text=tgt, meta={"is_ood": is_ood},
                                   label=intent if is_ood == 0 else -1)  # label=intent  tgt_text="the intent is"
            examples.append(example)
        return examples

    def get_src_tgt_len_ratio(self, ):
        pass

    def get_label_words(self, data_dir):
        path = os.path.join(data_dir, "train.csv")
        data = pd.read_csv(path)
        intents = list(data["intent"])
        domian_index = list(data['index'])
        result = collections.defaultdict(str)
        for intent, index in zip(intents, domian_index):
            result[index] = intent
        result = sorted(result.items(), key=lambda x: x[0])
        result = [[b] for a, b in result]
        return result

In [2]:
#run
# clinc
import collections

import numpy as np
import pandas as pd
from openprompt.data_utils.utils import InputExample
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger
from openprompt.data_utils.data_processor import DataProcessor


class ClincProcessor(DataProcessor):
    """
    # TODO citation

    Examples:

    .. code-block:: python

        from openprompt.data_utils.conditional_generation_dataset import PROCESSORS

        base_path = "datasets/CondGen"

        dataset_name = "webnlg_2017"
        dataset_path = os.path.join(base_path, dataset_name)
        processor = PROCESSORS[dataset_name.lower()]()
        train_dataset = processor.get_train_examples(dataset_path)
        valid_dataset = processor.get_train_examples(dataset_path)
        test_dataset = processor.get_test_examples(dataset_path)

        assert len(train_dataset) == 18025
        assert len(valid_dataset) == 18025
        assert len(test_dataset) == 4928
        assert test_dataset[0].text_a == " | Abilene_Regional_Airport : cityServed : Abilene,_Texas"
        assert test_dataset[0].text_b == ""
        assert test_dataset[0].tgt_text == "Abilene, Texas is served by the Abilene regional airport."
    """

    def __init__(self, is_id_only=True, mode=None):
        super().__init__()
        self.labels = None
        self.is_id_only = is_id_only
        self.monitor_mode = 'label' if mode == 150 else 'domain'

    def get_examples(self, data_dir: str, split: str, frac: int = 0) -> List[InputExample]:
        examples = []
        if split == 'train' and frac:
            path = os.path.join(data_dir, "{}_{}.csv".format(split, frac))
        else:
            path = os.path.join(data_dir, "{}.csv".format(split))
        data = pd.read_csv(path)
        intents = list(data["intent"])
        domian_index = list(data['domain_index'])
        indexs = list(data["index"])
        utts = list(data["utt"])
        id_utt = [utts[i] for i in range(len(utts)) if indexs[i] != -1]
        # id_intent = [intents[i] for i in range(len(utts)) if indexs[i] != -1]
        ood_utt = [utts[i] for i in range(len(utts)) if indexs[i] == -1]
        if self.is_id_only:
            utts = id_utt
            is_oods = [0 for _ in range(len(utts))]
            new_indexs = [indexs[i] for i in range(len(indexs)) if indexs[i] != -1]
            new_domian_index = [domian_index[i] for i in range(len(domian_index)) if domian_index[i] != -1]
        else:
            utts = id_utt + ood_utt
            is_oods = [0 for _ in range(len(id_utt))] + [1 for _ in range(len(ood_utt))]
            new_indexs = [indexs[i] for i in range(len(indexs)) if indexs[i] != -1] + [-1 for _ in range(len(ood_utt))]
            new_domian_index = [domian_index[i] for i in range(len(domian_index)) if domian_index[i] != -1] + [-1 for _
                                                                                                               in range(
                    len(ood_utt))]
            assert len(new_indexs) == len(utts)
            assert len(new_domian_index) == len(utts)
        if self.monitor_mode == 'label':
            monitor = new_indexs
        elif self.monitor_mode == 'domain':
            monitor = new_domian_index
        for i, (tgt, is_ood, intent) in enumerate(zip(utts, is_oods, monitor)):
            example = InputExample(guid=str(i), text_a="", tgt_text=tgt, meta={"is_ood": is_ood},
                                   label=intent)  # label=intent  tgt_text="the intent is"
            examples.append(example)

        return examples

    def get_src_tgt_len_ratio(self, ):
        pass

    def get_label_words(self, data_dir):
        path = os.path.join(data_dir, "train.csv")
        data = pd.read_csv(path)
        intents = list(data["intent"])
        domian_index = list(data['index'])
        result = collections.defaultdict(str)
        for intent, index in zip(intents, domian_index):
            result[index] = intent
        result = sorted(result.items(), key=lambda x: x[0])
        result = [[b] for a, b in result]
        return result

In [3]:
# run
import yaml
def load_parameters_from_yaml(file_path):
    with open(file_path, 'r') as file:
        parameters = yaml.safe_load(file)
    return parameters
yaml_file_path = 'paras copy.yml'
loaded_parameters = load_parameters_from_yaml(yaml_file_path)
# print("Loaded Parameters:")
# print(loaded_parameters['dataset']['image_size'])
config = loaded_parameters

In [4]:
# run
if config['dataset'] == 'IMDB':
    OOD_DataProcessor = IMDBProcessor
    datasets_dir = "./datasets/imdb_yelp"
    max_seq_length = 256
    batch_size = config['batch_size']

elif config['dataset'] == 'clinc':
    OOD_DataProcessor = ClincProcessor 
    datasets_dir = "./datasets/clinc150/"
    max_seq_length = 128
    batch_size = config['batch_size']


In [None]:
# run
dataset = {}
dataset['train'] = OOD_DataProcessor(True).get_examples(datasets_dir, "train")
# dataset['val'] = OOD_DataProcessor(True).get_examples(datasets_dir, "valid")

dataset['val'] = OOD_DataProcessor(True).get_examples(datasets_dir, "test")

dataset["val_ood"] = OOD_DataProcessor(False).get_examples(datasets_dir, "valid")
dataset['test'] = OOD_DataProcessor(False).get_examples(datasets_dir, "test")
print(dataset['val'][0])

In [6]:
# run
# dataloader
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from typing import List, Dict

class TextClassificationDataset(Dataset):
    def __init__(self, data: List[Dict[str, str]], tokenizer: BertTokenizer, max_length: int = 128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["tgt_text"]
        label = self.data[idx]["label"]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "label": torch.tensor(label, dtype=torch.long)
        }

def create_data_loader(data: List[Dict[str, str]], tokenizer: BertTokenizer, batch_size: int = 32, max_length: int = 128):
    dataset = TextClassificationDataset(data, tokenizer, max_length)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

def create_test_data_loader(data: List[Dict[str, str]], tokenizer: BertTokenizer, batch_size: int = 32, max_length: int = 128):
    dataset = TextClassificationDataset(data, tokenizer, max_length)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return data_loader

In [17]:
# run
# 评价指标
import numpy as np
from sklearn import metrics


def compute_all_metrics(conf, label, pred):
    np.set_printoptions(precision=3)
    recall = 0.95
    auroc, aupr_in, aupr_out, fpr = auc_and_fpr_recall(conf, label, recall)

    accuracy = acc(pred, label)

    results = [fpr, auroc, aupr_in, aupr_out, accuracy]

    return results


# accuracy
def acc(pred, label):
    ind_pred = pred[label != -1] #id acc
    ind_label = label[label != -1]

    # ind_pred = pred #all acc
    # ind_label = label

    num_tp = np.sum(ind_pred == ind_label)
    acc = num_tp / len(ind_label)

    return acc


# fpr_recall
def fpr_recall(conf, label, tpr):
    gt = np.ones_like(label)
    gt[label == -1] = 0

    fpr_list, tpr_list, threshold_list = metrics.roc_curve(gt, conf)
    fpr = fpr_list[np.argmax(tpr_list >= tpr)]
    thresh = threshold_list[np.argmax(tpr_list >= tpr)]
    return fpr, thresh


# auc
def auc_and_fpr_recall(conf, label, tpr_th):
    # following convention in ML we treat OOD as positive
    ood_indicator = np.zeros_like(label)
    ood_indicator[label == -1] = 1

    # in the postprocessor we assume ID samples will have larger
    # "conf" values than OOD samples
    # therefore here we need to negate the "conf" values
    
    fpr_list, tpr_list, thresholds = metrics.roc_curve(ood_indicator, -conf)
    fpr = fpr_list[np.argmax(tpr_list >= tpr_th)]

    precision_in, recall_in, thresholds_in \
        = metrics.precision_recall_curve(1 - ood_indicator, conf)

    precision_out, recall_out, thresholds_out \
        = metrics.precision_recall_curve(ood_indicator, -conf)

    auroc = metrics.auc(fpr_list, tpr_list)
    aupr_in = metrics.auc(recall_in, precision_in)
    aupr_out = metrics.auc(recall_out, precision_out)

    return auroc, aupr_in, aupr_out, fpr


# ccr_fpr
def ccr_fpr(conf, fpr, pred, label):
    ind_conf = conf[label != -1]
    ind_pred = pred[label != -1]
    ind_label = label[label != -1]

    ood_conf = conf[label == -1]

    num_ind = len(ind_conf)
    num_ood = len(ood_conf)

    fp_num = int(np.ceil(fpr * num_ood))
    thresh = np.sort(ood_conf)[-fp_num]
    num_tp = np.sum((ind_conf > thresh) * (ind_pred == ind_label))
    ccr = num_tp / num_ind

    return ccr


def detection(ind_confidences,
              ood_confidences,
              n_iter=100000,
              return_data=False):
    # calculate the minimum detection error
    Y1 = ood_confidences
    X1 = ind_confidences

    start = np.min([np.min(X1), np.min(Y1)])
    end = np.max([np.max(X1), np.max(Y1)])
    gap = (end - start) / n_iter

    best_error = 1.0
    best_delta = None
    all_thresholds = []
    all_errors = []
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
        error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
        detection_error = (tpr + error2) / 2.0

        if return_data:
            all_thresholds.append(delta)
            all_errors.append(detection_error)

        if detection_error < best_error:
            best_error = np.minimum(best_error, detection_error)
            best_delta = delta

    if return_data:
        return best_error, best_delta, all_errors, all_thresholds
    else:
        return best_error, best_delta


In [None]:
# run
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

tokenizer = BertTokenizer.from_pretrained('./pretrained_models', padding=True, truncation=True, return_tensors='pt', do_lower_case=True)
model = BertForSequenceClassification.from_pretrained('./pretrained_models', num_labels=config['K'])
# new_classifier = nn.Linear(model.config.hidden_size, config['K'])  # 假设 num_labels 是你的分类数量
# model.classifier = new_classifier
# print(tokenizer.tokenize('I have a good time, thank you.'))

train_data = [{"tgt_text": example.tgt_text, "label": example.label} for example in dataset['train']]
val_data = [{"tgt_text": example.tgt_text, "label": example.label} for example in dataset['val']]
test_data = [{"tgt_text": example.tgt_text, "label": example.label} for example in dataset['test']]

train_dataloader = create_data_loader(train_data, tokenizer, batch_size=config['batch_size'], max_length=max_seq_length)
val_dataloader = create_test_data_loader(val_data, tokenizer, batch_size=config['batch_size'], max_length=max_seq_length)
test_dataloader = create_test_data_loader(test_data, tokenizer, batch_size=config['batch_size'], max_length=max_seq_length)


In [9]:
# run
import torch
import torch.nn as nn
import torch.nn.functional as F

class GRODNet(nn.Module):
    def __init__(self, backbone, feat_dim, num_classes):
        super(GRODNet, self).__init__()

        self.backbone = backbone
        if hasattr(self.backbone, 'fc'):
            # remove fc otherwise ddp will
            # report unused params
            self.backbone.fc = nn.Identity()

        self.lda = LDA(n_components=feat_dim)
        self.pca = PCA(n_components=feat_dim)

        self.n_cls = num_classes
        self.head1 = nn.Linear(768, 2 * num_classes)
        self.head = nn.Linear(768, self.n_cls + 1)
        self.k = nn.Parameter(torch.tensor([0.1], dtype=torch.float32, requires_grad=True))

    def forward(self, x, y, attention): #x:data feature, y:label
        feat = self.backbone.bert(x, attention)[1]#.squeeze() #(b,768)
        # output = self.backbone(x)[0].squeeze() #(b,10)
        self.lda.fit(feat, y)
        X_lda = self.lda.transform(feat) #(b, feat_dim)
        
        self.pca.fit(feat)
        X_pca = self.pca.transform(feat)

        return feat, X_lda, X_pca
    def intermediate_forward(self, x, attention):
        feat = self.backbone.bert(x, attention)[1]
        output = self.head(feat)
        score = torch.softmax(output, dim=1)
        score0 = output[:,:-1]
        # score0 = torch.softmax(output[:,:-1], dim=1)
        conf = torch.max(score, dim=1)
        pred = torch.argmax(score, dim=1)
        conf0 = torch.max(score0, dim=1)
        pred0 = torch.argmax(score0, dim=1)
        for i in range(pred.size(0)):
            if pred[i] == output.size(1) - 1:
                # conf[i] = 0.1
                # pred[i] = 1
                score0[i] = 0.1 * torch.ones(score0.size(1)).to(x.device)
            # else:
                # conf[i] = conf0[i]   
        # return score0
        return torch.softmax(score0, dim=1)



class LDA(nn.Module):
    def __init__(self, n_components):
        super(LDA, self).__init__()
        self.n_components = n_components

    def fit(self, X, y):
        try:
            n_samples, n_features = X.shape
        except:
            n_features = X.shape[0]
        classes = torch.unique(y)
        n_classes = len(classes)
        
        means = torch.zeros(n_classes, n_features).to(X.device)
        for i, c in enumerate(classes):
            try:
                means[i] = torch.mean(X[y==c], dim=0)
            except:
                X = torch.unsqueeze(X, dim=0)
                means[i] = torch.mean(X[y==c], dim=0)
        
        overall_mean = torch.mean(X, dim=0)
        
        within_class_scatter = torch.zeros(n_features, n_features).to(X.device)
        for i, c in enumerate(classes):
            class_samples = X[y==c]
            deviation = class_samples - means[i]
            within_class_scatter += torch.mm(deviation.t(), deviation)
        
        between_class_scatter = torch.zeros(n_features, n_features).to(X.device)
        for i, c in enumerate(classes):
            n = len(X[y==c])
            mean_diff = (means[i] - overall_mean).unsqueeze(1)
            between_class_scatter += n * torch.mm(mean_diff, mean_diff.t())

        # torch.backends.cuda.preferred_linalg_library('magma')
        # print((torch.inverse(within_class_scatter) @ between_class_scatter).size()) #(768,768)
        eigenvalues, eigenvectors = torch.linalg.eigh(
        torch.inverse(within_class_scatter @ between_class_scatter  + 1e-2 * torch.eye((within_class_scatter @ between_class_scatter).size(0)).to(X.device)
        ))
        _, top_indices = torch.topk(eigenvalues, k=self.n_components, largest=True)
        self.components = eigenvectors[:, top_indices]

    def transform(self, X):
        return torch.mm(X, self.components)

class PCA(nn.Module):
    def __init__(self, n_components):
        super(PCA, self).__init__()
        self.n_components = n_components

    def fit(self, X):
        try:
            n_samples, n_features = X.shape
        except:
            n_samples = 1
        
        self.mean = torch.mean(X, dim=0)
        X_centered = X - self.mean
        
        covariance_matrix = torch.mm(X_centered.t(), X_centered) / max((n_samples - 1),1)
        
        eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix)
        _, top_indices = torch.topk(eigenvalues, k=self.n_components, largest=True)
        self.components = eigenvectors[:, top_indices]

    def transform(self, X):
        X_centered = X - self.mean
        return torch.mm(X_centered, self.components)


In [11]:
# run grod v2
import faiss.contrib.torch_utils
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from tqdm import tqdm
from einops import repeat

torch.autograd.set_detect_anomaly(True)
class GRODTrainer_Soft_Label:
    def __init__(self, net: nn.Module, train_loader: DataLoader,
                 config) -> None:
        self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        self.net = net.to(self.device)  
        self.train_loader = train_loader
        self.config = config

        self.n_cls = config['dataset']['num_classes']


        self.optimizer = torch.optim.AdamW(
            params=net.parameters(),
            lr=config['optimizer']['lr'],
            weight_decay=config['optimizer']['weight_decay'],
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max = 10
        )

        self.head = self.net.head
        self.head1 = self.net.head1
        self.alpha = config['trainer']['alpha']
        self.nums_rounded = config['trainer']['nums_rounded']
        self.gamma = config['trainer']['gamma']
        self.stat_smooth = 0.3
        self.batch_size = config['dataset']['batch_size']
        self.threshold = 20
        
        self.k = self.net.k
        
        self.best_accuracy = 0.0  # Update best accuracy
        self.best_model_state = None  # Save current model state    

    def train(self, epochs):
        # adjust_learning_rate(self.config, self.optimizer, epoch_idx - 1)
        # self.net = nn.DataParallel(self.net)
        self.net.train()
        self.net.to(self.device)
        for epoch_idx in range(epochs):
            loss_avg = 0.0
            train_dataiter = iter(self.train_loader)
    

            sub_datasets_in_mu = torch.zeros((self.n_cls, 768)).to(self.device) #(K,f)
            dataset_in_mu = torch.zeros(768).to(self.device) #(f)
            dataset_in_cov = torch.zeros(768, 768).to(self.device)
            sub_datasets_in_cov = torch.zeros((self.n_cls, 768, 768)).to(self.device)
            sub_datasets_in_distances = torch.zeros(self.n_cls).to(self.device)
            
            torch.autograd.detect_anomaly(True)        
            
            #### Warmup: first x step without lda ood, compute mu and cov for each class instead ###
            warmup = int(self.threshold * self.n_cls / self.batch_size)
            data_warmup = None
            print("Warmup...")
            if warmup == 0:
                pass
            else:
                for train_step in tqdm(range(1,
                                            warmup + 1),
                                    desc='Epoch {:03d}: '.format(epoch_idx),
                                    position=0,
                                    leave=True):
                    with torch.no_grad():
                        batch = next(train_dataiter)
                        data = batch['input_ids'].to(self.device)
                        target = batch['label'].to(self.device)        
                        attention_mask = batch['attention_mask'].to(self.device)  
                        
                        data_in, feat_lda, feat_pca = self.net(data, target, attention_mask)             
                        
                        if train_step == 1:
                            data_warmup = data_in
                        else:    
                            data_warmup = torch.cat((data_warmup, data_in), dim=0) 
            
            if warmup == 0:
                pass
            else:        
                dataset_in_mu = torch.mean(data_warmup, dim = 0)
                cov0 = torch.tensor(self.calculate_covariance_matrix(data_warmup).detach() + 1e-4 * torch.eye(dataset_in_mu.size(0)).to(self.device).detach(), dtype = torch.double)
                L = torch.linalg.cholesky(cov0).detach()
                L_inv = torch.linalg.inv(L).detach()
                dataset_in_cov = torch.tensor(torch.mm(L_inv.t(), L_inv).unsqueeze(0).detach(), dtype=torch.float)
                
                sub_datasets_in = [Subset(data_warmup, torch.where(target == i)[0]) for i in range(self.n_cls)]      
                    

                for i in range(len(sub_datasets_in)):
                    dataloader = DataLoader(sub_datasets_in[i], batch_size=int(self.threshold * self.n_cls), shuffle=False)
                    for batch in dataloader:
                        tensor_data_in = batch
                    
                        mean =  torch.mean(tensor_data_in, dim = 0)
                        cov0 = (self.calculate_covariance_matrix(tensor_data_in)+1e-4 * torch.eye(mean.size(0)).to(self.device)).detach()
                        L = torch.linalg.cholesky(cov0).detach()
                        L_inv = torch.linalg.inv(L).detach()

                        # Solve the inverse of a symmetric positive definite matrix A using the inverse of a lower triangular matrix
                        cov = torch.mm(L_inv.t(), L_inv)

                        sub_datasets_in_cov[i,:,:] = cov.detach()
                        sub_datasets_in_mu[i,:] = mean.detach()      
                        sub_datasets_in_distances[i] = torch.max(self.mahalanobis(tensor_data_in, sub_datasets_in_mu.clone(), sub_datasets_in_cov.clone())[:,i]).detach()                                              
            #### Warmup: first x step without lda ood, compute mu and cov for each class instead ###     
            
            self.net.train()
            for train_step in tqdm(range(warmup + 1,
                                        len(train_dataiter)),
                                desc='Epoch {:03d}: '.format(epoch_idx),
                                position=0,
                                leave=True):

                batch = next(train_dataiter)
                data = batch['input_ids'].to(self.device)
                target = batch['label'].to(self.device)        
                attention_mask = batch['attention_mask'].to(self.device)  
                
                data_in, feat_lda, feat_pca = self.net(data, target, attention_mask)    
                
                
                data = data_in
                data_in = data_in.detach()
                feat_lda = feat_lda.detach()
                feat_pca = feat_pca.detach()

                # generate rounded ood data
                sub_datasets_in = [Subset(data_in, torch.where(target == i)[0]) for i in range(self.n_cls)]
                sub_datasets_lda = [Subset(feat_lda, torch.where(target == i)[0]) for i in range(self.n_cls)]   
                
                # Count the number of samples in each sub-dataset
                dataset_lengths = torch.tensor([len(subset) for subset in sub_datasets_lda])
                mask = dataset_lengths > 2
                lda_class = len(dataset_lengths[mask])
                
                
                reshaped_rounded_data, dataset_in_mu = self.grod_generate_pca(data_in, feat_lda, feat_pca, train_step, dataset_in_mu)
                
                data = torch.cat((data, reshaped_rounded_data), dim = 0)
                
                if lda_class > 0:
                    reshaped_rounded_data, sub_datasets_in_mu, sub_datasets_in_cov, sub_datasets_in_distances = self.grod_generate_lda(feat_lda, sub_datasets_in, sub_datasets_lda, sub_datasets_in_mu, sub_datasets_in_cov, sub_datasets_in_distances, lda_class)
            
                    data = torch.cat((data, reshaped_rounded_data), dim = 0)
            
                data = torch.cat((data, reshaped_rounded_data), dim = 0)
                # print(reshaped_rounded_data.size())

                    

                data_add = data[data_in.size(0):]   
                # print(data_add.size())
                    
                
                distances = self.mahalanobis(data_add, sub_datasets_in_mu, sub_datasets_in_cov).to(self.device) #(n,k)
                
                # Calculate the minimum distance and corresponding category index of each sample point
                min_distances, min_distances_clas = torch.min(distances, dim=1)                       
                # Get the sub-dataset distance corresponding to each sample point
                sub_distances = sub_datasets_in_distances[min_distances_clas.to(self.device)]
                
                ### soft label of outliers ###
                target_add = torch.zeros((data_add.size(0), self.n_cls + 1)).to(self.device) #(n, K+1)
                extend = torch.clamp(sub_datasets_in_distances / (distances + 1e-5), -1e5, 1e5)
                # print(extend.max())
                # print(extend.size(), extend)
                target_add[:,:-1] = torch.exp(- (1 - extend))
                # extend_ood = torch.gather(extend, 1, min_distances_clas.unsqueeze(1)).squeeze(1).to(self.device)
                extend_ood, _ = torch.max(extend, dim=1)
                # print(extend_ood)
                assert not torch.isnan(extend).any(), "NaN values found in `extend`"
                assert not torch.isinf(extend).any(), "Inf values found in `extend`"
                assert not torch.isnan(min_distances_clas).any(), "NaN values found in `min_distances_clas`"
                target_add[:,-1] = torch.exp(1 - extend_ood)
                # print(target_add[:,:-1], target_add[:,-1])
                ### soft label of outliers ###
                
                
                
                k_init = (torch.mean(min_distances / sub_distances) - 1) * 10
                
                mask = min_distances > (1 + k_init * self.k.to(self.device)[0]) * sub_distances
                # Use Boolean indexing to remove data points that meet a condition
                cleaned_data_add = data_add[mask.to(self.device)]
                cleaned_target_add = target_add[mask.to(self.device)]
                    
                if cleaned_data_add.size(0) > data_in.size(0) // self.n_cls + 2:
                    indices = torch.randperm(cleaned_data_add.size(0))[:(data_in.size(0) // self.n_cls + 2)].to(self.device)
                    cleaned_data_add_de = cleaned_data_add[indices]
                    cleaned_target_add_de = cleaned_target_add[indices]
                else: 
                    cleaned_data_add_de = cleaned_data_add
                    cleaned_target_add_de = cleaned_target_add
                    
                data = torch.cat((data[:data_in.size(0)], cleaned_data_add_de), dim = 0)
                
                target = F.one_hot(target, num_classes=self.n_cls + 1)
                
                target = torch.cat((target, cleaned_target_add_de), dim = 0)
                # print(target.size())
                
                    

                output = self.head(data)
                # print(output.size())
                # output = F.normalize(output, dim=1)
                loss1 = F.cross_entropy(output, target)

                label_matrix = output
                biclas = torch.zeros(label_matrix.size(0), 2)
                biclas[:,-1] = label_matrix[:,-1]
                biclas[:,0] = torch.sum(label_matrix[:,:-1],-1)
                
                label_biclas = torch.zeros(target.size(0), 2)
                label_biclas[:,-1] = target[:,-1]
                label_biclas[:,0] = torch.sum(target[:,:-1],-1)
                
                loss2 = F.cross_entropy(biclas.to(self.device), label_biclas.to(self.device))
                # print(loss1, loss2, cleaned_target_add_de) 
                loss = (1 - self.gamma) * loss1 + self.gamma * loss2 
                
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                # exponential moving average, show smooth values
                with torch.no_grad():
                    loss_avg = loss_avg * 0.8 + float(loss) * 0.2

            print(f'Epoch {epoch_idx + 1}/{len(train_dataiter)}, Average Training Loss: {loss_avg:.4f}')
            # accuracy = correct / total
            # print(f'Accuracy on validation set: {accuracy:.4f}')
            accuracy = self.test_model()  # Test model after each epoch
            if accuracy > self.best_accuracy or accuracy == self.best_accuracy:  # If current accuracy is better than best
                self.best_accuracy = accuracy  # Update best accuracy
                self.best_model_state = self.net.state_dict()  # Save current model state                
                
    def test_model(self):
        self.net.eval()  # Switch to evaluation mode
        self.net.to(self.device)
        correct = 0
        total = 0
        val_dataiter = iter(self.train_loader)
        with torch.no_grad():
            for train_step in tqdm(range(1, len(val_dataiter) + 1),
                                    position=0,
                                    leave=True):
                batch = next(val_dataiter)
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch['label'].to(self.device)
                
                data_in, feat_lda, feat_pca = self.net(input_ids, labels, attention_mask)
                outputs = self.head(data_in)
                # outputs = self.net.backbone(input_ids, attention_mask).logits
                predicted = torch.argmax(outputs[:,:-1], dim=1)
                # print(predicted, labels)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f'Accuracy on validation set: {accuracy:.4f}')
        return accuracy    

    def save_best_model(self, filename):
        if self.best_model_state is not None:
            torch.save(self.best_model_state, filename)

    
    def grod_generate_pca(self, data_in, feat_lda, feat_pca, train_step, dataset_in_mu):
        # generate PCA ood data
        argmax = torch.zeros(feat_pca.size()[1])
        argmax = torch.argmax(feat_pca,dim=0) #feat_dim
        
        argmin = torch.zeros(feat_pca.size()[1])
        argmin = torch.argmin(feat_pca,dim=0) #feat_dim

        for j in range(feat_pca.size()[1]):
            if j==0:
                pcadata_rounded_category = data_in[int(argmax[j].item())].unsqueeze(0)
                pcadata_rounded_category_1 = data_in[int(argmin[j].item())].unsqueeze(0)
            else:
                
                pcadata_rounded_category = torch.cat((pcadata_rounded_category, data_in[int(argmax[j].item())].unsqueeze(0)),dim=0)
                pcadata_rounded_category_1 = torch.cat((pcadata_rounded_category_1, data_in[int(argmin[j].item())].unsqueeze(0)),dim=0)
        
        ### mu and std smoothing ###
        if train_step == 1:
            dataset_in_mu = torch.mean(data_in, dim = 0)
        else:
            dataset_in_mu = (1 - self.stat_smooth) * torch.mean(data_in.detach().clone(), dim = 0) + self.stat_smooth * dataset_in_mu.detach().clone() 
        ### mu and std smoothing ###
        
        # dataset_in_mu = torch.mean(data_in.detach().clone(), dim = 0)
        
        dataset_in_mu =  repeat(dataset_in_mu.squeeze(), "f -> b f", 
                                        f = data_in.size(1), b = feat_lda.size()[1])
        # print(data_rounded_category.size())
        B = pcadata_rounded_category.detach()
        B_1 = pcadata_rounded_category_1.detach()
        # print(A.size())
        pcavector = F.normalize(B.clone() - dataset_in_mu, dim = 1)
        pcavector_1 = F.normalize(B_1.clone() - dataset_in_mu, dim = 1)
        B = torch.add(B, self.alpha * pcavector).detach() #(feat_dim, 768)
        B_1 = torch.add(B_1, self.alpha * pcavector_1).detach() #(feat_dim, 768)
        mean_matrix_0 = B
        mean_matrix_1 = B_1
        # print(A.size())
        mean_matrix = torch.cat((mean_matrix_0, mean_matrix_1), dim = 0)
        # mean_matrix = mean_matrix_0
        std = 1 / 3 * self.alpha
        mu = mean_matrix.T.unsqueeze(2).to(self.device) 
        rand_data = torch.randn(mean_matrix.size(1), self.nums_rounded).to(self.device) 
        gaussian_data = mu + std * rand_data.unsqueeze(1) #(768, num, nums_rounded)
        # print(gaussian_data.size())
        nums = gaussian_data.size(1)
        nums_rounded = gaussian_data.size(2)
        reshaped_rounded_data = gaussian_data.permute(1, 2, 0).contiguous().view(nums * nums_rounded, mean_matrix.size(1)) # (num* nums_rounded, 768)
        # print(reshaped_rounded_data.size(),data.size())
        return reshaped_rounded_data, dataset_in_mu
        
    def grod_generate_lda(self, feat_lda, sub_datasets_in, sub_datasets_lda, sub_datasets_in_mu, sub_datasets_in_cov, sub_datasets_in_distances, lda_class):   
        dataset_lengths = torch.tensor([len(subset) for subset in sub_datasets_lda])
        # Get the index of sub-datasets with the largest amount of data
        top_indices = sorted(range(len(dataset_lengths)), key=lambda i: dataset_lengths[i], reverse=True)[:lda_class]

        
        arg_max = torch.zeros((len(sub_datasets_lda), feat_lda.size()[1]))
        arg_min = torch.zeros((len(sub_datasets_lda), feat_lda.size()[1]))
        k = 0
        for i in top_indices:
            k = k + 1
            dataloader = DataLoader(sub_datasets_lda[i], batch_size=64, shuffle=False)
            for batch in dataloader:
                tensor_data_lda = batch
            dataloader = DataLoader(sub_datasets_in[i], batch_size=64, shuffle=False)
            for batch in dataloader:
                tensor_data_in = batch
            arg_max[i] = torch.argmax(tensor_data_lda, dim=0) #feat_dim                   
            arg_min[i] = torch.argmin(tensor_data_lda, dim=0) #feat_dim

            for j in range(feat_lda.size()[1]):
                # print(argmax[i][j].item())
                if k == 1 and j==0:
                    data_rounded_category = tensor_data_in[int(arg_max[i][j].item())].unsqueeze(0)
                    data_rounded_category_1 = tensor_data_in[int(arg_min[i][j].item())].unsqueeze(0)
                else:
                    data_rounded_category = torch.cat((data_rounded_category, tensor_data_in[int(arg_max[i][j].item())].unsqueeze(0)),dim=0)
                    data_rounded_category_1 = torch.cat((data_rounded_category_1, tensor_data_in[int(arg_min[i][j].item())].unsqueeze(0)),dim=0)

            if tensor_data_in.size(0) > 1:
                
                mean =  torch.mean(tensor_data_in, dim = 0)
                cov0 = (self.calculate_covariance_matrix(tensor_data_in)+1e-4 * torch.eye(mean.size(0)).to(self.device)).detach()
                L = torch.linalg.cholesky(cov0).detach()
                L_inv = torch.linalg.inv(L).detach()

                # Solve the inverse of a symmetric positive definite matrix A using the inverse of a lower triangular matrix
                cov = torch.mm(L_inv.t(), L_inv)
                ### mu and std smoothing ###
                if torch.max(torch.abs(sub_datasets_in_mu[i,:]))<1e-7:
                    sub_datasets_in_cov[i,:,:] = cov.detach()
                    sub_datasets_in_mu[i,:] = mean.detach()                        
                    sub_datasets_in_distances[i] = torch.max(self.mahalanobis(tensor_data_in, sub_datasets_in_mu.clone(), sub_datasets_in_cov.clone())[:,i]).detach()                                                                     
                else:
                    sub_datasets_in_cov[i,:,:] = (1 - self.stat_smooth) * cov.detach().clone().to(self.device) + self.stat_smooth * sub_datasets_in_cov[i,:,:].detach().clone()
                    sub_datasets_in_mu[i,:] = (1 - self.stat_smooth) * mean.detach().clone().to(self.device) + self.stat_smooth * sub_datasets_in_mu[i,:].detach().clone()
                    dists = self.mahalanobis(tensor_data_in, sub_datasets_in_mu.clone(), sub_datasets_in_cov.clone())[:,i]
                    dist = torch.max(dists)
                    sub_datasets_in_distances[i] = (1 - self.stat_smooth) * dist.to(self.device).detach().clone() + self.stat_smooth * sub_datasets_in_distances[i].detach().clone()
                ### mu and std smoothing ###
            
            
            sub_datasets_in_mean =  repeat(sub_datasets_in_mu.clone()[i,:], "f -> b f", 
                                        f = tensor_data_in.size(1), b = feat_lda.size()[1])
            
            A = data_rounded_category[-feat_lda.size()[1]:].detach()
            A_1 = data_rounded_category_1[- feat_lda.size()[1]:].detach()
            vector = F.normalize(A.to(self.device) - sub_datasets_in_mean.to(self.device), dim = 1)
            vector_1 = F.normalize(A_1.to(self.device) - sub_datasets_in_mean.to(self.device), dim = 1)
            A = A + self.alpha * vector.detach().to(self.device) #(feat_dim, 768)
            A_1 = A_1 + self.alpha * vector_1.detach().to(self.device) #(feat_dim, 768)
            if k == 1:
                mean_matrix_0 = A
                mean_matrix_1 = A_1
            else:
                mean_matrix_0 = torch.cat((mean_matrix_0, A), dim = 0) #(num, 768)
                mean_matrix_1 = torch.cat((mean_matrix_1, A_1), dim = 0) #(num, 768)
            mean_matrix = torch.cat((mean_matrix_0, mean_matrix_1), dim = 0)
            # print(mean_matrix.size())
            std = 1 / 3 * self.alpha
            mu = mean_matrix.T.unsqueeze(2).to(self.device) #(768,num,1)
            rand_data = torch.randn(mean_matrix.size(1), self.nums_rounded).to(self.device) #(768,nums_rounded)
            gaussian_data = mu + std * rand_data.unsqueeze(1) #(768, num, nums_rounded)
            # print(gaussian_data.size())
            nums = gaussian_data.size(1)
            nums_rounded = gaussian_data.size(2)
            reshaped_rounded_data = gaussian_data.permute(1, 2, 0).contiguous().view(nums * nums_rounded, mean_matrix.size(1)) # (num* nums_rounded, 768)
            
            return reshaped_rounded_data, sub_datasets_in_mu, sub_datasets_in_cov, sub_datasets_in_distances
    
    def mahalanobis(self, x, support_mean, inv_covmat): #(n,d), (k,d), (k,d,d)
        n = x.size(0)
        d = x.size(1)

        x = x.to(inv_covmat.device)
        support_mean = support_mean.to(inv_covmat.device)

        maha_dists = []
        for i in range(inv_covmat.size(0)):
            class_inv_cov = inv_covmat[i].detach()
            support_class = support_mean[i].detach()
        
            x_mu = x - support_class.unsqueeze(0).expand(n, d)            
            class_inv_cov = class_inv_cov.to(inv_covmat.device)

            # Mahalanobis distances
            left = torch.matmul(x_mu, class_inv_cov)
            # print(x_mu.size(), class_inv_cov.size(), left.size())
            mahal = torch.matmul(left, x_mu.t()).diagonal()
            maha_dists.append(mahal)

        return torch.stack(maha_dists).t()
    
    def calculate_covariance_matrix(self, data):
        mean = torch.mean(data, dim=0)
        mean = mean.unsqueeze(0).expand(data.size(0), data.size(1))
        centered_data = data - mean

        covariance_matrix = torch.mm(centered_data.t(), centered_data) / (centered_data.size(0) - 1 + 1e-7)

        return covariance_matrix


In [None]:
yaml_file_path = 'grod copy.yml'
loaded_parameters = load_parameters_from_yaml(yaml_file_path)
config_grod = loaded_parameters
        
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model_grod = GRODNet(model, 1, config_grod['dataset']['num_classes']).to(device)
trainer = GRODTrainer_Soft_Label(model_grod, train_dataloader, config_grod)
trainer.train(config_grod['optimizer']['num_epochs']) 

# Save the best model state
trainer.save_best_model('best_model_grod2_clinc_0.001.ckpt') 

In [None]:
# 加载保存的状态字典
state_dict = torch.load('best_model_grod2_clinc_0.001.ckpt')
model_grod = GRODNet(model, 1, config_grod['dataset']['num_classes'])
# 将加载的状态字典加载到模型中
model_grod.load_state_dict(state_dict)
print(model_grod)

In [14]:
from typing import Any

import numpy as np
import torch
import torch.nn as nn
from numpy.linalg import norm, pinv
from scipy.special import logsumexp
from sklearn.covariance import EmpiricalCovariance
from tqdm import tqdm

# run
from typing import Any
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader


class BasePostprocessor:
    def __init__(self, config):
        self.config = config

    def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
        pass

    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any, attention):
        output = net(input_ids=data, attention_mask=attention)
        score = torch.softmax(output.logits, dim=1)
        conf, pred = torch.max(score, dim=1)
        return pred, conf

    def inference(self,
                  net: nn.Module,
                  data_loader: DataLoader, 
                  alpha, w, b, u, NS,
                  progress: bool = True):
        pred_list, conf_list, label_list = [], [], []
        for batch in tqdm(data_loader):
            input_ids = batch['input_ids'].cuda()
            attention_mask = batch["attention_mask"].cuda()
            labels = batch['label'].cuda()
            pred, conf = self.postprocess(net.cuda(), input_ids, attention_mask, alpha, w, b, u, NS,)

            pred_list.append(pred.cpu())
            conf_list.append(conf.cpu())
            label_list.append(labels.cpu())

        # convert values into numpy array
        pred_list = torch.cat(pred_list).numpy().astype(int)
        conf_list = torch.cat(conf_list).numpy()
        label_list = torch.cat(label_list).numpy().astype(int)

        return pred_list, conf_list, label_list


class GRODPostprocessor(BasePostprocessor):
    def __init__(self, config):
        super().__init__(config)
        self.args = self.config['postprocessor']['postprocessor_args']
        self.args_dict = self.config['postprocessor']['postprocessor_sweep']
        self.dim = self.args['dim']
        self.setup_flag = False

    def setup(self, net: nn.Module, id_loader_dict):
        if not self.setup_flag:
            net.eval()
            net.cuda()
            with torch.no_grad():
                # self.w, self.b = net.backbone.get_fc()
                # print(self.b.size)
                self.w, self.b = net.head.weight[:-1,:].cpu().numpy(), net.head.bias[:-1].cpu().numpy()
                # print(self.w.size())
                print('Extracting id training feature')
                feature_id_train = []
                logit_id_train = []
                for batch in tqdm(id_loader_dict,
                                  desc='Setup: ',
                                  position=0,
                                  leave=True):
                    data = batch['input_ids'].cuda()
                    attention_mask = batch["attention_mask"].cuda()
                    labels = batch['label'].cuda()
                    # data = data.float()
                    feature = net.backbone.bert(data, attention_mask)[1]
                    logit = net.head(feature)
                    score = torch.softmax(logit, dim=1)
                    score0 = torch.softmax(logit[:,:-1], dim=1)
                    conf, pred = torch.max(score, dim=1)
                    conf0, pred0 = torch.max(score0, dim=1)
                    for i in range(pred.size(0)):
                        if pred[i] == logit.size(1) - 1:
                            conf[i] = 0.1
                            pred[i] = 1
                            score0[i, :] = 0.1 * torch.ones(score0.size(1)).cuda()
                        else:
                            conf[i] = conf0[i]     
                        
                    feature_id_train.append(feature.cpu().numpy())
                    logit_id_train.append(score0.cpu().numpy())
                feature_id_train = np.concatenate(feature_id_train, axis=0)
                logit_id_train = np.concatenate(logit_id_train, axis=0)

                # logit_id_train = feature_id_train @ self.w.T + self.b

            self.u = -np.matmul(pinv(self.w), self.b)
            ec = EmpiricalCovariance(assume_centered=True)
            ec.fit(feature_id_train - self.u)
            eig_vals, eigen_vectors = np.linalg.eig(ec.covariance_)
            self.NS = np.ascontiguousarray(
                (eigen_vectors.T[np.argsort(eig_vals * -1)[self.dim:]]).T)

            vlogit_id_train = norm(np.matmul(feature_id_train - self.u,
                                             self.NS),
                                   axis=-1)
            
            print(feature_id_train - self.u, self.NS)
            
            self.alpha = logit_id_train.max(
                axis=-1).mean() / vlogit_id_train.mean()
            print(f'{self.alpha=:.4f}')

            self.setup_flag = True
        else:
            pass
        return self.alpha, self.w, self.b, self.u, self.NS
    
    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any, attention, alpha, w, b, u, NS):
        feature_ood = net.backbone.bert(data, attention)[1]
        
        logit = net.head(feature_ood)
        score = torch.softmax(logit, dim=1)
        score0 = torch.softmax(logit[:,:-1], dim=1)
        conf, pred = torch.max(score, dim=1)
        conf0, pred0 = torch.max(score0, dim=1)
        for i in range(pred.size(0)):
          if pred[i] == logit.size(1) - 1:
            conf[i] = 0.1
            pred[i] = 1
            score0[i, :] = 0.1 * torch.ones(score0.size(1)).cuda()
          else:
            conf[i] = conf0[i]
        logit_ood = score0.cpu()    
        
        feature_ood = feature_ood.cpu()
        
        # logit_ood = feature_ood @ w.T + b
        _, pred = torch.max(logit_ood, dim=1)
        energy_ood = logsumexp(logit_ood.numpy(), axis=-1)
        vlogit_ood = norm(np.matmul(feature_ood.numpy() - u, NS),
                          axis=-1) * alpha
        score_ood = -vlogit_ood + energy_ood
        return pred, torch.from_numpy(score_ood)

    def set_hyperparam(self, hyperparam: list):
        self.dim = hyperparam[0]

    def get_hyperparam(self):
        return self.dim


In [None]:
yaml_file_path = 'grod copy.yml'
loaded_parameters = load_parameters_from_yaml(yaml_file_path)
config_grod = loaded_parameters
alpha, w, b, u, NS = GRODPostprocessor(config_grod).setup(model_grod, train_dataloader)

pred_list, conf_list, label_list = GRODPostprocessor(config_grod).inference(model_grod, test_dataloader, alpha, w, b, u, NS)
# conf_list = 1e7 * np.ones(conf_list.shape[0])
# [fpr, auroc, aupr_in, aupr_out, accuracy]
compute_all_metrics(conf_list, label_list, pred_list)