## Online-Continual Few-Shot Domain Adaptation - Task-IL Method- Naive Strategy

### Import Libraries

In [1]:
import os
import sys
from subprocess import call
import datetime
import logging 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import json
from dotmap import DotMap
from sklearn import metrics
from tqdm import tqdm
import collections
from pprint import pprint

import random
import shutil
import socket

from PIL import Image
from scipy import stats
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

from CL.utils import (AverageMeter, datautils, is_div, per, reverse_domain,
                       torchutils, utils)

from CL.models import (CosineClassifier, MemoryBank, SSDALossModule,
                        compute_variance, loss_info, torch_kmeans,
                        update_data_memory)
from CL.utils import check_pretrain_dir, load_json, process_config, set_default

#from pcs.agents import BaseAgent

import json
from dotmap import DotMap


import torch.backends.cudnn as cudnn
from pcs_cl.utils import print_info, torchutils
from torch.utils.tensorboard import SummaryWriter

In [2]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
print(torchvision.__version__)

0.6.0a0+35d732a


### Check CUDA status

In [4]:
def print_cuda_statistics(nvidia_smi=True, output=print):
    output(f"Python VERSION: {sys.version}")
    output(f"pytorch VERSION: {torch.__version__}")
    output(f"CUDA VERSION: {torch.version.cuda}")
    output(f"CUDNN VERSION: {torch.backends.cudnn.version()}")
    output(f"Device NAME: {torch.cuda.get_device_name(0)}")
    output(f"Number CUDA Devices: {torch.cuda.device_count()}")
    output(f"Available devices: {torch.cuda.device_count()}")
    output(f"current CUDA Device: {torch.cuda.current_device()}")

    if nvidia_smi:
        print("nvidia-smi:")
        call(
            [
                "nvidia-smi",
                "--format=csv",
                "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free",
            ]
        )
        

output = print_cuda_statistics(nvidia_smi=True, output=print)

Python VERSION: 3.6.13 |Anaconda, Inc.| (default, Jun  4 2021, 14:25:59) 
[GCC 7.5.0]
pytorch VERSION: 1.5.1
CUDA VERSION: 10.2
CUDNN VERSION: 7605
Device NAME: Tesla V100-PCIE-16GB
Number CUDA Devices: 1
Available devices: 1
current CUDA Device: 0
nvidia-smi:


In [5]:
!nvidia-smi

Sun Oct 16 13:12:29 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:2F:00.0 Off |                    0 |
| N/A   46C    P0    37W / 250W |  10561MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

### Configuration JSON file

In [6]:
#json

def load_json(f_path):
    with open(f_path, "r") as f:
        return json.load(f)
    
    
def save_json(obj, f_path):
    with open(f_path, "w") as f:
        json.dump(obj, f, ensure_ascii=False, indent=4)


        
def process_config_path(config_path, override_dotmap=None):
    config_json = load_json(config_path)
    return process_config(config_json, override_dotmap=override_dotmap)        
        
def exist_key(k):
    is_empty_dotmap = isinstance(k, DotMap) and len(k) == 0
    return isinstance(k, bool) or (not is_empty_dotmap and k is not None)

    
def set_default(cur_config, name, value=None, callback=None):
    if not exist_key(cur_config[name]):
        if value is not None:
            cur_config[name] = value
        elif callback is not None:
            assert exist_key(cur_config[callback])
            cur_config[name] = cur_config[callback]
        elif value is None and callback is None:
            cur_config[name] = value
        else:
            raise NotImplementedError
    return cur_config[name]

In [7]:
def adjust_config(config):
    set_default(config, "validate_freq", value=1)
    set_default(config, "copy_checkpoint_freq", value=50)
    set_default(config, "debug", value=False)
    set_default(config, "cuda", value=True)
    set_default(config, "gpu_device", value=None)
    set_default(config, "pretrained_exp_dir", value=None)
    set_default(config, "agent", value="CDSAgent")

    # data_params
    set_default(config.data_params, "aug_src", callback="aug")
    set_default(config.data_params, "aug_tgt", callback="aug")
    set_default(config.data_params, "num_workers", value=4)
    set_default(config.data_params, "image_size", value=224)
    set_default(config.data_params, "task_num", value=8)
    set_default(config.data_params, "task", value= True)

    # model_params
    set_default(config.model_params, "load_weight_epoch", value=0)
    set_default(config.model_params, "load_memory_bank", value=True)

    # loss_params
    num_loss = len(config.loss_params.loss)
    set_default(config.loss_params, "weight", value=[1] * num_loss)
    set_default(config.loss_params, "start", value=[0] * num_loss)
    set_default(config.loss_params, "end", value=[1000] * num_loss)
    if not isinstance(config.loss_params.temp, list):
        config.loss_params.temp = [config.loss_params.temp] * num_loss
    assert len(config.loss_params.weight) == num_loss
    set_default(config.loss_params, "m", value=0.5)
    set_default(config.loss_params, "T", value=0.05)
    set_default(config.loss_params, "pseudo", value=True)

    # optim_params
    set_default(config.optim_params, "batch_size_src", callback="batch_size")
    set_default(config.optim_params, "batch_size_tgt", callback="batch_size")
    set_default(config.optim_params, "batch_size_lbd", callback="batch_size_lbd")
    set_default(config.optim_params, "momentum", value=0.9)
    set_default(config.optim_params, "nesterov", value=True)
    set_default(config.optim_params, "lr_decay_rate", value=0.1)
    set_default(config.optim_params, "cls_update", value=True)

    # clustering
    if config.loss_params.clus is not None:
        if config.loss_params.clus.type is None:
            config.loss_params.clus = None
        else:
            if not isinstance(config.loss_params.clus.type, list):
                config.loss_params.clus.type = [config.loss_params.clus.type]
            k = config.loss_params.clus.k
            k_task = config.loss_params.clus.k_task
            n_k = config.loss_params.clus.n_k
            config.k_list = k * n_k
            config.k_list_task = k_task * n_k
            config.loss_params.clus.n_kmeans = len(config.k_list)
            config.loss_params.clus.n_kmeans= len(config.k_list_task)

    return config


In [9]:

config_json = load_json("./config_oh_AR_Cl.json")
pre_checkpoint_dir = check_pretrain_dir(config_json)    

config = process_config(config_json)

config = adjust_config(config)
print(config.loss_params.clus.n_k)
print(config.data_params.task)
config.optim_params.batch_size_src
config.gpu_device
print("task_num", config.data_params.task_num)

[INFO]: Experiment directory is located at ./exps/experiments/officehome/Art->Clipart:naiv-221016131256
[INFO]: Experiment directory is located at ./exps/experiments/officehome/Art->Clipart:naiv-221016131256
[INFO]: Configurations and directories successfully set up.
[INFO]: Configurations and directories successfully set up.


