rewiting the old one to match the new interface while keeping compatibility

make  `query` mehthod :
1. as a static method and adapting it to use the new ModelHandler and DataSetHandler interfaces for operations related to model and dataset management.
2. The new query method must return a tuple of lists, 
    * 1st list :  contains indices of selected samples from the unlabeled dataset,
    * 2nd list : Additional info (in the context of BadgeSampling, this might not be necessary, so it could return an empty list).
3. 

In [7]:
import gc
import numpy as np
import torch,time,tqdm,pickle,os,itertools
import torch.nn.functional as F
from sklearn.cluster import kmeans_plusplus
from typing import List, Tuple, Dict, overload
from torch import nn
from abc import ABC, abstractmethod
from torch.utils.data import DataLoader,Dataset
from collections import Counter
# from .strategy import Strategy
# from pcdet.models import load_data_to_gpu

In [9]:
class ModelHandler(): 
    def get_grad_embedding(self, model, probs, feat): 
        '''
        creates gradient embedding from the  probabilities and  features.
        creates embeddings that reflect the uncertainty . or simply info values based on their prob and feat

        '''
        embeddings_pos = feat * (1 - probs) # create an embedding emphasizing features model is uncertain off.

        embeddings = torch.cat((embeddings_pos, -embeddings_pos), axis=-1) # represent both aspects of the prediction (e.g., for Binary Classification : one class vs. the other) in the embedding space.
        final_embeddings = torch.clone(embeddings)
         
        # rearranges the embedding such that the negated part of the features comes first
        # sawpping the 2 if the prob < 0.5
        final_embeddings[probs < 0.5] = torch.cat((-embeddings_pos[probs < 0.5], embeddings_pos[probs < 0.5]), axis=1)
        # B x 1 true false true

        return embeddings #not final_embedding?
    
    def enable_dropout(self, model):
        i = 0
        for m in model.modules():
            if m.__class__.__name__.startswith('Dropout'):
                i += 1
                m.train()
        print('**found and enabled {} Dropout layers for random sampling**'.format(i))


class DataSetHandler(): 
    unlabeled_idcs : List[int] 
    labeled_idcs : List[int] 
    
    def __init__(self, unlabeled_dataset: Dataset, labeled_dataset: Dataset, batch_size: int, num_workers: int):
        self.unlabeled_dataset = unlabeled_dataset
        self.labeled_dataset = labeled_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.unlabeled_idcs = list(range(len(self.unlabeled_dataset)))
        self.labeled_idcs = list(range(len(self.labeled_dataset)))


    def get_unlabeled_loader(self): 
        return DataLoader(self.unlabeled_dataset, batch_size=self.batch_size, 
                          shuffle=False, num_workers=self.num_workers)
 
     
    def get_labeled_loader(self): 
        return DataLoader(self.labeled_dataset, batch_size=self.batch_size, 
                          shuffle=False, num_workers=self.num_workers)
 


class StrategyNewInterface(ABC): 
    def __init__(self, *args, **kwarg): 
        pass 
 
    @overload 
    @staticmethod 
    @abstractmethod 
    def query(model : nn.Module, handler: ModelHandler, dataset_handler: DataSetHandler, 
             training_config: Dict, query_size: int, device, 
             *arg, **kwargs) -> Tuple[List, List]: 

       pass