Configuration Loaded:
{'agent': 'CDSAgent',
 'copy_checkpoint_freq': 50,
 'data_params': {'aug': 'aug_0',
                 'fewshot': '6',
                 'image_size': 224,
                 'name': 'office_home',
                 'num_workers': 4,
                 'source': 'Art',
                 'target': 'Clipart',
                 'task': True,
                 'task_num': 8},
 'exp_base': './exps',
 'exp_id': 'Art->Clipart:naiv',
 'exp_name': 'officehome',
 'loss_params': {'T': 0.05,
                 'clus': {'k': [64, 128],
                          'k_task': [8, 16],
                          'kmeans_freq': 1,
                          'n_k': 32,
                          'type': ['each']},
                 'loss': ['cls-so',
                          'proto-src',
                          'proto-tgt',
                          'I2C-cross',
                          'semi-condentmax',
                          'semi-entmin',
                          'tgt-condentmax',
        

### Define number of classes per Domain/Task

In [11]:

def create_image_label(image_list):
    image_index = [x.split(" ")[0] for x in open(image_list)]
    label_list = np.array([int(x.split(" ")[1].strip()) for x in open(image_list)])
    return image_index, label_list



def get_class_map(image_list):
    class_map = {}
    for x in open(image_list):
        key = int(x.split(" ")[1].strip())
        if key not in class_map:
            class_map[key] = x.split(" ")[0].split("/")[-2]
    class_map = collections.OrderedDict(sorted(class_map.items()))
    return class_map


def get_class_num(image_list):
    # return len(get_class_map(image_list))
    return max(list(get_class_map(image_list).keys())) + 1


domain_map = {
            "source": "Art",
            "target": "Clipart",
        }

name = config.data_params.name
domain =domain_map


num_class = get_class_num(
            f'./data/splits/{name}/{domain["source"]}.txt' )


def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))


task_classes_arr = list(split(range(num_class), config.data_params.task_num))
print(task_classes_arr)


num_class_task = num_class // len((task_classes_arr)) 
print(num_class_task)


[range(0, 9), range(9, 17), range(17, 25), range(25, 33), range(33, 41), range(41, 49), range(49, 57), range(57, 65)]
8


### Datautils

In [None]:



datasets_path = {
    "office": "./data/office",
    "office_home": "./data/officehome",
    "visda17": "./data/visda17",
    "domainnet": "./data/domainnet",
}

domain_map = {
            "source": "Art",
            "target": "Clipart",
        }


def get_fewshot_index(lbd_dataset, whl_dataset):
    lbd_imgs = lbd_dataset.imgs
    whl_imgs = whl_dataset.imgs
    fewshot_indices = [whl_imgs.index(path) for path in lbd_imgs]
    fewshot_labels = lbd_dataset.labels
    return fewshot_indices, fewshot_labels


def get_class_map(image_list):
    class_map = {}
    for x in open(image_list):
        key = int(x.split(" ")[1].strip())
        if key not in class_map:
            class_map[key] = x.split(" ")[0].split("/")[-2]
    class_map = collections.OrderedDict(sorted(class_map.items()))
    return class_map


def get_class_num(image_list):
    # return len(get_class_map(image_list))
    return max(list(get_class_map(image_list).keys())) + 1



def get_fewshot_index(lbd_dataset, whl_dataset):
    lbd_imgs = lbd_dataset.imgs
    whl_imgs = whl_dataset.imgs
    fewshot_indices = [whl_imgs.index(path) for path in lbd_imgs]
    fewshot_labels = lbd_dataset.labels
    return fewshot_indices, fewshot_labels


def get_augmentation(trans_type="aug_0", image_size=224):
       # stat = "imagenet"
    mean = [0.485, 0.456, 0.406]
    std =  [0.229, 0.224, 0.225]
    image_s = image_size + 32

    data_transforms = {
        "raw": transforms.Compose(
            [
                transforms.Resize((image_s, image_s)),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        ),
        "aug_0": transforms.Compose(
            [
                transforms.Resize((image_s, image_s)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        ),
        "aug_1": transforms.Compose(
            [
                transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ]
        ),
    }

    return data_transforms[trans_type]



class Imagelists(Dataset):
    def __init__(
        self,
        image_list,
        root,
        transform=None,
        target_transform=None,
        keep_in_mem=False,
        ret_index=False,
    ):
        #print(image_list)
        imgs, labels = create_image_label(image_list)
        self.imgs = imgs
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform
        self.root = root
        self.ret_index = ret_index
        self.keep_in_mem = keep_in_mem
        self.loader = pil_loader
       

        # keep in mem
        if self.keep_in_mem:
            images = []
            for index in range(len(self.imgs)):
                path = os.path.join(self.root, self.imgs[index])
                img = self.loader(path)
                if self.transform is not None:
                    img = self.transform(img)
                images.append(img)
            self.images = images

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is
            class_index of the target class.
        """
        if self.keep_in_mem:
            img = self.images[index]
           
        else:
            path = os.path.join(self.root, self.imgs[index])
            img = self.loader(path)
            if self.transform is not None:
                img = self.transform(img)

                
        target = self.labels[index]  # target is label asocciated with an index
        
        if self.target_transform is not None:
            target = self.target_transform(target)
         

        if not self.ret_index:
            return img, target
        else:
            return index, img, target

    def __len__(self):
        return len(self.imgs)


def create_dataset(
    name,
    domain,
    txt="",
    suffix="",
    keep_in_mem=False,
    ret_index=False,
    image_transform=None,
    use_mean_std=False,
    image_size=224,
):
    if suffix != "":
        suffix = "_" + suffix
    if txt == "":
        txt = f"{domain}{suffix}"
        print("txt:" , txt)
    #stat = f"{name}_{domain}" if use_mean_std else "imagenet"
    #if image_transform is not None and isinstance(image_transform, str):
    transform = get_augmentation(image_transform, image_size=image_size )

    return Imagelists(
        f"data/splits/{name}/{txt}.txt",
        datasets_path[name],
        keep_in_mem=keep_in_mem,
        ret_index=ret_index,
        transform=transform,
    )



def pil_loader(path):
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


def worker_init_seed(worker_id):
    np.random.seed(12 + worker_id)
    random.seed(12 + worker_id)



def create_loader(dataset, batch_size,  num_workers=4, is_train=True):
    return DataLoader(
        dataset,
        batch_size=min(batch_size, len(dataset)),
        #batch_size= len(dataset), 
        num_workers=num_workers,
        shuffle=is_train,
        drop_last=is_train,
        pin_memory=True,
        worker_init_fn=worker_init_seed,
    )



### BaseAgent

In [None]:

import datetime
import logging
import os

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from pcs_cl.utils import print_info, torchutils
from torch.utils.tensorboard import SummaryWriter



class BaseAgent(object):
    """
    General agent class
    Abstract Methods to be implemented:
        _load_datasets
        _create_model
        _create_optimizer
        train_one_epoch
        validate
        load_checkpoint
        save_checkpoint
    """

    def __init__(self, config):
        self.config = config
        # set seed as early as possible
        torchutils.set_seed(self.config.seed)

        self.model = None
        self.optim = None
        self.logger = logging.getLogger("Agent")
        self.summary_writer = SummaryWriter(log_dir=self.config.summary_dir)
        
        
        self.current_epoch = 0
        self.current_iteration = 0
        self.current_val_iteration = 0
        self.val_acc = []
        self.train_loss = []
        self.lr_scheduler_list = []

        print_info(self.logger.info)
        self.starttime = datetime.datetime.now()
        self._choose_device()

        # Load Dataset
        self._load_datasets()

        self._create_model()
        self._create_optimizer()

        # we need these to decide best loss
        self.current_loss = 0.0
        self.current_val_metric = 0.0
        self.best_val_metric = 0.0
        self.best_val_epoch = 0
        self.iter_with_no_improv = 0

    def get_attr(self, domain, name):
        return getattr(self, f"{name}_{domain}")

    def set_attr(self, domain, name, value):
        setattr(self, f"{name}_{domain}", value)
        return self.get_attr(domain, name)

    def _choose_device(self):
        # check if use gpu
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )
        self.cuda = self.is_cuda & self.config.cuda

        if self.cuda:
            self.device = torch.device("cuda")
            cudnn.benchmark = True

            if self.config.gpu_device is None:
                self.config.gpu_device = list(range(torch.cuda.device_count()))
            elif not isinstance(self.config.gpu_device, list):
                self.config.gpu_device = [self.config.gpu_device]
            self.gpu_devices = self.config.gpu_device

            # set device when only one gpu
            num_gpus = len(self.gpu_devices)
            self.multigpu = num_gpus > 1 and torch.cuda.device_count() > 1
            if not self.multigpu:
                torch.cuda.set_device(self.gpu_devices[0])

            gpu_devices = ",".join([str(_gpu_id) for _gpu_id in self.gpu_devices])
            self.logger.info(f"User specified {num_gpus} GPUs: {gpu_devices}")
            print("length gpu devices", len(self.gpu_devices))
            self.parallel_helper_idxs = torch.arange(len(self.gpu_devices)).to(
                self.device
            )

            self.logger.info("Program will run on *****GPU-CUDA***** ")
            torchutils.print_cuda_statistics(output=self.logger.info, nvidia_smi=False)
        else:
            self.device = torch.device("cpu")
            self.logger.info("Program will run on *****CPU*****\n")

    def _load_datasets(self):
        raise NotImplementedError

    def _create_model(self):
        raise NotImplementedError

    def _create_optimizer(self):
        raise NotImplementedError

    def run(self):
        """
        The main operator
        :return:
        """
        try:
            print("Train function is called in run ")
            self.train()
            print("train function is running")
            self.cleanup()
            print("After cleanup")
        except KeyboardInterrupt as e:
            self.logger.info("Interrupt detected. Saving data...")
            self.backup()
            self.cleanup()
            raise e
        except Exception as e:
            self.logger.error(e, exc_info=True)

    def train(self):
        """
        Main training loop
        :return:
        """
        print("******Main training lOOP starts***********")

#             print("First validation is running")
        for epoch in range(self.current_epoch + 1, self.config.num_epochs + 1):
            # early stop
            patience = self.config.optim_params.patience
            if patience and self.iter_with_no_improv > patience:
                self.logger.info(
                    f"accuracy not improved in {patience} epoches, stopped"
                )
                break
            # train
            self.current_epoch = epoch
            print("******Before training one epoch********")
            self.train_one_epoch()
            

            for sch in self.lr_scheduler_list:
                sch.step()
            # save
            self.save_checkpoint()

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """
        raise NotImplementedError

    def validate(self):
        """
        One cycle of model validation
        :return:
        """
        raise NotImplementedError

    def backup(self):
        """
        Backs up the model upon interrupt
        """
        self.summary_writer.close()
        self.save_checkpoint(filename="backup.pth.tar")

    def finalise(self):
        """
        Do appropriate saving after model is :finished training
        """
        self.backup()

    def cleanup(self):
        """
        Undo any global changes that the Agent may have made
        """
        if hasattr(self, "best_val_epoch"):
            self.logger.info(
                f"Best Val acc at {self.best_val_epoch}: {self.best_val_metric:.3}"
            )
        endtime = datetime.datetime.now()
        exe_time = endtime - self.starttime
        self.logger.info(
            f"End at time: {endtime.strftime('%Y.%m.%d-%H:%M:%S')}, total time: {exe_time.seconds}s"
        )

    def copy_checkpoint(self, filename="checkpoint.pth.tar"):
        if (
            self.config.copy_checkpoint_freq
            and self.current_epoch % self.config.copy_checkpoint_freq == 0
        ):
            self.logger.info(f"Backup checkpoint_epoch_{self.current_epoch}.pth.tar")
            torchutils.copy_checkpoint(
                filename=filename,
                folder=self.config.checkpoint_dir,
                copyname=f"checkpoint_epoch_{self.current_epoch}.pth.tar",
            )
            

    def load_checkpoint(self, filename):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        raise NotImplementedError
        
        

    def save_checkpoint(self, filename="checkpoint.pth.tar"):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far
        :return:
        """
        raise NotImplementedError

### CDSAgent

In [None]:
 ls_abbr = {
    "cls-so": "cls",
    "proto-each": "P",
    "proto-src": "Ps",
    "proto-tgt": "Pt",
    "cls-info": "info",
    "I2C-cross": "C",
    "semi-condentmax": "sCE",
    "semi-entmin": "sE",
    "tgt-condentmax": "tCE",
    "tgt-entmin": "tE",
    "ID-each": "I",
    "CD-cross": "CD",
   }


class Agent(BaseAgent):
    
    
    def __init__(self, config):
        self.config = config
        self._define_task(config)
        self.is_features_computed = False
        self.current_iteration_source = self.current_iteration_target = 0
        self.domain_map = {
            "source": self.config.data_params.source,
            "target": self.config.data_params.target,
            }
        
        self.task_classes_arr = task_classes_arr
        print("task_classes_arr:" , self.task_classes_arr)
        
        self.task_num = config.data_params.task_num
        print("tasks_num", self.task_num)
          
        super(Agent, self).__init__(config)
        
 
       
        #For TASK-IL 
        self.momentum_softmax_target_task = []  
        self.momentum_softmax_source_task = []   
        for task_id, task_classes in enumerate(self.task_classes_arr):
            momentumsoftmax_tgt = torchutils.MomentumSoftmax(
               len(task_classes),  m=len(self.get_attr(task_id, "train_loader_target_task"))) 

            self.momentum_softmax_target_task.append(momentumsoftmax_tgt)
            
            momentumsoftmax_src = torchutils.MomentumSoftmax(
                len(task_classes),  m=len(self.get_attr(task_id, "train_loader_source_task")))
            
            self.momentum_softmax_source_task.append(momentumsoftmax_src)

        
      
        # init loss  
        loss_fn = SSDALossModule(self.config, gpu_devices=self.gpu_devices)  
        loss_fn = nn.DataParallel(loss_fn, device_ids=self.gpu_devices).cuda()
        self.loss_fn = loss_fn

        if self.config.pretrained_exp_dir is None: 
            self._init_memory_bank()   

        # init statics    
        self._init_labels_task()
        self._load_fewshot_to_cls_weight()



    
    def _define_task(self, config):   
        # specify task
        self.fewshot = config.data_params.fewshot
        self.task = config.data_params.task
        self.clus = config.loss_params.clus != None
        self.cls = self.semi = self.tgt = self.ssl = False
        self.is_pseudo_src = self.is_pseudo_tgt = False
        for ls in config.loss_params.loss:
            self.cls = self.cls | ls.startswith("cls")
            self.semi = self.semi | ls.startswith("semi")
            self.tgt = self.tgt | ls.startswith("tgt")
            self.ssl = self.ssl | (ls.split("-")[0] not in ["cls", "semi", "tgt"])
            self.is_pseudo_src = self.is_pseudo_src | ls.startswith("semi-pseudo")
            self.is_pseudo_tgt = self.is_pseudo_tgt | ls.startswith("tgt-pseudo")

        self.is_pseudo_src = self.is_pseudo_src | (
            config.loss_params.pseudo and self.fewshot is not None 
        )
        self.is_pseudo_tgt = self.is_pseudo_tgt | config.loss_params.pseudo
        self.semi = self.semi | self.is_pseudo_src
        if self.clus:
            self.is_pseudo_tgt = self.is_pseudo_tgt | (
                config.loss_params.clus.tgt_GC ==  "PGC" and "GC" in config.clus.type 
            )
    
    
  
    #.........................Modified_2
    def _init_labels_task(self): #return an array whose elemts other than -1 is the label of the fewshot data

        for task_id, _ in enumerate(self.task_classes_arr): 
            self.task_id = task_id
            train_len_tgt_tsk = self.get_attr(self.task_id, "train_len_init_target_task")
            train_len_src_tsk = self.get_attr(self.task_id, "train_len_init_source_task")

            # labels for pseudo
            if self.fewshot:
                self.predict_ordered_labels_pseudo_source_tsk = (
                    torch.zeros(train_len_src_tsk, dtype=torch.long).detach().cuda() - 1 #creates a tensor of -1's 
                )


                for ind, lbl in zip(self.fewshot_index_source_tasks[self.task_id,:], self.fewshot_label_source_tasks[self.task_id,:]):

                    self.predict_ordered_labels_pseudo_source_tsk[int(ind)] = int(lbl) #change the -1's to lbl's only for those in self.fewshot_index_source

            self.predict_ordered_labels_pseudo_target_tsk = (
                torch.zeros(train_len_tgt_tsk, dtype=torch.long).detach().cuda() - 1
            )

            self.set_attr(self.task_id, "predict_ordered_labels_pseudo_source_tsk", self.predict_ordered_labels_pseudo_source_tsk) 
            self.set_attr(self.task_id, "predict_ordered_labels_pseudo_target_tsk", self.predict_ordered_labels_pseudo_target_tsk) 
 
        

    #...............................
    def _load_datasets(self):

        name = self.config.data_params.name
        num_workers = self.config.data_params.num_workers
        fewshot = self.config.data_params.fewshot
        domain = self.domain_map
        image_size = self.config.data_params.image_size
        aug_src = self.config.data_params.aug_src
        aug_tgt = self.config.data_params.aug_tgt
        raw = "raw"

        
        #..................................
        self.num_class = get_class_num(
            f'./data/splits/{name}/{domain["source"]}.txt'
        )


        self.class_map = get_class_map(
            f'data/splits/{name}/{domain["target"]}.txt'
        )

        
        
        self.num_class_task = self.num_class // self.task_num 
        print("num_class_task:" , self.num_class_task)
        
        #....................................
        batch_size_dict = {
            "test": self.config.optim_params.batch_size,
            "source": self.config.optim_params.batch_size_src,
            "target": self.config.optim_params.batch_size_tgt,
            "labeled": self.config.optim_params.batch_size_lbd,
        }
        self.batch_size_dict = batch_size_dict
 

        #............................self-supervised Dataset.......
    
        for domain_name in ("source", "target"):
            aug_name = {"source": aug_src, "target": aug_tgt}[domain_name]

            #print("domain_name: ", domain_name) 

            # Training datasets
            train_dataset = create_dataset(
                name,
                domain[domain_name],
                suffix="",
                ret_index=True,
                image_transform=aug_name,
                use_mean_std=False,
                image_size=image_size,)

            train_loader = create_loader(
                train_dataset,
                batch_size_dict[domain_name],
                is_train=True,
                num_workers=num_workers,
            )
            train_init_loader = create_loader(          
                train_dataset,
                batch_size_dict[domain_name],
                is_train=False,
                num_workers=num_workers,
            )
            
    
            train_labels = torch.from_numpy(train_dataset.labels).detach().cuda()
            print("debug-train-labels-shape:", train_labels.shape)
            
            self.set_attr(domain_name, "train_dataset", train_dataset)  #setattr(domain, name, value) --> train_dataset_source/train_dataset_target
            self.set_attr(domain_name, "train_ordered_labels", train_labels)
            self.set_attr(domain_name, "train_loader", train_loader)
            self.set_attr(domain_name, "train_init_loader", train_init_loader)
            self.set_attr(domain_name, "train_len", len(train_dataset))    
            
           
            #.........................train_task
            print(".......Task IL starts......")
            
            #task_classes_arr = [(0,1,2,3,4), (5,6,7,8,9), (10,11,12,13,14), (15,16,17,18,19), (20,21,22,23,24), (25,26,27,28,29) ] #,30
            #task_classes_arr = [tuple(list(range(i*tasks_num,((i+1)*tasks_num)+2))) for i in range(0,int(len(np.unique(labels))/tasks_num))]  #with overlap and without overlap
            #tasks_num = len(task_classes_arr)

            
            #********************************train-dataset-Source
            if domain_name == "source":
                
                train_dataset_source =  self.get_attr("source", "train_dataset") 
                train_loader_source = self.get_attr("source","train_loader")
                train_init_loader_source = self.get_attr("source","train_init_loader")
                
                #.................train-dataset-not-drop_lost-source
                x_train_source_idx = torch.Tensor([0])#.cuda()
                x_train_source_img = torch.zeros((1,3,224,224))#.cuda()
                x_train_source_lbl = torch.Tensor([0])#.cuda()

                for idx , data in enumerate(train_loader_source):

                    indx, img, label = data
#                     indx = indx.cuda()
#                     img = img.cuda()
#                     label = label.cuda()

                    x_train_source_idx=x_train_source_idx.to(torch.float32)
                    indx=indx.to(torch.float32)
                    label=label.to(torch.float32)
                    x_train_source_idx = torch.cat((x_train_source_idx, indx), 0)
                    x_train_source_img = torch.cat((x_train_source_img, img), 0)
                    x_train_source_lbl = torch.cat((x_train_source_lbl, label), 0)

                image = x_train_source_img[1:]     
                label = x_train_source_lbl[1:]
                index = x_train_source_idx[1:]

                # Divide the data over the different tasks
                task_data_source = [] 
                for task_id, task_classes in enumerate(self.task_classes_arr):
                    
                    
                    train_mask = np.isin(label, task_classes)
                    idx_train_task, x_train_task, y_train_task = index[train_mask], image[train_mask], label[train_mask]
                    task_data_source.append((idx_train_task, x_train_task, y_train_task))  #- (task_id * 2) --> with overlap
           

            
            
                #create tensor dataset
                for task_id, _ in enumerate(self.task_classes_arr):
                    dataset_source_task = TensorDataset(task_data_source[task_id][0], task_data_source[task_id][1], task_data_source[task_id][2])
                    dataloader_source_task = DataLoader(dataset_source_task, batch_size= batch_size_dict[domain_name]) 
                    

                    train_labels_source_task = task_data_source[task_id][2] #dataset_source_task.labels.detach() #.cuda() #torch.from_numpy(dataset_source_task.labels).detach().cuda()

        
                    self.set_attr(task_id, "train_data_source_task", dataset_source_task)  
                    self.set_attr(task_id, "train_loader_source_task", dataloader_source_task)  
                    
                    train_data_source_task = self.get_attr(task_id, "train_data_source_task")
                    self.set_attr(task_id, "train_len_source_task", len(train_data_source_task))  
                    self.set_attr(task_id, "train_ordered_labels_source_task", train_labels_source_task)
               
                
                
                #*****************train-init-dataset-Source-not-drop-last-unfulled-batch
                
                x_train_init_source_idx = torch.tensor([0])
                x_train_init_source_img = torch.zeros((1,3,224,224))
                x_train_init_source_lbl = torch.tensor([0])

                for idx , data in enumerate(train_init_loader_source):

                    indx, img, label = data
#                     print(indx.shape)
        
                    x_train_init_source_idx = torch.cat((x_train_init_source_idx, indx), 0)
                    x_train_init_source_img = torch.cat((x_train_init_source_img, img), 0)
                    x_train_init_source_lbl = torch.cat((x_train_init_source_lbl, label), 0)

                image = x_train_init_source_img[1:]     
                label = x_train_init_source_lbl[1:]
                index = x_train_init_source_idx[1:]

                # Divide the data over the different tasks
                task_data_init_source = []   
                for task_id, task_classes in enumerate(self.task_classes_arr):

                    train_mask = np.isin(label, task_classes)
                    print(train_mask)
                    idx_train_task, x_train_task, y_train_task = index[train_mask], image[train_mask], label[train_mask]
                    task_data_init_source.append((idx_train_task, x_train_task, y_train_task))  
                   
              

                #create tensor dataset
                for task_id, _ in enumerate(self.task_classes_arr):
                    dataset_init_source_task = TensorDataset(task_data_init_source[task_id][0], task_data_init_source[task_id][1], task_data_init_source[task_id][2])
                    dataloader_init_source_task = DataLoader(dataset_init_source_task, batch_size= batch_size_dict[domain_name]) # create your dataloader
                    
                    
                    
                    self.set_attr(task_id, "train_data_init_source_task", dataset_init_source_task)  
                    self.set_attr(task_id, "train_loader_init_source_task", dataloader_init_source_task)  
                    train_data_init_source_task = self.get_attr(task_id, "train_data_init_source_task")
                    self.set_attr(task_id, "train_len_init_source_task", len(train_data_init_source_task)) 
                



            #...............................Target
            print("*******Create tarin-target-task************")
            
            if domain_name == "target":
                
                
                train_dataset_target =  self.get_attr("target", "train_dataset")   
                train_loader_target = self.get_attr("target", "train_loader") 
                train_init_loader_target = self.get_attr("target","train_init_loader")
                
                
            
                x_train_target_idx = torch.tensor([0])#.cuda()
                x_train_target_img = torch.zeros((1,3,224,224))#.cuda()
                x_train_target_lbl = torch.tensor([0])#.cuda()

                for idx , data in enumerate(train_loader_target):

                    indx, img, label = data
#                     indx = indx.cuda()
#                     img = img.cuda()
#                     label = label.cuda()
                    
#                     x_train_target_idx=x_train_target_idx.to(torch.float32)
#                     indx=indx.to(torch.float32)
#                     label=label.to(torch.float32)
                
                    x_train_target_idx = torch.cat((x_train_target_idx, indx), 0)
                    x_train_target_img = torch.cat((x_train_target_img, img), 0)
                    x_train_target_lbl = torch.cat((x_train_target_lbl, label), 0)


                image = x_train_target_img[1:]
                label = x_train_target_lbl[1:]
                index = x_train_target_idx[1:]

                # Divide the data over the different tasks
                task_data_target = []
                for task_id, task_classes in enumerate(self.task_classes_arr):

                    train_mask = np.isin(label, task_classes)
                    idx_train_task, x_train_task, y_train_task = index[train_mask], image[train_mask], label[train_mask]
                    task_data_target.append((idx_train_task,x_train_task, y_train_task))  #- (task_id * 2) #with overlap
                    #print("length of train_task:" ,len(x_train_task) ) 

                    
                for task_id, _ in enumerate(self.task_classes_arr):
        
                    dataset_target_task = TensorDataset(task_data_target[task_id][0], task_data_target[task_id][1], task_data_target[task_id][2])
                    dataloader_target_task = DataLoader(dataset_target_task, batch_size= batch_size_dict[domain_name])
                    
                    #train_labels_target_task = torch.from_numpy(dataset_target_task.labels).detach().cuda()
                    train_labels_target_task = task_data_target[task_id][2]
                    
                    
                  
                    self.set_attr(task_id, "train_data_target_task", dataset_target_task)  
                    self.set_attr(task_id, "train_loader_target_task", dataloader_target_task)  
                    train_data_target_task = self.get_attr(task_id, "train_data_target_task")
                    self.set_attr(task_id, "train_len_target_task", len(train_data_target_task))
                    self.set_attr(task_id, "train_ordered_labels_target_task", train_labels_target_task)
                  
                    
                train_len_target_task = self.get_attr(0, "train_len_target_task")   
                print("*****train_len_target_task******", train_len_target_task)
                       
                #*******************************train-init-dataset-target-not-drop_last-unfulled_batch
                print("Create train-init-target-dataset")
                x_train_init_target_idx = torch.tensor([0])#.cuda()   
                x_train_init_target_img = torch.zeros((1,3,224,224))#.cuda()
                x_train_init_target_lbl = torch.tensor([0])#.cuda()

                for idx , data in enumerate(train_init_loader_target):

                    indx, img, label = data
#                     indx = indx.cuda()
#                     img = img.cuda()
#                     label = label.cuda()
    
                    x_train_init_target_idx = torch.cat((x_train_init_target_idx, indx), 0)
                    x_train_init_target_img = torch.cat((x_train_init_target_img, img), 0)
                    x_train_init_target_lbl = torch.cat((x_train_init_target_lbl, label), 0)

                image = x_train_init_target_img[1:]     
                label = x_train_init_target_lbl[1:]
                index = x_train_init_target_idx[1:]

                # Divide the data over the different tasks
                task_data_init_target = []   
                for task_id, task_classes in enumerate(self.task_classes_arr):

                    train_mask = np.isin(label.cpu(), task_classes)
                    idx_train_task, x_train_task, y_train_task = index[train_mask], image[train_mask], label[train_mask]
                    task_data_init_target.append((idx_train_task, x_train_task, y_train_task))  
                   
              

                #create tensor dataset
                for task_id, _ in enumerate(self.task_classes_arr):
                    dataset_init_target_task = TensorDataset(task_data_init_target[task_id][0], task_data_init_target[task_id][1], task_data_init_target[task_id][2])
                    dataloader_init_target_task = DataLoader(dataset_init_target_task, batch_size= batch_size_dict[domain_name]) # create your dataloader
                    
                    
                    
                    self.set_attr(task_id, "train_data_init_target_task", dataset_init_target_task)  
                    self.set_attr(task_id, "train_loader_init_target_task", dataloader_init_target_task)  
                    train_data_init_target_task = self.get_attr(task_id, "train_data_init_target_task")
                    self.set_attr(task_id, "train_len_init_target_task", len(train_data_init_target_task))  
                
                
                train_len_init_target_task = self.get_attr(0, "train_len_init_target_task")   
                print("*****train_len_init_target_task******",train_len_init_target_task)
                       

        #...............................Classification and Fewshot Dataset
        if fewshot:
#             print(int(fewshot))
            print("**********Create few-shot dataset********")
            train_lbd_dataset_source = create_dataset(  #Retrun an imagelist .txt file where the labeled samples are defined
                name,
                domain["source"],
                suffix=f"labeled_{fewshot}",
                ret_index=True,
                image_transform=aug_src,
                image_size=image_size,
            )
            
            self.train_lbd_loader_source = create_loader(           
                train_lbd_dataset_source,
                batch_size_dict["labeled"],
                is_train=False,
                num_workers=num_workers,
            )
             
           
            src_dataset = self.get_attr("source", "train_dataset")
            
            (
                self.fewshot_index_source,
                self.fewshot_label_source,
            ) = get_fewshot_index(train_lbd_dataset_source, src_dataset)
            
 

            test_unl_dataset_source = create_dataset(
                name,
                domain["source"],
                suffix=f"unlabeled_{fewshot}",
                ret_index=True,
                image_transform=raw,
                image_size=image_size,
            )
            self.test_unl_loader_source = create_loader(             
                test_unl_dataset_source,
                batch_size_dict["test"],
                is_train=False,
                num_workers=num_workers,
            )
            

            #........................ task_devision -fewshot
   
            x_train_lbd_idx = torch.tensor([0])
            x_train_lbd_img = torch.zeros((1,3,224,224))
            x_train_lbd_lbl = torch.tensor([0])

            for idx , data in enumerate(self.train_lbd_loader_source):   

                indx, img, label = data

                print("idx", idx)
                
                print("index_fewshot_totall:" , indx)
                x_train_lbd_idx = torch.cat((x_train_lbd_idx, indx), 0)
                x_train_lbd_img = torch.cat((x_train_lbd_img, img), 0)
                x_train_lbd_lbl = torch.cat((x_train_lbd_lbl, label), 0)


            image = x_train_lbd_img[1:]
            label = x_train_lbd_lbl[1:]
            index = x_train_lbd_idx[1:]

            # Divide the data over the different tasks
            task_train_lbd = []
            for task_id, task_classes in enumerate(self.task_classes_arr):

                train_mask = np.isin(label.cpu(), task_classes)
                idx_train_task, x_train_task, y_train_task = index[train_mask] , image[train_mask], label[train_mask]
                task_train_lbd.append(( idx_train_task, x_train_task, y_train_task))  #- (task_id * 2) #with overlap

      
            #...............................create dataset
            for task_id, _ in enumerate(self.task_classes_arr):
                dataset_task_train_lbd_source = TensorDataset(task_train_lbd[task_id][0], task_train_lbd[task_id][1],task_train_lbd[task_id][2])
                dataloader_task_train_lbd_source = DataLoader(dataset_task_train_lbd_source, batch_size= batch_size_dict["test"]) 

                
                self.set_attr(task_id, "train_data_lbd_source_task", dataset_task_train_lbd_source) 
                self.set_attr(task_id, "train_lbd_loader_source_task", dataloader_task_train_lbd_source)  
                
                
               
            
             
            #.............colculate fewshot_index and fewshot_labels
        
           
            num_labelled_each_task =  self.num_class_task          
            self.fewshot_index_source_tasks = np.zeros((self.task_num, num_labelled_each_task*int(fewshot)))
            self.fewshot_label_source_tasks = np.zeros((self.task_num, num_labelled_each_task*int(fewshot)))
            length_task = np.zeros(self.task_num)
            
            for task_id, task_classes in enumerate(self.task_classes_arr):
                self.task_id = task_id
             
                length_task[self.task_id] = self.get_attr(self.task_id, "train_len_init_source_task") 
                print(length_task[self.task_id]) 
                
                if self.task_id == 0:
                    self.fewshot_index_source_tasks[self.task_id,:] = np.array(self.fewshot_index_source)[np.isin(self.fewshot_label_source, task_classes)]
                    self.fewshot_label_source_tasks[self.task_id,:] = np.array(self.fewshot_label_source)[np.isin(self.fewshot_label_source, task_classes)]
                    
                   
                    
                else:
                    self.fewshot_index_source_tasks[self.task_id,:] = np.array(self.fewshot_index_source)[np.isin(self.fewshot_label_source, task_classes)] - np.sum(length_task[:self.task_id])
                    self.fewshot_label_source_tasks[self.task_id,:] = np.array(self.fewshot_label_source)[np.isin(self.fewshot_label_source, task_classes)]
                    
                    
                    
                self.set_attr(self.task_id, "fewshot_index_src_task", self.fewshot_index_source_tasks[self.task_id,:])
                self.set_attr(self.task_id, "fewshot_label_src_task", self.fewshot_label_source_tasks[self.task_id,:]) 
          
             
            
        
            fewshot_index_src_task_0 = self.get_attr(0, "fewshot_index_src_task" )
            #print("fewshot_index_src_task_0 ", fewshot_index_src_task_0) 
            
            fewshot_index_src_task_1 = self.get_attr(1, "fewshot_index_src_task" )
            #print("fewshot_index_src_task_1 ", fewshot_index_src_task_1) 
            

            fewshot_index_src_task_2 = self.get_attr(2, "fewshot_index_src_task" )
            #print("fewshot_index_src_task_2 ", fewshot_index_src_task_2) 
                
            

            
                 
            # labels for fewshot
            train_len_src_tsk = self.get_attr(task_id, "train_len_init_source_task")
            #print("task_id", task_id)
            
            fewshot_labels_tsk = (
                torch.zeros(train_len_src_tsk, dtype=torch.long).detach().cuda() - 1
            )
            for ind, lbl in zip(self.fewshot_index_source_tasks[task_id ,:], self.fewshot_label_source_tasks[task_id,:]): 

                fewshot_labels_tsk[int(ind)] = int(lbl)
               

                        
            #.............. Test unlabled task loader source in case fewshot                      
                         
            x_test_source_idx = torch.tensor([0])  
            x_test_source_img = torch.zeros((1,3,224,224))
            x_test_source_lbl = torch.tensor([0]

            for idx , data in enumerate(self.test_unl_loader_source):
                indx, img, label = data

                x_test_source_idx = torch.cat((x_test_source_idx, indx), 0)
                x_test_source_img = torch.cat((x_test_source_img, img), 0)
                x_test_source_lbl = torch.cat((x_test_source_lbl, label), 0)

            image = x_test_source_img[1:]     
            label = x_test_source_lbl[1:]
            index = x_test_source_idx[1:]

            # Divide the data over the different tasks
            task_testdata_source = []  
            for task_id, task_classes in enumerate(self.task_classes_arr):
                self.task_id = task_id
                test_mask = np.isin(label.cpu(), task_classes)
                idx_test_task, x_test_task, y_test_task = index[test_mask], image[test_mask], label[test_mask]
                task_testdata_source.append((idx_test_task, x_test_task, y_test_task))  #- (task_id * 2) #with overlap
                  
           
             #create tensor dataset
            for task_id, _ in enumerate(self.task_classes_arr):
                self.task_id = task_id
                test_unl_data_source_task = TensorDataset(task_testdata_source[self.task_id][0], task_testdata_source[self.task_id][1], task_testdata_source[self.task_id][2])
                test_unl_loader_source_task = DataLoader(test_unl_data_source_task, batch_size= batch_size_dict["test"]) 

                

                self.set_attr(self.task_id, "test_data_source_task", test_unl_data_source_task)  
                self.set_attr(self.task_id, "test_unl_loader_source_task", test_unl_loader_source_task) 
                test_data_source_task = self.get_attr(task_id, "test_data_source_task")
                self.set_attr(task_id, "test_len_data_source_task", len(test_data_source_task))
           
            
        
        #.........if not fewshot         
        else:
            print("******ELSE******")
            train_lbd_dataset_source = create_dataset(
                name,
                domain["source"],
                ret_index= True,
                image_transform = aug_src,
                image_size= image_size,
            )
                                  
                                  
            train_lbd_loader_source = create_loader(
                train_lbd_dataset_source,
                batch_size_dict["labeled"],
                num_workers=num_workers,
            )
            
            
            #.........train_lbd_dataset_source_task _incase_no_fewshot.........                  
                                  
            x_train_source_lbd_idx = torch.tensor([0]).cuda() 
            x_train_source_lbd_img = torch.zeros((1,3,224,224)).cuda()
            x_train_source_lbd_lbl = torch.tensor([0]).cuda()
            
            for idx , data in enumerate(train_lbd_loader_source):

                indx, img, label = data
                indx = indx.cuda()
                img = img.cuda()
                label = label.cuda()

                x_train_source_lbd_idx = torch.cat((x_train_source_lbd_idx, indx), 0)
                x_train_source_lbd_img = torch.cat(( x_train_source_lbd_img , img), 0)
                x_train_source_lbd_lbl = torch.cat((x_train_source_lbd_lbl, label), 0)

            image = x_train_source_lbd_img[1:]     
            label = x_train_source_lbd_lbl[1:]
            index = x_train_source_lbd_idx[1:]

            # Divide the data over the different tasks
            task_lbd_data_source = []   
            for task_id, task_classes in enumerate(task_classes_arr):

                train_mask = np.isin(label.cpu(), task_classes)
                idx_train_task, x_train_task, y_train_task = index[train_mask], image[train_mask], label[train_mask]
                task_lbd_data_source.append((idx_train_task, x_train_task, y_train_task))  #- (task_id * 2) #with overlap
                  
              

             #create tensor dataset
            for task_id, _ in enumerate(task_classes_arr):
                self.task_id = task_id
                train_lbd_data_source_task = TensorDataset(task_lbd_data_source[task_id][0], task_lbd_data_source[self.task_id][1], task_lbd_data_source[self.task_id][2])
                train_lbd_loader_source_task = DataLoader(train_lbd_data_source_task, batch_size= batch_size_dict["labeled"]) 


                self.set_attr(self.task_id, "train_lbd_data_source_task", train_lbd_data_source_task)  
                self.set_attr(self.task_id, "train_lbd_loader_source_task", train_lbd_loader_source_task) 

                   
            
        
            
            
        #........Start creating test unlabled target/task dataset   

        test_suffix = "test" if self.config.data_params.train_val_split else ""

        test_unl_dataset_target = create_dataset(
            name,
            domain["target"],
            suffix=test_suffix,
            ret_index=True,
            image_transform=raw,
            image_size=image_size,
        )


        test_unl_loader_target = create_loader(
            test_unl_dataset_target,
            batch_size_dict["test"],
            is_train=False,
            num_workers=num_workers,
        )


                                  
             
        #.......... Test unlabled target dataset task                   
        x_test_target_idx = torch.tensor([0]).cuda()    
        x_test_target_img = torch.zeros((1,3,224,224)).cuda()
        x_test_target_lbl = torch.tensor([0]).cuda()

        for idx , data in enumerate(test_unl_loader_target):

            indx, img, label = data
            indx = indx.cuda()
            img = img.cuda()
            label = label.cuda()

            x_test_target_idx = torch.cat((x_test_target_idx, indx), 0)
            x_test_target_img = torch.cat((x_test_target_img, img), 0)
            x_test_target_lbl = torch.cat((x_test_target_lbl, label), 0)

        image = x_test_target_img[1:]     
        label = x_test_target_lbl[1:]
        index = x_test_target_idx[1:]

        # Divide the data over the different tasks
        task_test_data_target = []   
        for task_id, task_classes in enumerate(task_classes_arr):

            test_mask = np.isin(label.cpu(), task_classes)
            idx_test_task, x_test_task, y_test_task = index[test_mask], image[test_mask], label[test_mask]
            task_test_data_target.append((idx_test_task, x_test_task, y_test_task))  



        #create tensor dataset
        for task_id, _ in enumerate(task_classes_arr):
            self.task_id = task_id
            test_unl_data_target_task = TensorDataset(task_test_data_target[self.task_id][0], task_test_data_target[self.task_id][1], task_test_data_target[self.task_id][2])
            test_unl_loader_target_task = DataLoader(test_unl_data_target_task, batch_size= batch_size_dict["test"]) 

            self.set_attr(self.task_id, "test_unl_data_target_task", test_unl_data_target_task)  
            self.set_attr(self.task_id, "test_unl_loader_target_task", test_unl_loader_target_task) 
            test_unl_data_target_task = self.get_attr(task_id, "test_unl_data_target_task")
            self.set_attr(task_id, "test_len_data_target_task", len(test_unl_data_target_task))
           


        self.logger.info(
            f"Dataset {name}, source {self.config.data_params.source}, target {self.config.data_params.target}"
        )
        
 
    #................................Model
    def _create_model(self):

        version_grp = self.config.model_params.version.split("-")
     
        version = version_grp[-1]
        pretrained = "pretrain" in version_grp
        if pretrained:
            self.logger.info("Imagenet pretrained model used")
                                  
        out_dim = self.config.model_params.out_dim  
     
        # backbone
        if "resnet" in version:
            net_class = getattr(torchvision.models, version)

            if pretrained:
                model = net_class(pretrained=pretrained)
                model.fc = nn.Linear(model.fc.in_features, out_dim)
                torchutils.weights_init(model.fc)
            else:
               
                model = net_class(pretrained=False, num_class_task=out_dim)   
                                  
        else:
            raise NotImplementedError
                                  
        model = nn.DataParallel(model, device_ids=self.gpu_devices)  
        model = model.cuda()
        self.model = model 
                                  


        if self.cls:
            self.criterion = nn.CrossEntropyLoss().cuda()
                
            self.task_heads = []
            for task_id in range(len(task_classes_arr)):
                self.task_id = task_id
                cls_head = CosineClassifier(
        
                    num_class = self.num_class_task, inc=out_dim, temp= self.config.loss_params.T  )  
                self.task_heads.append(cls_head.cuda())
                torchutils.weights_init(self.task_heads[self.task_id]) 


                
    #.............................Optimizer                            
    def _create_optimizer(self):
        lr = self.config.optim_params.learning_rate
        momentum = self.config.optim_params.momentum
        weight_decay = self.config.optim_params.weight_decay
        conv_lr_ratio = self.config.optim_params.conv_lr_ratio

        parameters = []
        # batch_norm layer: no weight_decay
        params_bn, _ = torchutils.split_params_by_name(self.model, "bn") 
        parameters.append({"params": params_bn, "weight_decay": 0.0})
        # conv layer: small lr
        _, params_conv = torchutils.split_params_by_name(self.model, ["fc", "bn"]) 
        if conv_lr_ratio:
            parameters[0]["lr"] = lr * conv_lr_ratio
            parameters.append({"params": params_conv, "lr": lr * conv_lr_ratio})
        else:
            parameters.append({"params": params_conv})
        # fc layer
        params_fc, _ = torchutils.split_params_by_name(self.model, "fc") 
                
        if self.cls and self.config.optim_params.cls_update:
                                  
        
            params_fc.extend(list(self.task_heads[self.task_id].parameters()))
           
        parameters.append({"params": params_fc})

                                  
        self.optim = torch.optim.SGD(
            parameters,
            lr = lr,
            weight_decay = weight_decay,
            momentum=momentum,
            nesterov=self.config.optim_params.nesterov,
        )

                
        # lr schedular
        if self.config.optim_params.lr_decay_schedule: 
            optim_stepLR = torch.optim.lr_scheduler.MultiStepLR(
                self.optim,
                milestones=self.config.optim_params.lr_decay_schedule,
                gamma=self.config.optim_params.lr_decay_rate,
            )
            self.lr_scheduler_list.append(optim_stepLR)

        if self.config.optim_params.decay:
            self.optim_iterdecayLR = torchutils.lr_scheduler_invLR(self.optim)
    
         
                
    def train_one_epoch(self):
    
        loss_all_tasks = [[] for i in range(len(self.task_classes_arr))]                 
        self.loss_all_tasks = loss_all_tasks
        self.accs_naive =[]
        clus_inf = []
        self.clus_inf = clus_inf
                                  
   
        self.len_task_src = np.zeros(self.task_num)
        self.len_task_tgt = np.zeros(self.task_num)
        
        for task_id, task_classes in enumerate(task_classes_arr):
            self.task_id = task_id
            print(f"Training starts for task {self.task_id}" )
            self.model = self.model.train()
            #print("Model is in the train mode")
            loss_all_tasks[self.task_id].append({})


            if self.cls:
                self.task_heads[self.task_id].train()                   
                                                        
            self.loss_fn.module.epoch = self.current_epoch 

            loss_list = self.config.loss_params.loss
            loss_weight = self.config.loss_params.weight
            loss_warmup = self.config.loss_params.start
            loss_giveup = self.config.loss_params.end

            num_loss = len(loss_list) 

            task_source_loader = self.get_attr(self.task_id, "train_loader_init_source_task",)
            task_target_loader = self.get_attr(self.task_id, "train_loader_init_target_task",) 
            
            len_data_src_task = self.get_attr(self.task_id, "train_len_init_source_task")
            len_data_tgt_task = self.get_attr(self.task_id, "train_len_init_target_task")
            
            self.len_task_src[self.task_id] =  len_data_src_task 
            self.len_task_tgt[self.task_id] =  len_data_tgt_task 
                
            if self.config.steps_epoch is None:
                num_batches = max(len(task_source_loader), len(task_target_loader)) + 1 
                self.logger.info(f"source task loader batches: {len(task_source_loader)}")
                self.logger.info(f"target task loader batches: {len(task_target_loader)}")
            else:
                num_batches = self.config.steps_epoch

            epoch_loss = AverageMeter() 
            epoch_loss_parts = [AverageMeter() for _ in range(num_loss)]
      
            # cluster
            if self.clus:
                if self.config.loss_params.clus.kmeans_freq:    
                    kmeans_batches = num_batches // self.config.loss_params.clus.kmeans_freq   
                else:
                    kmeans_batches = 1
            else:
                kmeans_batches = None

                
            # load weight
            self._load_fewshot_to_cls_weight() 
            if self.fewshot:
   
                fewshot_index_tsk = torch.tensor(self.get_attr(self.task_id, "fewshot_index_src_task")).cuda()
               
                                         
            tqdm_batch = tqdm( 
                total = num_batches, desc=f"[Task {self.task_id} Epoch {self.current_epoch}]", leave=False
            )
            tqdm_post = {} 

            for batch_i in range(num_batches):
                # Kmeans
                if is_div(kmeans_batches, batch_i):
                    self._update_cluster_labels()
                                  
                if not self.config.optim_params.cls_update: 
                    self._load_fewshot_to_cls_weight() 

                # iteration over all source images
                if not batch_i % len(task_source_loader):
                    source_iter_tsk = iter(task_source_loader) 

                    if "semi-condentmax" in self.config.loss_params.loss: 
                        momentum_prob_source_task = (
                            self.momentum_softmax_source_task[self.task_id].softmax_vector.cuda()
                        )
                        self.momentum_softmax_source_task[self.task_id].reset()

                # iteration over all target images
                if not batch_i % len(task_target_loader):
                    target_iter_tsk = iter(task_target_loader) 

                    if "tgt-condentmax" in self.config.loss_params.loss:  
                        momentum_prob_target_task = (
                            self.momentum_softmax_target_task[self.task_id].softmax_vector.cuda()
                        )
                        self.momentum_softmax_target_task[self.task_id].reset()

                
                train_loader_lbd_task = self.get_attr(self.task_id, "train_lbd_loader_source_task")
                # iteration over all labeled source images
                if self.cls and not batch_i % len(train_loader_lbd_task): 
                    source_lbd_iter_task = iter(train_loader_lbd_task)

                               
              
                # calculate loss
                for domain_name in ("source", "target"):
                    loss_tsk = torch.tensor(0).cuda()
                    loss_d_tsk = 0
                    loss_part_d_tsk = [0] * num_loss
                    batch_size = self.batch_size_dict[domain_name]

                    if self.cls and domain_name == "source":
                        indices_lbld_tsk, images_lbd_tsk, labels_lbld_tsk = next(source_lbd_iter_task)
                       
                        if self.task_id == 0:
                            indices_lbd_tsk = indices_lbld_tsk
                            labels_lbd_tsk = labels_lbld_tsk
                        else:
                            indices_lbd_tsk = indices_lbld_tsk - torch.min(indices_lbld_tsk)
                            labels_lbd_tsk = labels_lbld_tsk - torch.min(labels_lbld_tsk)
                        
                        indices_lbd_tsk = indices_lbd_tsk.cuda()
                       
                        images_lbd_tsk = images_lbd_tsk.cuda()
                        labels_lbd_tsk = labels_lbd_tsk.cuda()
                       
                        feat_lbd_tsk = self.model(images_lbd_tsk)
                        feat_lbd_tsk = F.normalize(feat_lbd_tsk, dim=1)
                        out_lbd_tsk = self.task_heads[self.task_id](feat_lbd_tsk)


                    # Matching & ssl
                    if (self.tgt and domain_name == "target") or self.ssl:
                        loader_iter_tsk = (
                            source_iter_tsk if domain_name == "source" else target_iter_tsk
                        )
                       
                        indices_unl_tsk, images_unl_tsk, _ = next(loader_iter_tsk)
                        images_unl_tsk = images_unl_tsk.cuda()
                       
                        if domain_name == "source":
                            if self.task_id == 0:
                                index_unl_tsk = indices_unl_tsk 
                                
                            else:
                                
                                index_unl_tsk= indices_unl_tsk - torch.tensor(np.sum(self.len_task_src[:self.task_id]))
                              
                        else :
                            if self.task_id == 0:
                                index_unl_tsk = indices_unl_tsk 
                            else:
                             
                                index_unl_tsk= indices_unl_tsk - torch.tensor(np.sum(self.len_task_tgt[:self.task_id]))
                                
                            
                        #indices_unl_tsk = indices_unl_tsk.cuda()
                        index_unl_tsk = index_unl_tsk.cuda()
                        feat_unl_tsk = self.model(images_unl_tsk)
                        feat_unl_tsk = F.normalize(feat_unl_tsk, dim=1)
                        out_unl_tsk = self.task_heads[self.task_id](feat_unl_tsk)

                    # Semi Supervised
                    if self.semi and domain_name == "source":
                       
                        semi_mask_tsk = ~torchutils.isin(index_unl_tsk, fewshot_index_tsk)
                    
                        indices_semi_tsk = index_unl_tsk[semi_mask_tsk] 
                        
                        out_semi_tsk = out_unl_tsk[semi_mask_tsk] 

                    # Self-supervised Learning
                    if self.ssl:
                        _, new_data_memory_tsk, loss_ssl_tsk, aux_list_tsk = self.loss_fn( 
                            self.task_id, index_unl_tsk.long(), feat_unl_tsk, domain_name, self.parallel_helper_idxs  
                        )
                        loss_ssl_tsk = [torch.mean(ls) for ls in loss_ssl_tsk] 

                    # pseudo
                    loss_pseudo_tsk = torch.tensor(0).cuda()
                    is_pseudo_tsk = {"source": self.is_pseudo_src, "target": self.is_pseudo_tgt}
                    thres_dict = {
                        "source": self.config.loss_params.thres_src, #Related to high confidence pseudo
                        "target": self.config.loss_params.thres_tgt, #Related to high confidence pseudo
                    }

                    if is_pseudo_tsk[domain_name]:
                        if domain_name == "source":
                            indices_pseudo_tsk = indices_semi_tsk
                            out_pseudo_tsk = out_semi_tsk
                            pseudo_domain_tsk = self.get_attr(self.task_id, "predict_ordered_labels_pseudo_source_tsk")
                            print("***pseudo_domain_tsk_source***", pseudo_domain_tsk )
                            print("*****Length of pseudo_domain_tsk_source**** ", len(pseudo_domain_tsk))
                        else:
                            indices_pseudo_tsk = index_unl_tsk
                            out_pseudo_tsk = out_unl_tsk 
                            pseudo_domain_tsk = self.get_attr(self.task_id, "predict_ordered_labels_pseudo_target_tsk")
                            print("***pseudo_domain_tsk_target***", pseudo_domain_tsk )
                            print("*****Length of pseudo_domain_tsk_target**** ", len(pseudo_domain_tsk))
                        thres = thres_dict[domain_name] 

                        # calculate loss
                        loss_pseudo_tsk, aux_tsk = torchutils.pseudo_label_loss( 
                            out_pseudo_tsk,
                            thres=thres,
                            mask=None,
                            #num_class=self.num_class,
                            num_class = self.num_class_task , 
                            aux=True
                        )
                        mask_pseudo_tsk = aux_tsk["mask"] 
                       

                        # fewshot memory bank
                        mb_tsk = self.get_attr(self.task_id, "memory_bank_wrapper_src")
                       
                        indices_lbd_tounl_tsk = fewshot_index_tsk[indices_lbd_tsk] 
                        
                        
                        mb_feat_lbd_tsk = mb_tsk.at_idxs(indices_lbd_tounl_tsk.long()) 
        
                        
                        fewshot_data_memory_tsk = update_data_memory(mb_feat_lbd_tsk, feat_lbd_tsk) 

                        # stat
                        pred_selected_tsk = out_pseudo_tsk.argmax(dim=1)[mask_pseudo_tsk] 
                          
                        indices_selected_tsk = indices_pseudo_tsk[mask_pseudo_tsk] 
                        
                        indices_unselected_tsk = indices_pseudo_tsk[~mask_pseudo_tsk] 
                        
                        pseudo_domain_tsk[indices_selected_tsk.long()] = pred_selected_tsk 
                    
                        pseudo_domain_tsk[indices_unselected_tsk.long()] = -1 
                       
                    # Compute Loss
                    for ind, ls in enumerate(loss_list):  
                        if (
                            self.current_epoch < loss_warmup[ind] 
                            or self.current_epoch >= loss_giveup[ind] 
                        ):
                            continue
                        
                        print("Compute loss is happening..")
                                  
                        loss_part_tsk = torch.tensor(0).cuda()
                        # *** handler for different loss ***
                        # classification on few-shot
                        if ls == "cls-so" and domain_name == "source":
                            loss_part_tsk = self.criterion(out_lbd_tsk, labels_lbd_tsk)
                                  
                        elif ls == "cls-info" and domain_name == "source":
                            loss_part_tsk = loss_info(feat_lbd_tsk, mb_feat_lbd_tsk, labels_lbd_tsk) 
                                  
                        # semi-supervision learning on unlabled source
                        elif ls == "semi-entmin" and domain_name == "source":
                            loss_part_tsk = torchutils.entropy(out_semi_tsk) 
                                  
                        elif ls == "semi-condentmax" and domain_name == "source":
                            bs_tsk = out_semi_tsk.size(0)
                            prob_semi_tsk = F.softmax(out_semi_tsk, dim=1)
                            prob_mean_semi_tsk = prob_semi_tsk.sum(dim=0) / bs_tsk

                            # update momentum
                            self.momentum_softmax_source_task[self.task_id].update(  
                                prob_mean_semi_tsk.cpu().detach(), bs_tsk
                            )
                            # get momentum probability
                            momentum_prob_source_task = ( 
                                self.momentum_softmax_source_task[self.task_id].softmax_vector.cuda() 
                            )
                                  
                            # compute loss
                            entropy_cond_tsk = -torch.sum( 
                                prob_mean_semi_tsk * torch.log(momentum_prob_source_task + 1e-5)
                            )
                            loss_part_tsk = -entropy_cond_tsk

                        # learning on unlabeled target domain
                        elif ls == "tgt-entmin" and domain_name == "target":
                            loss_part_tsk = torchutils.entropy(out_unl_tsk)
                                  
                        elif ls == "tgt-condentmax" and domain_name == "target":
                            bs_tsk = out_unl_tsk.size(0)
                            prob_unl_tsk = F.softmax(out_unl_tsk, dim=1)
                            prob_mean_unl_tsk = prob_unl_tsk.sum(dim=0) / bs_tsk

                            # update momentum
                            self.momentum_softmax_source_task[self.task_id].update(
                                prob_mean_unl_tsk.cpu().detach(), bs_tsk
                            )
                            # get momentum probability
                            momentum_prob_target_task = (
                                self.momentum_softmax_target_task[self.task_id].softmax_vector.cuda()
                            )
                            # compute loss
                            entropy_cond_tsk = -torch.sum(
                                prob_mean_unl_tsk * torch.log(momentum_prob_target_task + 1e-5)
                            )
                            loss_part_tsk = -entropy_cond_tsk

                        # self-supervised learning
                        elif ls.split("-")[0] in ["ID", "CD", "proto", "I2C", "C2C"]:
                            loss_part_tsk = loss_ssl_tsk[ind] #!

                        print(f"***********End of loss computation for {domain_name}*******")
                        
                        loss_part_tsk = loss_weight[ind] * loss_part_tsk
                        loss_tsk = loss_tsk + loss_part_tsk
                        loss_d_tsk = loss_d_tsk + loss_part_tsk.item()
                        loss_part_d_tsk[ind] = loss_part_tsk.item()

                                  
                    # Backpropagation
                    self.optim.zero_grad()
                    if len(loss_list) and loss_tsk != 0:
                        loss_tsk.backward()
                    self.optim.step()

                    # update memory_bank
                    if self.ssl:
                        self._update_memory_bank(domain_name, index_unl_tsk, new_data_memory_tsk) 
                        if domain_name == "source":
                            self._update_memory_bank( 
                                domain_name, indices_lbd_tounl_tsk, fewshot_data_memory_tsk
                            )

                    # update lr info
                    tqdm_post["lr"] = torchutils.get_lr(self.optim, g_id=-1) 

                    # update loss info
                    epoch_loss.update(loss_d_tsk, batch_size) 
                    tqdm_post["loss"] = epoch_loss.avg 
                    self.summary_writer.add_scalars( 
                        "train/loss", {"loss": epoch_loss.val}, self.current_iteration
                    )
                    self.train_loss.append(epoch_loss.val)  
                                  
                    loss_all_tasks[task_id][0][domain_name] = [loss_part_tsk, loss_tsk, loss_d_tsk, loss_part_d_tsk, self.train_loss]


                    # update loss part info
                    domain_iteration = self.get_attr(domain_name, "current_iteration") 
                    self.summary_writer.add_scalars( 
                        f"train/{self.domain_map[domain_name]}_loss",
                        {"loss": epoch_loss.val},
                        domain_iteration,
                    )
                    for i, ls in enumerate(loss_part_d_tsk): 
                        ls_name = loss_list[i]
                        epoch_loss_parts[i].update(ls, batch_size)
                        tqdm_post[ls_abbr[ls_name]] = epoch_loss_parts[i].avg
                        self.summary_writer.add_scalars(
                            f"train/{self.domain_map[domain_name]}_loss",
                            {ls_name: epoch_loss_parts[i].val},
                            domain_iteration,
                        )

                    # adjust lr
                    if self.config.optim_params.decay: 
                        self.optim_iterdecayLR.step()

                    self.current_iteration += 1 
                tqdm_batch.set_postfix(tqdm_post) 
                tqdm_batch.update() 
                self.current_iteration_source += 1
                self.current_iteration_target += 1
            tqdm_batch.close() 

            
            self.current_loss = epoch_loss.avg
                                  
            loss_all_tasks[task_id].append(self.current_loss) 
    
    
            
            val_acc_all_tasks = []
            self.val_acc_all_tasks = val_acc_all_tasks
            for idx in range(0, self.task_id + 1): 
                self.idx = idx
                self.validate()
                #self.save_checkpoint_tsk_val
                
            # For unseen tasks, we don't test
            if self.task_id < (self.task_num - 1):
                self.val_acc_all_tasks.extend([np.nan] * (7 - self.task_id))
                # Collect all test accuracies
            self.accs_naive.append(self.val_acc_all_tasks)

            

        
                                  
    @torch.no_grad()
    def _load_fewshot_to_cls_weight(self):
        """load centroids to cosine classifier

        Args:
            method (str, optional): None, 'fewshot', 'src', 'tgt'. Defaults to None.
        """
        method = self.config.model_params.load_weight
        print("task_id in load fewshot:" , self.task_id)
        if method is None:
            return
        assert method in ["fewshot", "src", "tgt", "src-tgt", "fewshot-tgt"]

        thres = {"src": 1, "tgt": self.config.model_params.load_weight_thres}
#         bank = {
#             "src_0": self.get_attr(0, "memory_bank_wrapper_src").as_tensor(),
#             "src_1": self.get_attr(1, "memory_bank_wrapper_src").as_tensor(),
#             "src_2": self.get_attr(2, "memory_bank_wrapper_src").as_tensor(),
#             "src_3": self.get_attr(3, "memory_bank_wrapper_src").as_tensor(),
#             "src_4": self.get_attr(4, "memory_bank_wrapper_src").as_tensor(),
#             "src_5": self.get_attr(5, "memory_bank_wrapper_src").as_tensor(),
#             "tgt_0": self.get_attr(0, "memory_bank_wrapper_tgt").as_tensor(),
#             "tgt_1": self.get_attr(1, "memory_bank_wrapper_tgt").as_tensor(),
#             "tgt_2": self.get_attr(2, "memory_bank_wrapper_tgt").as_tensor(),
#             "tgt_3": self.get_attr(3, "memory_bank_wrapper_tgt").as_tensor(),
#             "tgt_4": self.get_attr(4, "memory_bank_wrapper_tgt").as_tensor(),
#             "tgt_5": self.get_attr(5, "memory_bank_wrapper_tgt").as_tensor(),
#         }
        
        
        bank = {
        "src": {
            0: self.get_attr(0, "memory_bank_wrapper_src").as_tensor(),
            1: self.get_attr(1, "memory_bank_wrapper_src").as_tensor(),
            2: self.get_attr(2, "memory_bank_wrapper_src").as_tensor(),
            3: self.get_attr(3, "memory_bank_wrapper_src").as_tensor(),
            4: self.get_attr(4, "memory_bank_wrapper_src").as_tensor(),
            5: self.get_attr(5, "memory_bank_wrapper_src").as_tensor(),
            6: self.get_attr(6, "memory_bank_wrapper_src").as_tensor(),
            7: self.get_attr(7, "memory_bank_wrapper_src").as_tensor(),
            },
    
        "tgt": {
            0: self.get_attr(0, "memory_bank_wrapper_tgt").as_tensor(),
            1: self.get_attr(1, "memory_bank_wrapper_tgt").as_tensor(),
            2: self.get_attr(1, "memory_bank_wrapper_tgt").as_tensor(),
            3: self.get_attr(3, "memory_bank_wrapper_tgt").as_tensor(),
            4: self.get_attr(4, "memory_bank_wrapper_tgt").as_tensor(),
            5: self.get_attr(5, "memory_bank_wrapper_tgt").as_tensor(),
            6: self.get_attr(6, "memory_bank_wrapper_tgt").as_tensor(),
            7: self.get_attr(7, "memory_bank_wrapper_tgt").as_tensor(),
       }
    }
        
        fewshot_label_tsk = {}
        fewshot_index_tsk = {}
        
        is_tgt = (
            method in ["tgt", "fewshot-tgt", "src-tgt"]
            and self.current_epoch >= self.config.model_params.load_weight_epoch
        )
        
        if method in ["fewshot", "fewshot-tgt"]:
            if self.fewshot:

                fewshot_index_tsk["src"] : {
                   
                    0:torch.tensor(self.get_attr(0, "fewshot_index_src_task")),
                    1:torch.tensor(self.get_attr(1, "fewshot_index_src_task")),
                    2:torch.tensor(self.get_attr(2, "fewshot_index_src_task")),
                    3:torch.tensor(self.get_attr(3, "fewshot_index_src_task")),
                    4:torch.tensor(self.get_attr(4, "fewshot_index_src_task")),
                    5:torch.tensor(self.get_attr(5, "fewshot_index_src_task")),
                    6:torch.tensor(self.get_attr(6, "fewshot_index_src_task")),
                    7:torch.tensor(self.get_attr(7, "fewshot_index_src_task")),
                }
                
                fewshot_label_tsk["src"] : {
                   
                    0:torch.tensor(self.get_attr(0, "fewshot_label_src_task")),
                    1:torch.tensor(self.get_attr(1, "fewshot_label_src_task")),
                    2:torch.tensor(self.get_attr(2, "fewshot_label_src_task")),
                    3:torch.tensor(self.get_attr(3, "fewshot_label_src_task")),
                    4:torch.tensor(self.get_attr(4, "fewshot_label_src_task")),
                    5:torch.tensor(self.get_attr(5, "fewshot_label_src_task")),
                    6:torch.tensor(self.get_attr(6, "fewshot_label_src_task")), 
                    7:torch.tensor(self.get_attr(7, "fewshot_label_src_task")),
                }

            else:

                fewshot_label_tsk["src"] : {
            
                    0:self.get_attr(0, "train_ordered_labels_source_task"),
                    1:self.get_attr(1, "train_ordered_labels_source_task"),
                    2:self.get_attr(2, "train_ordered_labels_source_task"),
                    3:self.get_attr(3, "train_ordered_labels_source_task"),
                    4:self.get_attr(4, "train_ordered_labels_source_task"),
                    5:self.get_attr(5, "train_ordered_labels_source_task"),
                    6:self.get_attr(6, "train_ordered_labels_source_task"),
                    7:self.get_attr(7, "train_ordered_labels_source_task"),
                     }
                                          
                fewshot_index_tsk["src"] : {

                    0:self.get_attr(0, "train_len_source_task"),
                    1:self.get_attr(1, "train_len_source_task"),
                    2:self.get_attr(2, "train_len_source_task"),
                    3:self.get_attr(3, "train_len_source_task"),
                    4:self.get_attr(4, "train_len_source_task"),
                    5:self.get_attr(5, "train_len_source_task"),
                    6:self.get_attr(6, "train_len_source_task"),
                    7:self.get_attr(7, "train_len_source_task"),
                     }


        else:       



            mask_tsk_0 = torch.tensor(self.get_attr(0, "predict_ordered_labels_pseudo_source_tsk"))!= -1 
            mask_tsk_1 = torch.tensor(self.get_attr(1, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_2 = torch.tensor(self.get_attr(2, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_3 = torch.tensor(self.get_attr(3, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_4 = torch.tensor(self.get_attr(4, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_5 = torch.tensor(self.get_attr(5, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_6 = torch.tensor(self.get_attr(6, "predict_ordered_labels_pseudo_source_tsk")) != -1 
            mask_tsk_7 = torch.tensor(self.get_attr(7, "predict_ordered_labels_pseudo_source_tsk")) != -1 

            fewshot_label_tsk["src"] : {

                0:self.get_attr(0, "train_ordered_labels_source_task")[mask_tsk_0],
                1:self.get_attr(1, "train_ordered_labels_source_task")[mask_tsk_1],
                2:self.get_attr(2, "train_ordered_labels_source_task")[mask_tsk_2],
                3:self.get_attr(3, "train_ordered_labels_source_task")[mask_tsk_3],
                4:self.get_attr(4, "train_ordered_labels_source_task")[mask_tsk_4],
                5:self.get_attr(5, "train_ordered_labels_source_task")[mask_tsk_5],
                6:self.get_attr(6, "train_ordered_labels_source_task")[mask_tsk_6],
                7:self.get_attr(7, "train_ordered_labels_source_task")[mask_tsk_7],
                 }
            
            
            fewshot_index_tsk["src"] : {

                    0:mask_tsk_0.nonzero().squeeze(1),
                    1:mask_tsk_1.nonzero().squeeze(1),
                    2:mask_tsk_2.nonzero().squeeze(1),
                    3:mask_tsk_3.nonzero().squeeze(1),
                    4:mask_tsk_4.nonzero().squeeze(1),
                    5:mask_tsk_5.nonzero().squeeze(1),
                    6:mask_tsk_6.nonzero().squeeze(1),
                    7:mask_tsk_7.nonzero().squeeze(1),
                     }
            
            
            
        if is_tgt:



            mask_tsk_0 = torch.tensor(self.get_attr(0, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_1 = torch.tensor(self.get_attr(1, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_2 = torch.tensor(self.get_attr(2, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_3 = torch.tensor(self.get_attr(3, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_4 = torch.tensor(self.get_attr(4, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_5 = torch.tensor(self.get_attr(5, "predict_ordered_labels_pseudo_target_tsk")) != -1 
            mask_tsk_6 = torch.tensor(self.get_attr(6, "predict_ordered_labels_pseudo_target_tsk")) != -1
            mask_tsk_7 = torch.tensor(self.get_attr(7, "predict_ordered_labels_pseudo_target_tsk")) != -1
            
            fewshot_label_tsk["tgt"]: {

                0:self.get_attr(0, "train_ordered_labels_target_task")[mask_tsk_0], 
                1:self.get_attr(1, "train_ordered_labels_target_task")[mask_tsk_1],
                2:self.get_attr(2, "train_ordered_labels_target_task")[mask_tsk_2],
                3:self.get_attr(3, "train_ordered_labels_target_task")[mask_tsk_3],
                4:self.get_attr(4, "train_ordered_labels_target_task")[mask_tsk_4],
                5:self.get_attr(5, "train_ordered_labels_target_task")[mask_tsk_5],
                6:self.get_attr(6, "train_ordered_labels_target_task")[mask_tsk_5],
                7:self.get_attr(7, "train_ordered_labels_target_task")[mask_tsk_5],
                 }
            
            
            fewshot_index_tsk["tgt"]: {

                    0:mask_tsk_0.nonzero().squeeze(1),
                    1:mask_tsk_1.nonzero().squeeze(1),
                    2:mask_tsk_2.nonzero().squeeze(1),
                    3:mask_tsk_3.nonzero().squeeze(1),
                    4:mask_tsk_4.nonzero().squeeze(1),
                    5:mask_tsk_5.nonzero().squeeze(1),
                    6:mask_tsk_6.nonzero().squeeze(1),
                    7:mask_tsk_7.nonzero().squeeze(1),

                     }


        
        for task_id ,task_classes in enumerate(self.task_classes_arr):
            task_id = self.task_id
            weight = self.task_heads[self.task_id].fc.weight.data 
                
            self.set_attr(self.task_id, "weight_tsk", weight)
            
            for domain in ("src", "tgt"):
                if domain == "tgt" and not is_tgt:
                    break
                if domain == "src" and method == "tgt":
                    break


                    # for label in range(self.num_class):
                    for label in (task_classes):  
                        print(task_classes)
                        fewshot_mask_tsk = fewshot_label_tsk[domain][self.task_id] == label
                        if fewshot_mask_tsk.sum() < thres[domain]: 
                            continue
                        fewshot_ind_tsk = fewshot_index_tsk[domain][self.task_id][fewshot_mask_tsk]
                        bank_vec_tsk = bank[domain][self.task_id][fewshot_ind_tsk] 
                        weight_tsk = self.get_attr(self.task_id, "wight_tsk")
                        weight[label] = F.normalize(torch.mean(bank_vec_tsk, dim=0), dim=0)

    # Validate

    @torch.no_grad()
    def validate(self):
                                  

        self.model.eval()
        self.length_task_src = np.zeros(self.task_num)
        
        
        print(f"validation for task {self.idx} starts ")     
        self.test_unl_loader_source_task = self.get_attr(self.idx, "test_unl_loader_source_task")
        test_len_data_src_task = self.get_attr(self.idx, "test_len_data_source_task")
        self.length_task_src[self.idx] = test_len_data_src_task 
        #print("length of tasks in source",self.length_task_src[self.task_id] )
            
       
            
        self.test_unl_loader_target_task = self.get_attr(self.idx, "test_unl_loader_target_task")
        # Domain Adaptation
        if self.cls:
            

            self.task_heads[self.idx].eval()
            if ( 
                self.config.data_params.fewshot
                and self.config.data_params.name not in ["visda17", "digits"]
            ):
                print("***********Score calculation for source started")
                self.score( 
                       
                        self.test_unl_loader_source_task ,
                        name=f"unlabeled {self.domain_map['source'] } task {self.idx}",
                    )
            print("***********Score calculation for target started")
            self.current_val_metric = self.score( 
                self.test_unl_loader_target_task,
                name=f"unlabeled  {self.domain_map['target']}  task {self.idx}",
            )

        # update information
        self.current_val_iteration += 1 
        if self.current_val_metric >= self.best_val_metric:
            self.best_val_metric = self.current_val_metric
            self.best_val_epoch = self.current_epoch
            self.iter_with_no_improv = 0
        else:
            self.iter_with_no_improv += 1

        self.val_acc.append(self.current_val_metric)                               
        self.val_acc_all_tasks.append(self.current_val_metric)

        self.clear_train_features() 

    @torch.no_grad()
    def score(self, loader, name="test"): 
        correct = 0
        size = 0
        epoch_loss = AverageMeter()
        error_indices = []  
        confusion_matrix = torch.zeros(self.num_class_task, self.num_class_task, dtype=torch.long)  
                
        pred_score = []
        pred_label = []
        label = []

        for batch_i, (indices, images, labels) in enumerate(loader):
            images = images.cuda()
   
            labels = labels.cuda()
        
            if self.idx == 0:
                indices_tsk = indices
                labels_tsk = labels 
            else:
             
                indices_tsk = indices - torch.tensor(np.sum(self.length_task_src[:self.idx])) #.cpu().astype(int)
                indices_tsk = indices_tsk.to(torch.int64)
                labels_tsk = labels - torch.min(labels)
                          
            
            
            feat = self.model(images)
            
            feat = F.normalize(feat, dim=1)
            output = self.task_heads[self.idx](feat)
            
            prob = F.softmax(output, dim=-1)
            
            loss = self.criterion(output,labels_tsk )

            pred = torch.max(output, dim=1)[1]

            pred_label.extend(pred.cpu().tolist())
            label.extend(labels_tsk.cpu().tolist())
            if self.num_class_task == 2:
                pred_score.extend(prob[:, 1].cpu().tolist())

            correct += pred.eq(labels_tsk).sum().item()
            
            
            for t, p, ind in zip(labels_tsk, pred, indices_tsk):
                confusion_matrix[t.long(), p.long()] += 1
                if t != p:
                    error_indices.append((ind, p))
            size += pred.size(0)
            epoch_loss.update(loss, pred.size(0))

        acc = correct / size
        self.summary_writer.add_scalars(
            "test/acc", {f"{name}": acc}, self.current_epoch
        )
        self.summary_writer.add_scalars(
            "test/loss", {f"{name}": epoch_loss.avg}, self.current_epoch
        )
        self.logger.info(
            f"[Epoch {self.current_epoch} task {self.idx} {name}] loss={epoch_loss.avg:.5f}, acc={correct}/{size}({100. * acc:.3f}%)"
        )

        return acc

    # Load & Save checkpoint

    def load_checkpoint(
        self,
        filename,
        checkpoint_dir=None,
        load_memory_bank=False,
        load_model=True,
        load_optim=False,
        load_epoch=False,
        load_cls=True,
    ):
        checkpoint_dir = checkpoint_dir or self.config.checkpoint_dir
        filename = os.path.join(checkpoint_dir, filename)
        try:
            self.logger.info(f"Loading checkpoint '{filename}'")
            checkpoint = torch.load(filename, map_location="cpu")

            if load_epoch:
                self.current_epoch = checkpoint["epoch"]
                for domain_name in ("source", "target"):
                    self.set_attr(
                        domain_name,
                        "current_iteration",
                        checkpoint[f"iteration_{domain_name}"],
                    )
                self.current_iteration = checkpoint["iteration"]
                self.current_val_iteration = checkpoint["val_iteration"]

            if load_model:
                model_state_dict = checkpoint["model_state_dict"]
                self.model.load_state_dict(model_state_dict)

            if load_cls and self.cls and "cls_state_dict" in checkpoint:
                cls_state_dict = checkpoint["cls_state_dict"]
                #self.cls_head.load_state_dict(cls_state_dict)
                self.task_heads[self.task_id].load_state_dict(cls_state_dict)

            if load_optim:
                optim_state_dict = checkpoint["optim_state_dict"]
                self.optim.load_state_dict(optim_state_dict)

                lr_pretrained = self.optim.param_groups[0]["lr"]
                lr_config = self.config.optim_params.learning_rate

                # Change learning rate
                if not lr_pretrained == lr_config:
                    for param_group in self.optim.param_groups:
                        param_group["lr"] = self.config.optim_params.learning_rate

            self._init_memory_bank()
            if (
                load_memory_bank or self.config.model_params.load_memory_bank == False
            ):  
                self._load_memory_bank(
                    {
                        "source": checkpoint["memory_bank_source"],
                        "target": checkpoint["memory_bank_target"],
                    }
                )

            self.logger.info(
                f"Checkpoint loaded successfully from '{filename}' at (epoch {checkpoint['epoch']}) at (iteration s:{checkpoint['iteration_source']} t:{checkpoint['iteration_target']}) with loss = {checkpoint['loss']}\nval acc = {checkpoint['val_acc']}\n"
            )

        except OSError as e:
            self.logger.info(f"Checkpoint doesnt exists: [{filename}]")
            raise e
            
            
    def save_checkpoint_tsk_train(self, filename="checkpoint_task_train.pth.tar"):  
        
        out_dict_tsk = {
            "model_state_dict": self.model.state_dict(), 
            "optim_state_dict": self.optim.state_dict(),
            "memory_bank_source_tsk_0": self.get_attr(0, "memory_bank_wrapper_src"),
            "memory_bank_source_tsk_1": self.get_attr(1, "memory_bank_wrapper_src"),
            "memory_bank_source_tsk_2": self.get_attr(2, "memory_bank_wrapper_src"),
            "memory_bank_source_tsk_3": self.get_attr(3, "memory_bank_wrapper_src"),
            "memory_bank_source_tsk_4": self.get_attr(4, "memory_bank_wrapper_src"),
            "memory_bank_source_tsk_5": self.get_attr(5, "memory_bank_wrapper_src"),
            "memory_bank_target_tsk_0": self.get_attr(0, "memory_bank_wrapper_tgt"),  
            "memory_bank_target_tsk_1": self.get_attr(1, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_2": self.get_attr(2, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_3": self.get_attr(3, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_4": self.get_attr(4, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_5": self.get_attr(5, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_6": self.get_attr(6, "memory_bank_wrapper_tgt"),
            "memory_bank_target_tsk_7": self.get_attr(7, "memory_bank_wrapper_tgt"),
            
            "iteration_source": self.get_attr("source", "current_iteration"), 
            "iteration_target": self.get_attr("target", "current_iteration"), 
            "loss": self.current_loss,
            "train_loss": np.array(self.train_loss),
            
        }
        if self.cls:
            out_dict["cls_state_dict"] = cls_dict = {
                
                "head_task_0" : self.task_heads[0].state_dict(),
                "head_task_1" : self.task_heads[1].state_dict(),
                "head_task_2" : self.task_heads[2].state_dict(),  
                "head_task_3" : self.task_heads[3].state_dict(),
                "head_task_4" : self.task_heads[4].state_dict(),
                "head_task_5" : self.task_heads[5].state_dict(),
                "head_task_6" : self.task_heads[6].state_dict(),
                "head_task_7" : self.task_heads[7].state_dict(),
             }
       
        
        torchutils.save_checkpoint(
            out_dict, filename=filename, folder=self.config.checkpoint_dir
        )
        self.copy_checkpoint_tsk_train()

        
        
    def save_checkpoint_tsk_val(self, filename="checkpoint_task_val.pth.tar"):  
        
        out_dict_tsk = {
            "model_state_dict": self.model.state_dict(), 
            "optim_state_dict": self.optim.state_dict(),
            "val_iteration": self.current_val_iteration,
            "val_acc": np.array(self.val_acc),
            "val_metric": self.current_val_metric,
            
        }
 
        if self.cls:
            out_dict["cls_state_dict"] = cls_dict = {
                
                "head_task_0" : self.task_heads[0].state_dict(),
                "head_task_1" : self.task_heads[1].state_dict(),
                "head_task_2" : self.task_heads[2].state_dict(),  
                "head_task_3" : self.task_heads[3].state_dict(),
                "head_task_4" : self.task_heads[4].state_dict(),
                "head_task_5" : self.task_heads[5].state_dict(),
                "head_task_6" : self.task_heads[5].state_dict(),
                "head_task_7" : self.task_heads[5].state_dict(),
                                                    
                 }
       
        
       
        is_best = (
            self.current_val_metric == self.best_val_metric
        ) or not self.config.validate_freq
        torchutils.save_checkpoint(
            out_dict, is_best, filename=filename, folder=self.config.checkpoint_dir
        )
        self.copy_checkpoint_tsk_val()


    def save_checkpoint(self, filename="checkpoint.pth.tar"):
        out_dict = {
            "config": self.config,
            "model_state_dict": self.model.state_dict(), 
            "optim_state_dict": self.optim.state_dict(),
            "epoch": self.current_epoch,
            "iteration": self.current_iteration,
            "iteration_source": self.get_attr("source", "current_iteration"), #fix it
            "iteration_target": self.get_attr("target", "current_iteration"), #fix it
            "val_iteration": self.current_val_iteration,
            "val_acc": np.array(self.val_acc),
            "val_metric": self.current_val_metric,
            "loss": self.current_loss,
            "loss_all_tasks":self.loss_all_tasks,
            "clus_inf" : self.clus_inf,
            "val_acc_all_tasks": self.val_acc_all_tasks,
            "train_loss": np.array(self.train_loss),
            
        }
       
          
        # best according to source-to-target
        is_best = (
            self.current_val_metric == self.best_val_metric
        ) or not self.config.validate_freq
        torchutils.save_checkpoint(
            out_dict, is_best, filename=filename, folder=self.config.checkpoint_dir
        )
        self.copy_checkpoint()

    # compute train features

    @torch.no_grad()
    def compute_train_features(self):
        #print("feature computation is running")
        if self.is_features_computed and self.task_idm not in [0,1,2,3,4,5,6,7] :
            #print("feature computed return:" ,self.is_features_computed)
            return
        else:
            self.is_features_computed = True  
            #print("feature computed model.eval:" ,self.is_features_computed)
        self.model.eval()
        
            
        for domain in ("source", "target"):
            #print("task_id:", self.task_idm)
            if domain == "source":   
                #print("task_id",self.task_id)
                train_loader_source_tsk = self.get_attr(self.task_idm, "train_loader_init_source_task")
                
                self.length_task = np.zeros(self.task_num)
                self.length_task[self.task_idm]  = self.get_attr(self.task_idm, "train_len_init_source_task")

                features, y, idx = [], [], []
                tqdm_batch = tqdm(
                    total=len( train_loader_source_tsk), desc=f"[Compute train features of task {self.task_idm} for {domain}]"
                )

                for batch_i, (indices, images, labels) in enumerate(train_loader_source_tsk):
         
                    indices = indices.to(torch.int64)
                    images = images.to(self.device)
                    feat = self.model(images)
                    feat = F.normalize(feat, dim=1)

                    features.append(feat)
                    y.append(labels)
                    idx.append(indices)
        

                    tqdm_batch.update()
                tqdm_batch.close()


                features = torch.cat(features)

                y = torch.cat(y)
                idx = torch.cat(idx) 
                
                
                #.iterate over all three lists at the same time and sort the indices
                data_zip = list(zip(idx, features, y))            
                data_sort = sorted(data_zip)
                

                #....Devide the tupel into three lists
                idx_sort = []
                y_sort =[]
                features_sort = []

                for item in data_sort :
                    idx_sort.append(item[0])
                    features_sort.append(item[1])
                    y_sort.append(item[2])

                #...Convert list to torch.Tensor
                idx_sort = torch.Tensor(idx_sort).to(self.device)  
                y_sort = torch.Tensor(y_sort)
                features_sort_stack = torch.stack(features_sort)


                self.set_attr(self.task_idm, "train_features_src_tsk", features_sort_stack)
                self.set_attr(self.task_idm, "train_labels_scr_task", y_sort) 
                self.set_attr(self.task_idm, "train_indices_src_tsk", idx_sort) 

            #....................for target domain
            if domain == "target":   
                train_loader_target_tsk = self.get_attr(self.task_idm, "train_loader_init_target_task")
                features, y, idx = [], [], []
                tqdm_batch = tqdm(
                    total=len(train_loader_target_tsk), desc=f"[Compute train features of task {self.task_idm} for {domain}]"
                )
                for batch_i, (indices, images, labels) in enumerate(train_loader_target_tsk):

                    indices = indices.to(torch.int64)
                    images = images.to(self.device)
                    feat = self.model(images)            
                    feat = F.normalize(feat, dim=1)


                    features.append(feat) 
                    y.append(labels)                 
                    idx.append(indices)

                    tqdm_batch.update()
                tqdm_batch.close()

                features = torch.cat(features)
                y = torch.cat(y)       
                idx = torch.cat(idx) 
              



                #.iterate over all three lists at the same time and sort the indices
                data_zip = list(zip(idx, features, y))              
                data_sort = sorted(data_zip)

                #....Devide the tupel into three lists
                idx_sort = []
                y_sort =[]
                features_sort = []

                for item in data_sort :
                    idx_sort.append(item[0])
                    features_sort.append(item[1])
                    y_sort.append(item[2])

   

                #...Convert list to torch.Tensor
                idx_sort = torch.Tensor(idx_sort).to(self.device)
                y_sort = torch.Tensor(y_sort)
                #features_sort = torch.Tensor(features_sort) 
                features_sort_stack = torch.stack(features_sort)
                


                self.set_attr(self.task_idm, "train_features_tgt_tsk", features_sort_stack) 
                self.set_attr(self.task_idm, "train_labels_tgt_task", y_sort) 
                self.set_attr(self.task_idm, "train_indices_tgt_tsk", idx_sort) 
                


    def clear_train_features(self):
        self.is_features_computed = False

   
               
    #.define a single memeory for each task in each domain --> 12 memories            
    @torch.no_grad()
    def _init_memory_bank(self): #
        out_dim = self.config.model_params.out_dim
        self.length_task_source = np.zeros(self.task_num)
        self.length_task_target = np.zeros(self.task_num)
        for task_idm, _ in enumerate(self.task_classes_arr):
            self.task_idm = task_idm
            print("Task_idm for memory", self.task_idm)
            
            for domain_name in ("source", "target"):

                if domain_name == "source":

                    data_len_source_task = self.get_attr(self.task_idm, "train_len_init_source_task")
                    print(f"data_len_source_task_{task_idm}", data_len_source_task )              
                    self.length_task_source[self.task_idm] = data_len_source_task
                    
                    
                    memory_bank_source_task = MemoryBank(data_len_source_task, out_dim) #data_len: length of the memory bank, out_dim: dimension of the memory bank features
                    
                    if self.config.model_params.load_memory_bank:
                        self.compute_train_features()  
                        idx = self.get_attr(self.task_idm, "train_indices_src_tsk") 
                        feat = self.get_attr(self.task_idm, "train_features_src_tsk") 
                        
                        if self.task_idm == 0:
                            idx_n = idx 
                        else:
                            #print("tensor print ", torch.tensor(np.sum(self.length_task_source[:self.task_idm])))
                            idx_n = idx - torch.tensor(np.sum(self.length_task_source[:self.task_idm])) #.cpu().astype(int)
             
                        memory_bank_source_task.update(idx_n.to(torch.int64), feat)

              
                        if self.config.data_params.name in ["visda17", "domainnet"]:
                            delattr(self, f"train_indices_{domain_name}")
                            delattr(self, f"train_features_{domain_name}")
                            

                    
                    self.set_attr(self.task_idm, "memory_bank_wrapper_src", memory_bank_source_task)        
                    self.loss_fn.module.set_attr( self.task_idm, "data_len_source_task", data_len_source_task)                                 
                    self.loss_fn.module.set_broadcast(
                       task_idm, "memory_bank_source_task", memory_bank_source_task.as_tensor()
                    )


                if domain_name == "target":

    
                    data_len_target_task = self.get_attr(self.task_idm, "train_len_init_target_task")
                    #print("data_len_target_task", data_len_target_task)
            
                    self.length_task_target[self.task_idm] = data_len_target_task
            
                    memory_bank_target_task = MemoryBank(data_len_target_task, out_dim) 

                    if self.config.model_params.load_memory_bank:
                        self.compute_train_features()
                        idx = self.get_attr(self.task_idm, "train_indices_tgt_tsk") 
                        feat = self.get_attr(self.task_idm, "train_features_tgt_tsk") 
                        
    
                        if self.task_idm == 0:
                            idx_n = idx 
                        else:
                            #print("tensor print target", torch.tensor(np.sum(self.length_task_target[:self.task_idm])))
                            idx_n = idx - torch.tensor(np.sum(self.length_task_target[:self.task_idm])) #.cpu().astype(int)
                          
                        memory_bank_target_task.update(idx_n.to(torch.int64), feat) #memory_bank_target_task.update(idx.to(torch.int64), feat)
                     
        
        
                        if self.config.data_params.name in ["visda17", "domainnet"]:
                            delattr(self, f"train_indices_{domain_name}")
                            delattr(self, f"train_features_{domain_name}")

                    
                    
                    self.set_attr(self.task_idm, "memory_bank_wrapper_tgt", memory_bank_target_task)
                    self.loss_fn.module.set_attr(self.task_idm, "data_len_target_task", data_len_target_task)              
                    self.loss_fn.module.set_broadcast( 
                        self.task_idm, "memory_bank_target_task", memory_bank_target_task.as_tensor()
                    )




  
    @torch.no_grad()
    def _update_memory_bank(self, domain_name, indices, new_data_memory_tsk): #Fix it
        
        if domain_name == "source":
            memory_bank_wrapper_src = self.get_attr(self.task_id, "memory_bank_wrapper_src")
            memory_bank_wrapper_src.update(indices, new_data_memory_tsk)
            updated_bank_src_tsk = memory_bank_wrapper_src.as_tensor()
            self.loss_fn.module.set_broadcast(self.task_id, "memory_bank_source_task", updated_bank_src_tsk)  #Not sure

        
        if domain_name == "taregt":
            memory_bank_wrapper_tgt = self.get_attr(self.task_id, "memory_bank_wrapper_tgt")
            memory_bank_wrapper_tgt.update(indices, new_data_memory_tsk)
            updated_bank_tgt_tsk = memory_bank_wrapper_tgt.as_tensor()
            self.loss_fn.module.set_broadcast(self.task_id, "memory_bank_target_task", updated_bank_tgt_tsk)  #Not sure

            
        

    def _load_memory_bank(self, memory_bank_dict):
        """load memory bank from checkpoint

        Args:
            memory_bank_dict (dict): memory_bank dict of source and target domain
        """
        for domain_name in ("source", "target"):
                                  
            if domain_name == "source":       
                memory_bank_src = memory_bank_dict[domain_name]._bank.cuda()
                self.get_attr(self.task_id, "memory_bank_wrapper_src")._bank = memory_bank
                self.loss_fn.module.set_broadcast(task_id, "memory_bank", memory_bank)
                                  
            if domain_name == "target":    
                memory_bank_tgt = memory_bank_dict[domain_name]._bank.cuda() 
                self.get_attr(self.task_id, "memory_bank_wrapper_tgt")._bank = memory_bank_tgt
                self.loss_fn.module.set_broadcast(self.task_id, "memory_bank", memory_bank)

    # Cluster

    @torch.no_grad()
    def _update_cluster_labels(self):   #FIX it
        #k_list = self.config.k_list 
        k_list_task = self.config.k_list_task
        
        for clus_type in self.config.loss_params.clus.type:
            cluster_labels_domain = {}
            cluster_centroids_domain = {}
            cluster_phi_domain = {}

            # clustering for each domain
            if clus_type == "each":
                for domain_name in ("source", "target"):
                    
                    if domain_name == "source":

                        memory_bank_source_tensor = self.get_attr(
                            self.task_id, "memory_bank_wrapper_src"
                        ).as_tensor()

                        # clustering
                        cluster_labels, cluster_centroids, cluster_phi = torch_kmeans(
                            k_list_task,
                            memory_bank_source_tensor,
                            seed=self.current_epoch + self.current_iteration,
                        )

                        cluster_labels_domain[domain_name] = cluster_labels
                        cluster_centroids_domain[domain_name] = cluster_centroids
                        cluster_phi_domain[domain_name] = cluster_phi

                    if domain_name == "target":

                        memory_bank_target_tensor = self.get_attr(
                            self.task_id, "memory_bank_wrapper_tgt"
                        ).as_tensor()

                        # clustering
                        cluster_labels, cluster_centroids, cluster_phi = torch_kmeans(
                            k_list_task,
                            memory_bank_target_tensor,
                            seed=self.current_epoch + self.current_iteration,
                        )

                        cluster_labels_domain[domain_name] = cluster_labels
                        cluster_centroids_domain[domain_name] = cluster_centroids
                        cluster_phi_domain[domain_name] = cluster_phi  


                        
                        
                self.cluster_each_centroids_domain = cluster_centroids_domain
                self.cluster_each_labels_domain = cluster_labels_domain
                self.cluster_each_phi_domain = cluster_phi_domain

                #update task_clus_info
                self.clus_inf.append([cluster_centroids_domain, cluster_labels_domain, cluster_phi_domain])
            else:
                print(clus_type)
                raise NotImplementedError

            # update cluster to losss_fn
            for domain_name in ("source", "target"):  
                self.loss_fn.module.set_broadcast(
                    domain_name,
                    f"cluster_labels_{clus_type}",
                    cluster_labels_domain[domain_name],
                )
                self.loss_fn.module.set_broadcast(
                    domain_name,
                    f"cluster_centroids_{clus_type}",
                    cluster_centroids_domain[domain_name],
                )
                if cluster_phi_domain:
                    self.loss_fn.module.set_broadcast(
                        domain_name,
                        f"cluster_phi_{clus_type}",
                        cluster_phi_domain[domain_name],
                    )


In [None]:

#pre_checkpoint_dir = check_pretrain_dir(config_json)
AgentClass = Agent #globals()
agent = AgentClass(config)


if pre_checkpoint_dir is not None:
    agent.load_checkpoint("model_best.pth.tar", pre_checkpoint_dir)
    
try:
    agent.run()
    #agent.finalise()
except KeyboardInterrupt:
    pass 