class BadgeSampling(StrategyNewInterface):

    @staticmethod
    def query(model: nn.Module, handler: ModelHandler, dataset_handler: DataSetHandler,
              training_config: Dict, query_size: int, device,
              *args, **kwargs) -> Tuple[List, List]:
        
        
        rank        = training_config.get('rank', 0) 
        cur_epoch   = training_config.get.get('cur_epoch', 0)
        leave_pbar  = training_config.get.get('leave_pbar', True)
        save_points = training_config.get.get('save_points', lambda frame_id, pred_dict: None)
        active_label_dir = training_config.get.get('active_label_dir', './')
        
        unlabelled_loader = dataset_handler.get_unlabeled_loader()
        unlabelled_set  = dataset_handler.unlabeled_dataset

        val_loader = unlabelled_loader
        val_dataloader_iter = iter(val_loader)
        total_it_each_epoch = len(val_loader)
        
        if rank == 0:
            pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar,
                             desc='evaluating_unlabelled_set_epoch_%d' % cur_epoch, dynamic_ncols=True)
            
        handler.enable_dropout(model) # Enable dropout for uncertainty estimation
        rpn_preds_results = []

        if os.path.isfile(os.path.join(active_label_dir, f'grad_embeddings_epoch_{cur_epoch}.pkl')):
            print(f'found {cur_epoch} epoch grad embeddings... start resuming...')
            with open(os.path.join(active_label_dir, f'grad_embeddings_epoch_{cur_epoch}.pkl'), 'rb') as f:
                grad_embeddings = pickle.load(f)
        else:
            for cur_it in range(total_it_each_epoch):
                try:
                    unlabelled_batch = next(val_dataloader_iter)
                except StopIteration:
                    unlabelled_dataloader_iter = iter(val_loader)
                    unlabelled_batch = next(unlabelled_dataloader_iter)
                with torch.no_grad():
                    # load_data_to_gpu(unlabelled_batch)
                    pred_dicts, _ = model(unlabelled_batch)
                    for batch_inx in range(len(pred_dicts)):
                        save_points(unlabelled_batch['frame_id'][batch_inx], pred_dicts[batch_inx])
                        # final_full_cls_logits = pred_dicts[batch_inx]['pred_logits']
                    # did not apply batch mask -> directly output 
                    rpn_preds = pred_dicts[0]['rpn_preds']
                    batch_size = rpn_preds.shape[0]
                    rpn_preds = torch.argmax(rpn_preds.view(batch_size, -1, model.dense_head.num_class), -1)
                    rpn_preds_results.append(rpn_preds.cpu())
                if rank == 0:
                    pbar.update()
                    pbar.refresh()
                        
            if rank == 0:
                pbar.close()
            del rpn_preds
            del pred_dicts
            torch.cuda.empty_cache()
            print('start stacking cls and reg results as gt...')
            rpn_preds_results = torch.cat(rpn_preds_results, 0)

            print('retrieving grads on the training mode...')
            model.train()
            rpn_grad_embedding_list = []

            grad_loader = DataLoader(
                unlabelled_set, batch_size=1, pin_memory=True, num_workers=dataset_handler.get_unlabeled_loader().num_workers,
                shuffle=False, collate_fn=unlabelled_set.collate_batch,
                drop_last=False, sampler=unlabelled_loader.sampler, timeout=0
                )
            grad_dataloader_iter = iter(grad_loader)
            total_it_each_epoch = len(grad_loader)

            if rank == 0:
                pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar,
                                desc='inf_grads_unlabelled_set_epoch_%d' % cur_epoch, dynamic_ncols=True)
            
            for cur_it in range(total_it_each_epoch):
                try:
                    unlabelled_batch = next(grad_dataloader_iter)
                    
                except StopIteration:
                    unlabelled_dataloader_iter = iter(grad_loader)
                    unlabelled_batch = next(grad_dataloader_iter)

                # load_data_to_gpu(unlabelled_batch)
                    
                pred_dicts, _, _= model(unlabelled_batch)
                
                new_data  = {'box_cls_labels': rpn_preds_results[cur_it, :].cuda().unsqueeze(0), 'cls_preds': pred_dicts['rpn_preds']}
                rpn_loss = model.dense_head.get_cls_layer_loss(new_data=new_data)[0]
                # since the rpn head does not have dropout, we cannot get MC dropout labels for regression
                loss = rpn_loss
                model.zero_grad()
                loss.backward()

                rpn_grads = model.dense_head.conv_cls.weight.grad.clone().detach().cpu()
                rpn_grad_embedding_list.append(rpn_grads)
                
                if rank == 0:
                    pbar.update()
                    pbar.refresh()
                    

            if rank == 0:
                pbar.close()

        rpn_grad_embeddings = torch.stack(rpn_grad_embedding_list, 0) #rpn_grad_embedding_list
        del rpn_grad_embedding_list
        gc.collect()
        num_sample = rpn_grad_embeddings.shape[0]
        rpn_grad_embeddings = rpn_grad_embeddings.view(num_sample, -1)

        start_time = time.time()
        _, selected_rpn_idx = kmeans_plusplus(rpn_grad_embeddings.numpy(), n_clusters=query_size, random_state=0)
        print("--- kmeans++ running time: %s seconds for rpn grads---" % (time.time() - start_time))
        
        selected_idx = selected_rpn_idx
        model.zero_grad()
        model.eval()
        selected_frames = [unlabelled_set.sample_id_list[idx] for idx in selected_idx]
        return selected_frames


### Adapter Pattern for Compatibility

To ensure compatibility with the old interface, I created a new adapter class. This class will adapt the new interface to be compatible with the existing code that expects the old interface. 

Calling query abstracts the user whether it is called using the new or old interface. As it will try use old interface paramters to call the new interface query function. 

Thus backward compatible!

In [11]:
class StrategyAdapter:
    def __init__(self, model : nn.Module, new_strategy: StrategyNewInterface,
                 unlabeled_dataset: Dataset, labeled_dataset: Dataset, 
                 batch_size: int, num_workers: int,cfg):
        self.model=model
        self.new_strategy = new_strategy
        self.handler = ModelHandler()
        self.dataset_handler = DataSetHandler(unlabeled_dataset, labeled_dataset, batch_size, num_workers)
        self.cfg=cfg
 
    def query(self,leave_pbar=True, cur_epoch=None):

        # Assume training_config is part of cfg
        training_config = self.cfg.get('training_config', {})
        query_size = training_config.get('query_size', 10)  # Default query size
        device="cuda" if torch.cuda.is_available() else "cpu"        
        training_config["leave_pbar"]=leave_pbar
        training_config["cur_epoch"]=cur_epoch

        # We assuming the following paramters are in training_config:
        # training_config = {
        #     "rank":0,
        #     "active_label_dir": "...",
        #     "cur_epoch": 0,
        #     "leave_pbar": True,
        #     "save_points": lambda frame_id, pred_dict: None,
        #     # Other configuration settings...
        # }
        return self.new_strategy.query(
            self.model, self.handler, self.dataset_handler,
            training_config,query_size,device,)